diff options
Diffstat (limited to 'continuedev/src')
67 files changed, 1218 insertions, 436 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index e1c8a076..9dbced32 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -13,7 +13,7 @@ from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..models.main import ContinueBaseModel from .main import Context, ContinueCustomException, HighlightedRangeContext, Policy, History, FullState, Step, HistoryNode -from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep +from ..plugins.steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from ..libs.util.telemetry import capture_event from .sdk import ContinueSDK from ..libs.util.step_name_to_steps import get_step_from_name @@ -36,7 +36,11 @@ def get_error_title(e: Exception) -> str: elif isinstance(e, openai_errors.APIConnectionError): return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\"" elif isinstance(e, openai_errors.InvalidRequestError): - return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' + return 'Invalid request sent to OpenAI. Please try again.' + elif "rate_limit_ip_middleware" in e.__str__(): + return 'You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings.' + elif e.__str__().startswith("Cannot connect to host"): + return "The request failed. Please check your internet connection and try again." return e.__str__() or e.__repr__() @@ -48,6 +52,8 @@ class Autopilot(ContinueBaseModel): full_state: Union[FullState, None] = None _on_update_callbacks: List[Callable[[FullState], None]] = [] + continue_sdk: ContinueSDK = None + _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] @@ -55,9 +61,11 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @cached_property - def continue_sdk(self) -> ContinueSDK: - return ContinueSDK(self) + @classmethod + async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": + autopilot = cls(ide=ide, policy=policy) + autopilot.continue_sdk = await ContinueSDK.create(autopilot) + return autopilot class Config: arbitrary_types_allowed = True @@ -94,9 +102,14 @@ class Autopilot(ContinueBaseModel): self.continue_sdk.update_default_model(model) async def clear_history(self): + # Reset history self.history = History.from_empty() self._main_user_input_queue = [] self._active = False + + # Also remove all context + self._highlighted_ranges = [] + await self.update_subscribers() def on_update(self, callback: Coroutine["FullState", None, None]): @@ -160,6 +173,25 @@ class Autopilot(ContinueBaseModel): if not any(map(lambda x: x.editing, self._highlighted_ranges)): self._highlighted_ranges[0].editing = True + def _disambiguate_highlighted_ranges(self): + """If any files have the same name, also display their folder name""" + name_status: Dict[str, set] = { + } # basename -> set of full paths with that basename + for rif in self._highlighted_ranges: + basename = os.path.basename(rif.range.filepath) + if basename in name_status: + name_status[basename].add(rif.range.filepath) + else: + name_status[basename] = {rif.range.filepath} + + for rif in self._highlighted_ranges: + basename = os.path.basename(rif.range.filepath) + if len(name_status[basename]) > 1: + rif.display_name = os.path.join( + os.path.basename(os.path.dirname(rif.range.filepath)), basename) + else: + rif.display_name = basename + async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): # Filter out rifs from ~/.continue/diffs folder range_in_files = [ @@ -205,6 +237,7 @@ class Autopilot(ContinueBaseModel): ) for rif in range_in_files] self._make_sure_is_editing_range() + self._disambiguate_highlighted_ranges() await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index f6167638..70c4876e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -45,6 +45,11 @@ DEFAULT_SLASH_COMMANDS = [ step_name="OpenConfigStep", ), SlashCommand( + name="help", + description="Ask a question like '/help what is given to the llm as context?'", + step_name="HelpStep", + ), + SlashCommand( name="comment", description="Write comments for the current file or highlighted code", step_name="CommentCodeStep", @@ -62,16 +67,22 @@ DEFAULT_SLASH_COMMANDS = [ ] +class AzureInfo(BaseModel): + endpoint: str + engine: str + api_version: str + + class ContinueConfig(BaseModel): """ A pydantic class for the continue config file. """ steps_on_startup: Optional[Dict[str, Dict]] = {} disallowed_steps: Optional[List[str]] = [] - server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4"] = 'gpt-4' + "gpt-4", "claude-2", "ggml"] = 'gpt-4' + temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", @@ -80,12 +91,18 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS on_traceback: Optional[List[OnTracebackSteps]] = [ OnTracebackSteps(step_name="DefaultOnTracebackStep")] + system_message: Optional[str] = None + azure_openai_info: Optional[AzureInfo] = None # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) def default_slash_commands_validator(cls, v): return DEFAULT_SLASH_COMMANDS + @validator('temperature', pre=True) + def temperature_validator(cls, v): + return max(0.0, min(1.0, v)) + def load_config(config_file: str) -> ContinueConfig: """ diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 88690c83..50d01f8d 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -102,6 +102,7 @@ class HistoryNode(ContinueBaseModel): depth: int deleted: bool = False active: bool = True + logs: List[str] = [] def to_chat_messages(self) -> List[ChatMessage]: if self.step.description is None or self.step.manage_own_chat_context: @@ -258,10 +259,8 @@ class Step(ContinueBaseModel): def dict(self, *args, **kwargs): d = super().dict(*args, **kwargs) - if self.description is not None: - d["description"] = self.description - else: - d["description"] = "" + # Make sure description is always a string + d["description"] = self.description or "" return d @validator("name", pre=True, always=True) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index b8363df2..dfa0e7f9 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,25 +1,15 @@ from textwrap import dedent -from typing import List, Tuple, Type, Union +from typing import Union -from ..steps.welcome import WelcomeStep +from ..plugins.steps.chat import SimpleChatStep +from ..plugins.steps.welcome import WelcomeStep from .config import ContinueConfig -from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma -from ..steps.steps_on_startup import StepsOnStartupStep -from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe -from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ..recipes.AddTransformRecipe.main import AddTransformRecipe -from .main import Step, Validator, History, Policy -from .observation import Observation, TracebackObservation, UserInputObservation -from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep -from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe -from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep -from ..steps.comment_code import CommentCodeStep -from ..steps.react import NLDecisionStep -from ..steps.chat import SimpleChatStep, ChatWithFunctions, EditFileStep, AddFileStep -from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ..steps.core.core import MessageStep +from ..plugins.steps.steps_on_startup import StepsOnStartupStep +from .main import Step, History, Policy +from .observation import UserInputObservation +from ..plugins.steps.core.core import MessageStep from ..libs.util.step_name_to_steps import get_step_from_name -from ..steps.custom_command import CustomCommandStep +from ..plugins.steps.custom_command import CustomCommandStep def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: @@ -50,7 +40,7 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: return None -class DemoPolicy(Policy): +class DefaultPolicy(Policy): ran_code_last: bool = False def next(self, config: ContinueConfig, history: History) -> Step: @@ -58,12 +48,9 @@ class DemoPolicy(Policy): if history.get_current() is None: return ( MessageStep(name="Welcome to Continue", message=dedent("""\ - - Highlight code and ask a question or give instructions - - Use `cmd+k` (Mac) / `ctrl+k` (Windows) to open Continue - - Use `cmd+shift+e` / `ctrl+shift+e` to open file Explorer - - Add your own OpenAI API key to VS Code Settings with `cmd+,` - - Use slash commands when you want fine-grained control - - Past steps are included as part of the context by default""")) >> + - Highlight code section and ask a question or give instructions + - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue + - Use `/help` to ask questions about how to use Continue""")) >> WelcomeStep() >> # SetupContinueWorkspaceStep() >> # CreateCodebaseIndexChroma() >> diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index aa2d8892..9d1025e3 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,9 +1,9 @@ import asyncio from functools import cached_property -from typing import Coroutine, Union +from typing import Coroutine, Dict, Union import os -from ..steps.core.core import DefaultModelEditCodeStep +from ..plugins.steps.core.core import DefaultModelEditCodeStep from ..models.main import Range from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig, load_config, load_global_config, update_global_config @@ -11,10 +11,12 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI +from ..libs.llm.anthropic import AnthropicLLM +from ..libs.llm.ggml import GGML from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole -from ..steps.core.core import * +from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage +from ..plugins.steps.core.core import * from ..libs.llm.proxy_server import ProxyServer @@ -22,26 +24,78 @@ class Autopilot: pass +ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] +MODEL_PROVIDER_TO_ENV_VAR = { + "openai": "OPENAI_API_KEY", + "hf_inference_api": "HUGGING_FACE_TOKEN", + "anthropic": "ANTHROPIC_API_KEY", +} + + class Models: - def __init__(self, sdk: "ContinueSDK"): + provider_keys: Dict[ModelProvider, str] = {} + model_providers: List[ModelProvider] + system_message: str + + """ + Better to have sdk.llm.stream_chat(messages, model="claude-2"). + Then you also don't care that it' async. + And it's easier to add more models. + And intermediate shared code is easier to add. + And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" + PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. + Easy to reason about, can place anywhere. + And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. + This can all happen inside of Models? + + class Prompt: + def __init__(self, ...info): + '''take whatever info is needed to describe the prompt''' + + def to_string(self, model: str) -> str: + '''depending on the model, return the single prompt string''' + """ + + def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk + self.model_providers = model_providers + self.system_message = sdk.config.system_message + + @classmethod + async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + if sdk.config.default_model == "claude-2": + with_providers.append("anthropic") + + models = Models(sdk, with_providers) + for provider in with_providers: + if provider in MODEL_PROVIDER_TO_ENV_VAR: + env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] + models.provider_keys[provider] = await sdk.get_user_secret( + env_var, f'Please add your {env_var} to the .env file') + + return models def __load_openai_model(self, model: str) -> OpenAI: - async def load_openai_model(): - api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free') - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) - return asyncio.get_event_loop().run_until_complete(load_openai_model()) + api_key = self.provider_keys["openai"] + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) + + def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: + api_key = self.provider_keys["hf_inference_api"] + return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) + + def __load_anthropic_model(self, model: str) -> AnthropicLLM: + api_key = self.provider_keys["anthropic"] + return AnthropicLLM(api_key, model, self.system_message) + + @cached_property + def claude2(self): + return self.__load_anthropic_model("claude-2") @cached_property def starcoder(self): - async def load_starcoder(): - api_key = await self.sdk.get_user_secret( - 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') - return HuggingFaceInferenceAPI(api_key=api_key) - return asyncio.get_event_loop().run_until_complete(load_starcoder()) + return self.__load_hf_inference_api_model("bigcode/starcoder") @cached_property def gpt35(self): @@ -59,6 +113,10 @@ class Models: def gpt4(self): return self.__load_openai_model("gpt-4") + @cached_property + def ggml(self): + return GGML(system_message=self.system_message) + def __model_from_name(self, model_name: str): if model_name == "starcoder": return self.starcoder @@ -68,13 +126,17 @@ class Models: return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 + elif model_name == "claude-2": + return self.claude2 + elif model_name == "ggml": + return self.ggml else: raise Exception(f"Unknown model {model_name}") @property def default(self): default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt35 + return self.__model_from_name(default_model) if default_model is not None else self.gpt4 class ContinueSDK(AbstractContinueSDK): @@ -87,10 +149,32 @@ class ContinueSDK(AbstractContinueSDK): def __init__(self, autopilot: Autopilot): self.ide = autopilot.ide self.__autopilot = autopilot - self.models = Models(self) self.context = autopilot.context self.config = self._load_config() + @classmethod + async def create(cls, autopilot: Autopilot) -> "ContinueSDK": + sdk = ContinueSDK(autopilot) + + try: + config = sdk._load_config() + sdk.config = config + except Exception as e: + print(e) + sdk.config = ContinueConfig() + msg_step = MessageStep( + name="Invalid Continue Config File", message=e.__repr__()) + msg_step.description = e.__repr__() + sdk.history.add_node(HistoryNode( + step=msg_step, + observation=None, + depth=0, + active=False + )) + + sdk.models = await Models.create(sdk) + return sdk + config: ContinueConfig def _load_config(self) -> ContinueConfig: @@ -108,6 +192,9 @@ class ContinueSDK(AbstractContinueSDK): def history(self) -> History: return self.__autopilot.history + def write_log(self, message: str): + self.history.timeline[self.history.current_index].logs.append(message) + async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): return path @@ -215,7 +302,7 @@ class ContinueSDK(AbstractContinueSDK): for rif in highlighted_code: msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", - role="system", summary=f"{preface}: {rif.filepath}") + role="user", summary=f"{preface}: {rif.filepath}") # Don't insert after latest user message or function call i = -1 diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4c4de213..2766db4b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,15 +9,15 @@ from pydantic import BaseModel class LLM(ABC): system_message: Union[str, None] = None - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError - def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: """Stream the completion through generator.""" raise NotImplementedError - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: """Stream the chat through generator.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..625d4e57 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -0,0 +1,97 @@ + +from functools import cached_property +import time +from typing import Any, Coroutine, Dict, Generator, List, Union +from ...core.main import ChatMessage +from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic +from ..llm import LLM +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top + + +class AnthropicLLM(LLM): + api_key: str + default_model: str + async_client: AsyncAnthropic + + def __init__(self, api_key: str, default_model: str, system_message: str = None): + self.api_key = api_key + self.default_model = default_model + self.system_message = system_message + + self.async_client = AsyncAnthropic(api_key=api_key) + + @cached_property + def name(self): + return self.default_model + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.default_model} + + def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: + args = args.copy() + if "max_tokens" in args: + args["max_tokens_to_sample"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + return args + + def count_tokens(self, text: str): + return count_tokens(self.default_model, text) + + def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + args = self._transform_args(args) + + async for chunk in await self.async_client.completions.create( + prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", + **args + ): + yield chunk.completion + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + args = self._transform_args(args) + + messages = compile_chat_messages( + args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + async for chunk in await self.async_client.completions.create( + prompt=self.__messages_to_prompt(messages), + **args + ): + yield { + "role": "assistant", + "content": chunk.completion + } + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + args = self._transform_args(args) + + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + resp = (await self.async_client.completions.create( + prompt=self.__messages_to_prompt(messages), + **args + )).completion + + return resp diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py new file mode 100644 index 00000000..4889a556 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -0,0 +1,86 @@ +from functools import cached_property +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + +SERVER_URL = "http://localhost:8000" + + +class GGML(LLM): + + def __init__(self, system_message: str = None): + self.system_message = system_message + + @cached_property + def name(self): + return "ggml" + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": messages, + **args + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["stream"] = True + + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ + "messages": messages, + **args + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + **args + }) as resp: + try: + return await resp.text() + except: + raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 1586c620..36f03270 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -9,9 +9,14 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): api_key: str - model: str = "bigcode/starcoder" + model: str - def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): + def __init__(self, api_key: str, model: str, system_message: str = None): + self.api_key = api_key + self.model = model + self.system_message = system_message # TODO: Nothing being done with this + + def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" headers = { diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index f0877d90..a0773c1d 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,54 +1,78 @@ from functools import cached_property -import time -from typing import Any, Coroutine, Dict, Generator, List, Union +import json +from typing import Any, Callable, Coroutine, Dict, Generator, List, Union + from ...core.main import ChatMessage import openai from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top +from ...core.config import AzureInfo class OpenAI(LLM): api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None): + def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message + self.azure_info = azure_info + self.write_log = write_log openai.api_key = api_key + # Using an Azure OpenAI deployment + if azure_info is not None: + openai.api_type = "azure" + openai.api_base = azure_info.endpoint + openai.api_version = azure_info.api_version + @cached_property def name(self): return self.default_model @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + args = {**DEFAULT_ARGS, "model": self.default_model} + if self.azure_info is not None: + args["engine"] = self.azure_info.engine + return args def count_tokens(self, text: str): return count_tokens(self.default_model, text) - async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True if args["model"] in CHAT_MODELS: + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + completion = "" async for chunk in await openai.ChatCompletion.acreate( - messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + messages=messages, **args, ): if "content" in chunk.choices[0].delta: yield chunk.choices[0].delta.content + completion += chunk.choices[0].delta.content else: continue + + self.write_log(f"Completion: \n\n{completion}") else: + self.write_log(f"Prompt:\n\n{prompt}") + completion = "" async for chunk in await openai.Completion.acreate(prompt=prompt, **args): yield chunk.choices[0].text + completion += chunk.choices[0].text + + self.write_log(f"Completion:\n\n{completion}") - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() args.update(kwargs) args["stream"] = True @@ -56,27 +80,39 @@ class OpenAI(LLM): if not args["model"].endswith("0613") and "functions" in args: del args["functions"] + messages = compile_chat_messages( + args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + completion = "" async for chunk in await openai.ChatCompletion.acreate( - messages=compile_chat_messages( - args["model"], messages, args["max_tokens"], functions=args.get("functions", None)), + messages=messages, **args, ): yield chunk.choices[0].delta + if "content" in chunk.choices[0].delta: + completion += chunk.choices[0].delta.content + self.write_log(f"Completion: \n\n{completion}") - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} if args["model"] in CHAT_MODELS: + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") resp = (await openai.ChatCompletion.acreate( - messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + messages=messages, **args, )).choices[0].message.content + self.write_log(f"Completion: \n\n{resp}") else: + prompt = prune_raw_prompt_from_top( + args["model"], prompt, args["max_tokens"]) + self.write_log(f"Prompt:\n\n{prompt}") resp = (await openai.Completion.acreate( - prompt=prune_raw_prompt_from_top( - args["model"], prompt, args["max_tokens"]), + prompt=prompt, **args, )).choices[0].text + self.write_log(f"Completion:\n\n{resp}") return resp diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index eab6e441..75c91c4e 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,10 +1,12 @@ -from functools import cached_property + import json -from typing import Any, Coroutine, Dict, Generator, List, Literal, Union +import traceback +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union import aiohttp +from ..util.telemetry import capture_event from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens, format_chat_messages import certifi import ssl @@ -19,12 +21,14 @@ class ProxyServer(LLM): unique_id: str name: str default_model: Literal["gpt-3.5-turbo", "gpt-4"] + write_log: Callable[[str], None] - def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None): + def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): self.unique_id = unique_id self.default_model = default_model self.system_message = system_message self.name = default_model + self.write_log = write_log @property def default_args(self): @@ -32,33 +36,44 @@ class ProxyServer(LLM): def count_tokens(self, text: str): return count_tokens(self.default_model, text) + + def get_headers(self): + # headers with unique id + return {"unique_id": self.unique_id} - async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), - "unique_id": self.unique_id, + "messages": messages, **args - }) as resp: - try: - return await resp.text() - except: + }, headers=self.get_headers()) as resp: + if resp.status != 200: raise Exception(await resp.text()) - async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: + response_text = await resp.text() + self.write_log(f"Completion: \n\n{response_text}") + return response_text + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) + args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ "messages": messages, - "unique_id": self.unique_id, **args - }) as resp: + }, headers=self.get_headers()) as resp: # This is streaming application/json instaed of text/event-stream + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) async for line in resp.content.iter_chunks(): if line[1]: try: @@ -67,24 +82,38 @@ class ProxyServer(LLM): chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": - yield json.loads(chunk) - except: - raise Exception(str(line[0])) + loaded_chunk = json.loads(chunk) + yield loaded_chunk + if "content" in loaded_chunk: + completion += loaded_chunk["content"] + except Exception as e: + capture_event(self.unique_id, "proxy_server_parse_error", { + "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + else: + break - async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + self.write_log(f"Completion: \n\n{completion}") + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ "messages": messages, - "unique_id": self.unique_id, **args - }) as resp: + }, headers=self.get_headers()) as resp: + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) async for line in resp.content.iter_any(): if line: try: - yield line.decode("utf-8") + decoded_line = line.decode("utf-8") + yield decoded_line + completion += decoded_line except: raise Exception(str(line)) + self.write_log(f"Completion: \n\n{completion}") diff --git a/continuedev/src/continuedev/libs/util/commonregex.py b/continuedev/src/continuedev/libs/util/commonregex.py new file mode 100644 index 00000000..55da7fc0 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/commonregex.py @@ -0,0 +1,138 @@ +# coding: utf-8 +import json +import re +from typing import Any, Dict + +date = re.compile( + '(?:(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?\s+(?:of\s+)?(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)|(?:jan\.?|january|feb\.?|february|mar\.?|march|apr\.?|april|may|jun\.?|june|jul\.?|july|aug\.?|august|sep\.?|september|oct\.?|october|nov\.?|november|dec\.?|december)\s+(?<!\:)(?<!\:\d)[0-3]?\d(?:st|nd|rd|th)?)(?:\,)?\s*(?:\d{4})?|[0-3]?\d[-\./][0-3]?\d[-\./]\d{2,4}', re.IGNORECASE) +time = re.compile( + '\d{1,2}:\d{2} ?(?:[ap]\.?m\.?)?|\d[ap]\.?m\.?', re.IGNORECASE) +phone = re.compile( + '''((?:(?<![\d-])(?:\+?\d{1,3}[-.\s*]?)?(?:\(?\d{3}\)?[-.\s*]?)?\d{3}[-.\s*]?\d{4}(?![\d-]))|(?:(?<![\d-])(?:(?:\(\+?\d{2}\))|(?:\+?\d{2}))\s*\d{2}\s*\d{3}\s*\d{4}(?![\d-])))''') +phones_with_exts = re.compile( + '((?:(?:\+?1\s*(?:[.-]\s*)?)?(?:\(\s*(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9])\s*\)|(?:[2-9]1[02-9]|[2-9][02-8]1|[2-9][02-8][02-9]))\s*(?:[.-]\s*)?)?(?:[2-9]1[02-9]|[2-9][02-9]1|[2-9][02-9]{2})\s*(?:[.-]\s*)?(?:[0-9]{4})(?:\s*(?:#|x\.?|ext\.?|extension)\s*(?:\d+)?))', re.IGNORECASE) +link = re.compile('(?i)((?:https?://|www\d{0,3}[.])?[a-z0-9.\-]+[.](?:(?:international)|(?:construction)|(?:contractors)|(?:enterprises)|(?:photography)|(?:immobilien)|(?:management)|(?:technology)|(?:directory)|(?:education)|(?:equipment)|(?:institute)|(?:marketing)|(?:solutions)|(?:builders)|(?:clothing)|(?:computer)|(?:democrat)|(?:diamonds)|(?:graphics)|(?:holdings)|(?:lighting)|(?:plumbing)|(?:training)|(?:ventures)|(?:academy)|(?:careers)|(?:company)|(?:domains)|(?:florist)|(?:gallery)|(?:guitars)|(?:holiday)|(?:kitchen)|(?:recipes)|(?:shiksha)|(?:singles)|(?:support)|(?:systems)|(?:agency)|(?:berlin)|(?:camera)|(?:center)|(?:coffee)|(?:estate)|(?:kaufen)|(?:luxury)|(?:monash)|(?:museum)|(?:photos)|(?:repair)|(?:social)|(?:tattoo)|(?:travel)|(?:viajes)|(?:voyage)|(?:build)|(?:cheap)|(?:codes)|(?:dance)|(?:email)|(?:glass)|(?:house)|(?:ninja)|(?:photo)|(?:shoes)|(?:solar)|(?:today)|(?:aero)|(?:arpa)|(?:asia)|(?:bike)|(?:buzz)|(?:camp)|(?:club)|(?:coop)|(?:farm)|(?:gift)|(?:guru)|(?:info)|(?:jobs)|(?:kiwi)|(?:land)|(?:limo)|(?:link)|(?:menu)|(?:mobi)|(?:moda)|(?:name)|(?:pics)|(?:pink)|(?:post)|(?:rich)|(?:ruhr)|(?:sexy)|(?:tips)|(?:wang)|(?:wien)|(?:zone)|(?:biz)|(?:cab)|(?:cat)|(?:ceo)|(?:com)|(?:edu)|(?:gov)|(?:int)|(?:mil)|(?:net)|(?:onl)|(?:org)|(?:pro)|(?:red)|(?:tel)|(?:uno)|(?:xxx)|(?:ac)|(?:ad)|(?:ae)|(?:af)|(?:ag)|(?:ai)|(?:al)|(?:am)|(?:an)|(?:ao)|(?:aq)|(?:ar)|(?:as)|(?:at)|(?:au)|(?:aw)|(?:ax)|(?:az)|(?:ba)|(?:bb)|(?:bd)|(?:be)|(?:bf)|(?:bg)|(?:bh)|(?:bi)|(?:bj)|(?:bm)|(?:bn)|(?:bo)|(?:br)|(?:bs)|(?:bt)|(?:bv)|(?:bw)|(?:by)|(?:bz)|(?:ca)|(?:cc)|(?:cd)|(?:cf)|(?:cg)|(?:ch)|(?:ci)|(?:ck)|(?:cl)|(?:cm)|(?:cn)|(?:co)|(?:cr)|(?:cu)|(?:cv)|(?:cw)|(?:cx)|(?:cy)|(?:cz)|(?:de)|(?:dj)|(?:dk)|(?:dm)|(?:do)|(?:dz)|(?:ec)|(?:ee)|(?:eg)|(?:er)|(?:es)|(?:et)|(?:eu)|(?:fi)|(?:fj)|(?:fk)|(?:fm)|(?:fo)|(?:fr)|(?:ga)|(?:gb)|(?:gd)|(?:ge)|(?:gf)|(?:gg)|(?:gh)|(?:gi)|(?:gl)|(?:gm)|(?:gn)|(?:gp)|(?:gq)|(?:gr)|(?:gs)|(?:gt)|(?:gu)|(?:gw)|(?:gy)|(?:hk)|(?:hm)|(?:hn)|(?:hr)|(?:ht)|(?:hu)|(?:id)|(?:ie)|(?:il)|(?:im)|(?:in)|(?:io)|(?:iq)|(?:ir)|(?:is)|(?:it)|(?:je)|(?:jm)|(?:jo)|(?:jp)|(?:ke)|(?:kg)|(?:kh)|(?:ki)|(?:km)|(?:kn)|(?:kp)|(?:kr)|(?:kw)|(?:ky)|(?:kz)|(?:la)|(?:lb)|(?:lc)|(?:li)|(?:lk)|(?:lr)|(?:ls)|(?:lt)|(?:lu)|(?:lv)|(?:ly)|(?:ma)|(?:mc)|(?:md)|(?:me)|(?:mg)|(?:mh)|(?:mk)|(?:ml)|(?:mm)|(?:mn)|(?:mo)|(?:mp)|(?:mq)|(?:mr)|(?:ms)|(?:mt)|(?:mu)|(?:mv)|(?:mw)|(?:mx)|(?:my)|(?:mz)|(?:na)|(?:nc)|(?:ne)|(?:nf)|(?:ng)|(?:ni)|(?:nl)|(?:no)|(?:np)|(?:nr)|(?:nu)|(?:nz)|(?:om)|(?:pa)|(?:pe)|(?:pf)|(?:pg)|(?:ph)|(?:pk)|(?:pl)|(?:pm)|(?:pn)|(?:pr)|(?:ps)|(?:pt)|(?:pw)|(?:py)|(?:qa)|(?:re)|(?:ro)|(?:rs)|(?:ru)|(?:rw)|(?:sa)|(?:sb)|(?:sc)|(?:sd)|(?:se)|(?:sg)|(?:sh)|(?:si)|(?:sj)|(?:sk)|(?:sl)|(?:sm)|(?:sn)|(?:so)|(?:sr)|(?:st)|(?:su)|(?:sv)|(?:sx)|(?:sy)|(?:sz)|(?:tc)|(?:td)|(?:tf)|(?:tg)|(?:th)|(?:tj)|(?:tk)|(?:tl)|(?:tm)|(?:tn)|(?:to)|(?:tp)|(?:tr)|(?:tt)|(?:tv)|(?:tw)|(?:tz)|(?:ua)|(?:ug)|(?:uk)|(?:us)|(?:uy)|(?:uz)|(?:va)|(?:vc)|(?:ve)|(?:vg)|(?:vi)|(?:vn)|(?:vu)|(?:wf)|(?:ws)|(?:ye)|(?:yt)|(?:za)|(?:zm)|(?:zw))(?:/[^\s()<>]+[^\s`!()\[\]{};:\'".,<>?\xab\xbb\u201c\u201d\u2018\u2019])?)', re.IGNORECASE) +email = re.compile( + "([a-z0-9!#$%&'*+\/=?^_`{|.}~-]+@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?)", re.IGNORECASE) +ip = re.compile('(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)', re.IGNORECASE) +ipv6 = re.compile( + '\s*(?!.*::.*::)(?:(?!:)|:(?=:))(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)){6}(?:[0-9a-f]{0,4}(?:(?<=::)|(?<!::):)[0-9a-f]{0,4}(?:(?<=::)|(?<!:)|(?<=:)(?<!::):)|(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)(?:\.(?:25[0-4]|2[0-4]\d|1\d\d|[1-9]?\d)){3})\s*', re.VERBOSE | re.IGNORECASE | re.DOTALL) +price = re.compile('[$]\s?[+-]?[0-9]{1,3}(?:(?:,?[0-9]{3}))*(?:\.[0-9]{1,2})?') +hex_color = re.compile('(#(?:[0-9a-fA-F]{8})|#(?:[0-9a-fA-F]{3}){1,2})\\b') +credit_card = re.compile('((?:(?:\\d{4}[- ]?){3}\\d{4}|\\d{15,16}))(?![\\d])') +btc_address = re.compile( + '(?<![a-km-zA-HJ-NP-Z0-9])[13][a-km-zA-HJ-NP-Z0-9]{26,33}(?![a-km-zA-HJ-NP-Z0-9])') +street_address = re.compile( + '\d{1,4} [\w\s]{1,20}(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\W?(?=\s|$)', re.IGNORECASE) +zip_code = re.compile(r'\b\d{5}(?:[-\s]\d{4})?\b') +po_box = re.compile(r'P\.? ?O\.? Box \d+', re.IGNORECASE) +ssn = re.compile( + '(?!000|666|333)0*(?:[0-6][0-9][0-9]|[0-7][0-6][0-9]|[0-7][0-7][0-2])[- ](?!00)[0-9]{2}[- ](?!0000)[0-9]{4}') +win_absolute_filepath = re.compile( + r'^(?:[a-zA-Z]\:|\\\\[\w\.]+\\[\w.$]+)\\(?:[\w]+\\)*\w([\w.])+', re.IGNORECASE) +unix_absolute_filepath = re.compile( + r'^\/(?:[\/\w]+\/)*\w([\w.])+', re.IGNORECASE) + +regexes = { + "win_absolute_filepath": win_absolute_filepath, + "unix_absolute_filepath": unix_absolute_filepath, + "dates": date, + "times": time, + "phones": phone, + "phones_with_exts": phones_with_exts, + "links": link, + "emails": email, + "ips": ip, + "ipv6s": ipv6, + "prices": price, + "hex_colors": hex_color, + "credit_cards": credit_card, + "btc_addresses": btc_address, + "street_addresses": street_address, + "zip_codes": zip_code, + "po_boxes": po_box, + "ssn_number": ssn, +} + +placeholders = { + "win_absolute_filepath": "<FILEPATH>", + "unix_absolute_filepath": "<FILEPATH>", + "dates": "<DATE>", + "times": "<TIME>", + "phones": "<PHONE>", + "phones_with_exts": "<PHONE_WITH_EXT>", + "links": "<LINK>", + "emails": "<EMAIL>", + "ips": "<IP>", + "ipv6s": "<IPV6>", + "prices": "<PRICE>", + "hex_colors": "<HEX_COLOR>", + "credit_cards": "<CREDIT_CARD>", + "btc_addresses": "<BTC_ADDRESS>", + "street_addresses": "<STREET_ADDRESS>", + "zip_codes": "<ZIP_CODE>", + "po_boxes": "<PO_BOX>", + "ssn_number": "<SSN>", +} + + +class regex: + + def __init__(self, obj, regex): + self.obj = obj + self.regex = regex + + def __call__(self, *args): + def regex_method(text=None): + return [x.strip() for x in self.regex.findall(text or self.obj.text)] + return regex_method + + +class CommonRegex(object): + + def __init__(self, text=""): + self.text = text + + for k, v in list(regexes.items()): + setattr(self, k, regex(self, v)(self)) + + if text: + for key in list(regexes.keys()): + method = getattr(self, key) + setattr(self, key, method()) + + +pii_parser = CommonRegex() + + +def clean_pii_from_str(text: str): + """Replace personally identifiable information (PII) with placeholders.""" + for regex_name, regex in list(regexes.items()): + placeholder = placeholders[regex_name] + text = regex.sub(placeholder, text) + + return text + + +def clean_pii_from_any(v: Any) -> Any: + """Replace personally identifiable information (PII) with placeholders. Not guaranteed to return same type as input.""" + if isinstance(v, str): + return clean_pii_from_str(v) + elif isinstance(v, dict): + cleaned_dict = {} + for key, value in v.items(): + cleaned_dict[key] = clean_pii_from_any(value) + return cleaned_dict + elif isinstance(v, list): + return [clean_pii_from_any(x) for x in v] + else: + # Try to convert to string + try: + orig_text = str(v) + cleaned_text = clean_pii_from_str(orig_text) + if orig_text != cleaned_text: + return cleaned_text + else: + return v + except: + return v diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 73be0717..c58ae499 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,15 +1,21 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage +from .templating import render_templated_string import tiktoken -aliases = {} +aliases = { + "ggml": "gpt-3.5-turbo", + "claude-2": "gpt-3.5-turbo", +} DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192 + "gpt-4": 8192, + "ggml": 2048, + "claude-2": 100000 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" @@ -40,9 +46,17 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in return encoding.decode(tokens[-max_tokens:]) +def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE + + def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): total_tokens = tokens_for_completion + \ - sum(count_tokens(model, message.content) + sum(count_chat_message_tokens(model, message) for message in chat_history) # 1. Replace beyond last 5 messages with summary @@ -59,43 +73,76 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: message = chat_history.pop(0) total_tokens -= count_tokens(model, message.content) - # 3. Truncate message in the last 5 + # 3. Truncate message in the last 5, except last 1 i = 0 - while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history): + while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1: message = chat_history[i] total_tokens -= count_tokens(model, message.content) total_tokens += count_tokens(model, message.summary) message.content = message.summary i += 1 - # 4. Remove entire messages in the last 5 - while total_tokens > max_tokens and len(chat_history) > 0: + # 4. Remove entire messages in the last 5, except last 1 + while total_tokens > max_tokens and len(chat_history) > 1: message = chat_history.pop(0) total_tokens -= count_tokens(model, message.content) + # 5. Truncate last message + if total_tokens > max_tokens and len(chat_history) > 0: + message = chat_history[0] + message.content = prune_raw_prompt_from_top( + model, message.content, tokens_for_completion) + total_tokens = max_tokens + return chat_history -def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: - prompt_tokens = count_tokens(model, prompt) +# In case we've missed weird edge cases +TOKEN_BUFFER_FOR_SAFETY = 100 + + +def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: + """ + The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message + """ + msgs_copy = [msg.copy(deep=True) + for msg in msgs] if msgs is not None else [] + + if prompt is not None: + prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) + msgs_copy += [prompt_msg] + + if system_message is not None: + # NOTE: System message takes second precedence to user prompt, so it is placed just before + # but move back to start after processing + rendered_system_message = render_templated_string(system_message) + system_chat_msg = ChatMessage( + role="system", content=rendered_system_message, summary=rendered_system_message) + # insert at second-to-last position + msgs_copy.insert(-1, system_chat_msg) + + # Add tokens from functions + function_tokens = 0 if functions is not None: for function in functions: - prompt_tokens += count_tokens(model, json.dumps(function)) - - msgs = prune_chat_history(model, - msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message)) - history = [] - if system_message: - history.append({ - "role": "system", - "content": system_message - }) - history += [msg.to_dict(with_functions=functions is not None) - for msg in msgs] - if prompt: - history.append({ - "role": "user", - "content": prompt - }) + function_tokens += count_tokens(model, json.dumps(function)) + + msgs_copy = prune_chat_history( + model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + + history = [msg.to_dict(with_functions=functions is not None) + for msg in msgs_copy] + + # Move system message back to start + if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": + system_message_dict = history.pop(-2) + history.insert(0, system_message_dict) return history + + +def format_chat_messages(messages: List[ChatMessage]) -> str: + formatted = "" + for msg in messages: + formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n" + return formatted diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py index d329e110..baa25da6 100644 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -1,18 +1,19 @@ from typing import Dict from ...core.main import Step -from ...steps.core.core import UserInputStep -from ...steps.main import EditHighlightedCodeStep -from ...steps.chat import SimpleChatStep -from ...steps.comment_code import CommentCodeStep -from ...steps.feedback import FeedbackStep -from ...recipes.AddTransformRecipe.main import AddTransformRecipe -from ...recipes.CreatePipelineRecipe.main import CreatePipelineRecipe -from ...recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ...steps.on_traceback import DefaultOnTracebackStep -from ...steps.clear_history import ClearHistoryStep -from ...steps.open_config import OpenConfigStep +from ...plugins.steps.core.core import UserInputStep +from ...plugins.steps.main import EditHighlightedCodeStep +from ...plugins.steps.chat import SimpleChatStep +from ...plugins.steps.comment_code import CommentCodeStep +from ...plugins.steps.feedback import FeedbackStep +from ...plugins.recipes.AddTransformRecipe.main import AddTransformRecipe +from ...plugins.recipes.CreatePipelineRecipe.main import CreatePipelineRecipe +from ...plugins.recipes.DDtoBQRecipe.main import DDtoBQRecipe +from ...plugins.recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe +from ...plugins.steps.on_traceback import DefaultOnTracebackStep +from ...plugins.steps.clear_history import ClearHistoryStep +from ...plugins.steps.open_config import OpenConfigStep +from ...plugins.steps.help import HelpStep # This mapping is used to convert from string in ContinueConfig json to corresponding Step class. # Used for example in slash_commands and steps_on_startup @@ -28,7 +29,8 @@ step_name_to_step_class = { "DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe, "DefaultOnTracebackStep": DefaultOnTracebackStep, "ClearHistoryStep": ClearHistoryStep, - "OpenConfigStep": OpenConfigStep + "OpenConfigStep": OpenConfigStep, + "HelpStep": HelpStep, } diff --git a/continuedev/src/continuedev/libs/util/dedent.py b/continuedev/src/continuedev/libs/util/strings.py index e59c2e97..f1fb8d0b 100644 --- a/continuedev/src/continuedev/libs/util/dedent.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -23,3 +23,27 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: break return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp + + +def remove_quotes_and_escapes(output: str) -> str: + """ + Clean up the output of the completion API, removing unnecessary escapes and quotes + """ + output = output.strip() + + # Replace smart quotes + output = output.replace("“", '"') + output = output.replace("”", '"') + output = output.replace("‘", "'") + output = output.replace("’", "'") + + # Remove escapes + output = output.replace('\\"', '"') + output = output.replace("\\'", "'") + output = output.replace("\\n", "\n") + output = output.replace("\\t", "\t") + output = output.replace("\\\\", "\\") + if (output.startswith('"') and output.endswith('"')) or (output.startswith("'") and output.endswith("'")): + output = output[1:-1] + + return output diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index bd9fde9d..17735dce 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -3,6 +3,7 @@ from posthog import Posthog from ...core.config import load_config import os from dotenv import load_dotenv +from .commonregex import clean_pii_from_any load_dotenv() in_codespaces = os.getenv("CODESPACES") == "true" @@ -13,10 +14,14 @@ posthog = Posthog('phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs', def capture_event(unique_id: str, event_name: str, event_properties: Any): + # Return early if telemetry is disabled config = load_config('.continue/config.json') if not config.allow_anonymous_telemetry: return if in_codespaces: event_properties['codespaces'] = True - posthog.capture(unique_id, event_name, event_properties) + + # Send event to PostHog + posthog.capture(unique_id, event_name, + clean_pii_from_any(event_properties)) diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py new file mode 100644 index 00000000..bb922ad7 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -0,0 +1,39 @@ +import os +import chevron + + +def get_vars_in_template(template): + """ + Get the variables in a template + """ + return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] + + +def escape_var(var: str) -> str: + """ + Escape a variable so it can be used in a template + """ + return var.replace(os.path.sep, '').replace('.', '') + + +def render_templated_string(template: str) -> str: + """ + Render system message or other templated string with mustache syntax. + Right now it only supports rendering absolute file paths as their contents. + """ + vars = get_vars_in_template(template) + + args = {} + for var in vars: + if var.startswith(os.path.sep): + # Escape vars which are filenames, because mustache doesn't allow / in variable names + escaped_var = escape_var(var) + template = template.replace( + var, escaped_var) + + if os.path.exists(var): + args[escaped_var] = open(var, 'r').read() + else: + args[escaped_var] = '' + + return chevron.render(template, args) diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md index d735e0cd..d735e0cd 100644 --- a/continuedev/src/continuedev/recipes/AddTransformRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/dlt_transform_docs.md b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md index 658b285f..658b285f 100644 --- a/continuedev/src/continuedev/recipes/AddTransformRecipe/dlt_transform_docs.md +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/dlt_transform_docs.md diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py index fdd343f5..5d242f7c 100644 --- a/continuedev/src/continuedev/recipes/AddTransformRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/main.py @@ -1,9 +1,9 @@ from textwrap import dedent -from ...core.main import Step -from ...core.sdk import ContinueSDK -from ...steps.core.core import WaitForUserInputStep -from ...steps.core.core import MessageStep +from ....core.main import Step +from ....core.sdk import ContinueSDK +from ....plugins.steps.core.core import WaitForUserInputStep +from ....plugins.steps.core.core import MessageStep from .steps import SetUpChessPipelineStep, AddTransformStep diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py index 9744146c..8c6446da 100644 --- a/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/AddTransformRecipe/steps.py @@ -1,14 +1,10 @@ import os from textwrap import dedent -from ...models.main import Range -from ...models.filesystem import RangeInFile -from ...steps.core.core import MessageStep -from ...core.sdk import Models -from ...core.observation import DictObservation -from ...models.filesystem_edit import AddFile -from ...core.main import Step -from ...core.sdk import ContinueSDK +from ....plugins.steps.core.core import MessageStep +from ....core.sdk import Models +from ....core.main import Step +from ....core.sdk import ContinueSDK AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" diff --git a/continuedev/src/continuedev/recipes/ContinueRecipeRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/README.md index df66104f..df66104f 100644 --- a/continuedev/src/continuedev/recipes/ContinueRecipeRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/ContinueRecipeRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py index 953fb0c2..c0f9e7e3 100644 --- a/continuedev/src/continuedev/recipes/ContinueRecipeRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/ContinueRecipeRecipe/main.py @@ -1,8 +1,7 @@ from textwrap import dedent -from ...models.filesystem import RangeInFile -from ...steps.main import EditHighlightedCodeStep -from ...core.main import Step -from ...core.sdk import ContinueSDK +from ....plugins.steps.main import EditHighlightedCodeStep +from ....core.main import Step +from ....core.sdk import ContinueSDK class ContinueStepStep(Step): diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/README.md index e69de29b..e69de29b 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py index 55ef107b..84363e02 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/main.py @@ -1,9 +1,9 @@ from textwrap import dedent -from ...core.sdk import ContinueSDK -from ...core.main import Step -from ...steps.core.core import WaitForUserInputStep -from ...steps.core.core import MessageStep +from ....core.sdk import ContinueSDK +from ....core.main import Step +from ....plugins.steps.core.core import WaitForUserInputStep +from ....plugins.steps.core.core import MessageStep from .steps import SetupPipelineStep, ValidatePipelineStep, RunQueryStep diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 60218ef9..433e309e 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -1,15 +1,13 @@ import os -import subprocess from textwrap import dedent import time -from ...models.main import Range -from ...models.filesystem import RangeInFile -from ...steps.core.core import MessageStep -from ...core.observation import DictObservation, InternalErrorObservation -from ...models.filesystem_edit import AddFile, FileEdit -from ...core.main import Step -from ...core.sdk import ContinueSDK, Models +from ....models.main import Range +from ....models.filesystem import RangeInFile +from ....plugins.steps.core.core import MessageStep +from ....models.filesystem_edit import AddFile, FileEdit +from ....core.main import Step +from ....core.sdk import ContinueSDK, Models AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md index c4981e56..c4981e56 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/dlt_duckdb_to_bigquery_docs.md b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/dlt_duckdb_to_bigquery_docs.md index eb68e117..eb68e117 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/dlt_duckdb_to_bigquery_docs.md +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/dlt_duckdb_to_bigquery_docs.md diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py index 1ae84310..5b6aa8f0 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/main.py @@ -1,9 +1,8 @@ from textwrap import dedent -from ...core.main import Step -from ...core.sdk import ContinueSDK -from ...steps.core.core import WaitForUserInputStep -from ...steps.core.core import MessageStep +from ....core.main import Step +from ....core.sdk import ContinueSDK +from ....plugins.steps.core.core import MessageStep from .steps import SetUpChessPipelineStep, SwitchDestinationStep, LoadDataStep # Based on the following guide: diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py index df414e2e..767936b8 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py @@ -1,17 +1,11 @@ import os -import subprocess from textwrap import dedent -import time - -from ...steps.find_and_replace import FindAndReplaceStep -from ...models.main import Range -from ...models.filesystem import RangeInFile -from ...steps.core.core import MessageStep -from ...core.sdk import Models -from ...core.observation import DictObservation, InternalErrorObservation -from ...models.filesystem_edit import AddFile, FileEdit -from ...core.main import Step -from ...core.sdk import ContinueSDK + +from ....plugins.steps.find_and_replace import FindAndReplaceStep +from ....plugins.steps.core.core import MessageStep +from ....core.sdk import Models +from ....core.main import Step +from ....core.sdk import ContinueSDK AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/README.md index e69de29b..e69de29b 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py index 2a3e3566..54cba45f 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/main.py @@ -1,10 +1,9 @@ from textwrap import dedent -from ...steps.input.nl_multiselect import NLMultiselectStep -from ...core.main import Step -from ...core.sdk import ContinueSDK -from ...steps.core.core import WaitForUserInputStep -from ...steps.core.core import MessageStep +from ....plugins.steps.input.nl_multiselect import NLMultiselectStep +from ....core.main import Step +from ....core.sdk import ContinueSDK +from ....plugins.steps.core.core import MessageStep from .steps import SetupPipelineStep, DeployAirflowStep, RunPipelineStep diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py index d9bdbc0a..83067d52 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DeployPipelineAirflowRecipe/steps.py @@ -1,18 +1,11 @@ import os -import subprocess from textwrap import dedent -import time - -from ...steps.core.core import WaitForUserInputStep -from ...models.main import Range -from ...models.filesystem import RangeInFile -from ...steps.core.core import MessageStep -from ...core.sdk import Models -from ...core.observation import DictObservation, InternalErrorObservation -from ...models.filesystem_edit import AddFile, FileEdit -from ...core.main import Step -from ...core.sdk import ContinueSDK -from ...steps.find_and_replace import FindAndReplaceStep + +from ....plugins.steps.core.core import MessageStep +from ....core.sdk import Models +from ....core.main import Step +from ....core.sdk import ContinueSDK +from ....plugins.steps.find_and_replace import FindAndReplaceStep AI_ASSISTED_STRING = "(✨ AI-Assisted ✨)" @@ -93,5 +86,3 @@ class DeployAirflowStep(Step): # Tell the user to check the schedule and fill in owner, email, other default_args await sdk.run_step(MessageStep(message="Fill in the owner, email, and other default_args in the DAG file with your own personal information. Then the DAG will be ready to run!", name="Fill in default_args")) - - # Run the DAG locally ?? diff --git a/continuedev/src/continuedev/recipes/README.md b/continuedev/src/continuedev/plugins/recipes/README.md index d5a006fb..9860b0e2 100644 --- a/continuedev/src/continuedev/recipes/README.md +++ b/continuedev/src/continuedev/plugins/recipes/README.md @@ -1,5 +1,7 @@ # This is a collaborative collection of Continue recipes +A recipe is technically just a [Step](../steps/README.md), but is intended to be more complex, composed of multiple sub-steps. + Recipes here will automatically be made available in the [Continue VS Code extension](https://marketplace.visualstudio.com/items?itemName=Continue.continue). The `recipes` folder contains all recipes, each with the same structure. **If you wish to create your own recipe, please do the following:** diff --git a/continuedev/src/continuedev/recipes/TemplateRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/README.md index 91d1123b..91d1123b 100644 --- a/continuedev/src/continuedev/recipes/TemplateRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/TemplateRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py index 16132cfd..197abe85 100644 --- a/continuedev/src/continuedev/recipes/TemplateRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/TemplateRecipe/main.py @@ -1,5 +1,7 @@ from typing import Coroutine -from continuedev.core import Step, ContinueSDK, Observation, Models +from ....core.main import Step, Observation +from ....core.sdk import ContinueSDK +from ....core.sdk import Models class TemplateRecipe(Step): diff --git a/continuedev/src/continuedev/recipes/WritePytestsRecipe/README.md b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/README.md index 5ce33ecb..5ce33ecb 100644 --- a/continuedev/src/continuedev/recipes/WritePytestsRecipe/README.md +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/README.md diff --git a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index c7a65fa6..6ef5ffd6 100644 --- a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -1,7 +1,8 @@ from textwrap import dedent from typing import Union -from ...models.filesystem_edit import AddDirectory, AddFile -from ...core.main import Step, ContinueSDK +from ....models.filesystem_edit import AddDirectory, AddFile +from ....core.main import Step +from ....core.sdk import ContinueSDK import os diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md new file mode 100644 index 00000000..12073835 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -0,0 +1,50 @@ +# Steps + +Steps are the composable unit of action in Continue. They define a `run` method which has access to the entire `ContinueSDK`, allowing you to take actions inside the IDE, call language models, and more. In this folder you can find a number of good examples. + +## How to write a step + +a. Start by creating a subclass of `Step` + +You should first consider what will be the parameters of your recipe. These are defined as attributes in the Pydantic class. For example, if you wanted a "filepath" attribute that would look like this: + +```python +class HelloWorldStep(Step): + filepath: str + ... +``` + +b. Next, write the `run` method + +This method takes the ContinueSDK as a parameter, giving you all the tools you need to write your steps (if it's missing something, let us know, we'll add it!). You can write any code inside the run method; this is what will happen when your step is run, line for line. As an example, here's a step that will open a file and append "Hello World!": + +```python +class HelloWorldStep(Step): + filepath: str + + async def run(self, sdk: ContinueSDK): + await sdk.ide.setFileOpen(self.filepath) + await sdk.append_to_file(self.filepath, "Hello World!") +``` + +c. Finally, every Step is displayed with a description of what it has done + +If you'd like to override the default description of your step, which is just the class name, then implement the `describe` method. You can: + +- Return a static string +- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. + +Here's an example: + +```python +class HelloWorldStep(Step): + filepath: str + + async def run(self, sdk: ContinueSDK): + await sdk.ide.setFileOpen(self.filepath) + await sdk.append_to_file(self.filepath, "Hello World!") + + def describe(self, models: Models): + return f"Appended 'Hello World!' to {self.filepath}" +``` diff --git a/continuedev/src/continuedev/steps/__init__.py b/continuedev/src/continuedev/plugins/steps/__init__.py index 8b137891..8b137891 100644 --- a/continuedev/src/continuedev/steps/__init__.py +++ b/continuedev/src/continuedev/plugins/steps/__init__.py diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 14a1cd41..2c662459 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -3,11 +3,12 @@ from typing import Any, Coroutine, List from pydantic import Field +from ...libs.util.strings import remove_quotes_and_escapes from .main import EditHighlightedCodeStep from .core.core import MessageStep -from ..core.main import FunctionCall, Models -from ..core.main import ChatMessage, Step, step_to_json_schema -from ..core.sdk import ContinueSDK +from ...core.main import FunctionCall, Models +from ...core.main import ChatMessage, Step, step_to_json_schema +from ...core.sdk import ContinueSDK import openai import os from dotenv import load_dotenv @@ -28,32 +29,31 @@ class SimpleChatStep(Step): completion = "" messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.gpt4.stream_chat(messages, temperature=0.5) + generator = sdk.models.default.stream_chat( + messages, temperature=sdk.config.temperature) try: async for chunk in generator: if sdk.current_step_was_deleted(): # So that the message doesn't disappear self.hide = False - return + break if "content" in chunk: self.description += chunk["content"] completion += chunk["content"] await sdk.update_ui() finally: - await generator.aclose() - - self.name = (await sdk.models.gpt35.complete( - f"Write a short title for the following chat message: {self.description}")).strip() + self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( + f"Write a short title for the following chat message: {self.description}")) - if self.name.startswith('"') and self.name.endswith('"'): - self.name = self.name[1:-1] + self.chat_context.append(ChatMessage( + role="assistant", + content=completion, + summary=self.name + )) - self.chat_context.append(ChatMessage( - role="assistant", - content=completion, - summary=self.name - )) + # TODO: Never actually closing. + await generator.aclose() class AddFileStep(Step): diff --git a/continuedev/src/continuedev/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 9d085981..dbe8363e 100644 --- a/continuedev/src/continuedev/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -1,10 +1,10 @@ from textwrap import dedent from typing import Coroutine, Union -from ..core.observation import Observation, TextObservation -from ..core.main import Step -from ..core.sdk import ContinueSDK +from ...core.observation import Observation, TextObservation +from ...core.main import Step +from ...core.sdk import ContinueSDK from .core.core import EditFileStep -from ..libs.chroma.query import ChromaIndexManager +from ...libs.chroma.query import ChromaIndexManager from .core.core import EditFileStep diff --git a/continuedev/src/continuedev/steps/clear_history.py b/continuedev/src/continuedev/plugins/steps/clear_history.py index a875c6d3..8f21518b 100644 --- a/continuedev/src/continuedev/steps/clear_history.py +++ b/continuedev/src/continuedev/plugins/steps/clear_history.py @@ -1,5 +1,5 @@ -from ..core.main import Step -from ..core.sdk import ContinueSDK +from ...core.main import Step +from ...core.sdk import ContinueSDK class ClearHistoryStep(Step): diff --git a/continuedev/src/continuedev/steps/comment_code.py b/continuedev/src/continuedev/plugins/steps/comment_code.py index aa17e62c..3e34ab52 100644 --- a/continuedev/src/continuedev/steps/comment_code.py +++ b/continuedev/src/continuedev/plugins/steps/comment_code.py @@ -1,4 +1,4 @@ -from ..core.main import ContinueSDK, Models, Step +from ...core.main import ContinueSDK, Models, Step from .main import EditHighlightedCodeStep diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 5ea95104..5a81e5ee 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,17 +1,19 @@ # These steps are depended upon by ContinueSDK import os import subprocess +import difflib from textwrap import dedent from typing import Coroutine, List, Literal, Union -from ...models.main import Range -from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder -from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit -from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ...core.main import ChatMessage, Step, SequentialStep -from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS -from ...libs.util.dedent import dedent_and_get_common_whitespace +from ....libs.llm.ggml import GGML +from ....models.main import Range +from ....libs.llm.prompt_utils import MarkdownStyleEncoderDecoder +from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit +from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents +from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation +from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep +from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS +from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes import difflib @@ -156,42 +158,32 @@ class DefaultModelEditCodeStep(Step): _new_contents: str = "" _prompt_and_completion: str = "" - def _cleanup_output(self, output: str) -> str: - output = output.replace('\\"', '"') - output = output.replace("\\'", "'") - output = output.replace("\\n", "\n") - output = output.replace("\\t", "\t") - output = output.replace("\\\\", "\\") - if output.startswith('"') and output.endswith('"'): - output = output[1:-1] - - return output - async def describe(self, models: Models) -> Coroutine[str, None, None]: if self._previous_contents.strip() == self._new_contents.strip(): description = "No edits were made" else: + changes = '\n'.join(difflib.ndiff( + self._previous_contents.splitlines(), self._new_contents.splitlines())) description = await models.gpt3516k.complete(dedent(f"""\ - ```original - {self._previous_contents} - ``` + Diff summary: "{self.user_input}" - ```new - {self._new_contents} + ```diff + {changes} ``` Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") - self.name = self._cleanup_output(name) + self.name = remove_quotes_and_escapes(name) - return f"{self._cleanup_output(description)}" + return f"{remove_quotes_and_escapes(description)}" async def get_prompt_parts(self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.gpt4 - max_tokens = DEFAULT_MAX_TOKENS + model_to_use = sdk.models.default + max_tokens = int(MAX_TOKENS_FOR_MODEL.get( + model_to_use.name, DEFAULT_MAX_TOKENS) / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -228,13 +220,13 @@ class DefaultModelEditCodeStep(Step): if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: break - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: - while cur_start_line < max_start_line: - cur_start_line += 1 - total_tokens -= model_to_use.count_tokens( - full_file_contents_lst[cur_end_line]) - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: - break + if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + while cur_start_line < max_start_line: + cur_start_line += 1 + total_tokens -= model_to_use.count_tokens( + full_file_contents_lst[cur_start_line]) + if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + break # Now use the found start/end lines to get the prefix and suffix strings file_prefix = "\n".join( @@ -474,6 +466,14 @@ Please output the code to be inserted at the cursor in order to fulfill the user current_block_lines.append(line) messages = await sdk.get_chat_context() + # Delete the last user and assistant messages + i = len(messages) - 1 + deleted = 0 + while i >= 0 and deleted < 2: + if messages[i].role == "user" or messages[i].role == "assistant": + messages.pop(i) + deleted += 1 + i -= 1 messages.append(ChatMessage( role="user", content=prompt, @@ -486,58 +486,68 @@ Please output the code to be inserted at the cursor in order to fulfill the user completion_lines_covered = 0 repeating_file_suffix = False line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] - async for chunk in model_to_use.stream_chat(messages, temperature=0, max_tokens=max_tokens): - # Stop early if it is repeating the file_suffix or the step was deleted - if repeating_file_suffix: - break - if sdk.current_step_was_deleted(): - return - # Accumulate lines - if "content" not in chunk: - continue - chunk = chunk["content"] - chunk_lines = chunk.split("\n") - chunk_lines[0] = unfinished_line + chunk_lines[0] - if chunk.endswith("\n"): - unfinished_line = "" - chunk_lines.pop() # because this will be an empty string - else: - unfinished_line = chunk_lines.pop() - - # Deal with newly accumulated lines - for i in range(len(chunk_lines)): - # Trailing whitespace doesn't matter - chunk_lines[i] = chunk_lines[i].rstrip() - chunk_lines[i] = common_whitespace + chunk_lines[i] - - # Lines that should signify the end of generation - if self.is_end_line(chunk_lines[i]): - break - # Lines that should be ignored, like the <> tags - elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): - continue - # Check if we are currently just copying the prefix - elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: - # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line - lines_of_prefix_copied += 1 - continue - # Because really short lines might be expected to be repeated, this is only a !heuristic! - # Stop when it starts copying the file_suffix - elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): - repeating_file_suffix = True + if isinstance(model_to_use, GGML): + messages = [ChatMessage( + role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] + + generator = model_to_use.stream_chat( + messages, temperature=sdk.config.temperature, max_tokens=max_tokens) + + try: + async for chunk in generator: + # Stop early if it is repeating the file_suffix or the step was deleted + if repeating_file_suffix: break + if sdk.current_step_was_deleted(): + return - # If none of the above, insert the line! - if False: - await handle_generated_line(chunk_lines[i]) + # Accumulate lines + if "content" not in chunk: + continue + chunk = chunk["content"] + chunk_lines = chunk.split("\n") + chunk_lines[0] = unfinished_line + chunk_lines[0] + if chunk.endswith("\n"): + unfinished_line = "" + chunk_lines.pop() # because this will be an empty string + else: + unfinished_line = chunk_lines.pop() + + # Deal with newly accumulated lines + for i in range(len(chunk_lines)): + # Trailing whitespace doesn't matter + chunk_lines[i] = chunk_lines[i].rstrip() + chunk_lines[i] = common_whitespace + chunk_lines[i] + + # Lines that should signify the end of generation + if self.is_end_line(chunk_lines[i]): + break + # Lines that should be ignored, like the <> tags + elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): + continue + # Check if we are currently just copying the prefix + elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: + # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line + lines_of_prefix_copied += 1 + continue + # Because really short lines might be expected to be repeated, this is only a !heuristic! + # Stop when it starts copying the file_suffix + elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): + repeating_file_suffix = True + break - lines.append(chunk_lines[i]) - completion_lines_covered += 1 - current_line_in_file += 1 + # If none of the above, insert the line! + if False: + await handle_generated_line(chunk_lines[i]) - await sendDiffUpdate(lines + [common_whitespace if unfinished_line.startswith("<") else (common_whitespace + unfinished_line)], sdk) + lines.append(chunk_lines[i]) + completion_lines_covered += 1 + current_line_in_file += 1 + await sendDiffUpdate(lines + [common_whitespace if unfinished_line.startswith("<") else (common_whitespace + unfinished_line)], sdk) + finally: + await generator.aclose() # Add the unfinished line if unfinished_line != "" and not self.line_to_be_ignored(unfinished_line, completion_lines_covered == 0) and not self.is_end_line(unfinished_line): unfinished_line = common_whitespace + unfinished_line @@ -602,6 +612,13 @@ Please output the code to be inserted at the cursor in order to fulfill the user rif_dict[rif.filepath] = rif.contents for rif in rif_with_contents: + # If the file doesn't exist, ask them to save it first + if not os.path.exists(rif.filepath): + message = f"The file {rif.filepath} does not exist. Please save it first." + raise ContinueCustomException( + title=message, message=message + ) + await sdk.ide.setFileOpen(rif.filepath) await sdk.ide.setSuggestionsLocked(rif.filepath, True) await self.stream_rif(rif, sdk) diff --git a/continuedev/src/continuedev/steps/custom_command.py b/continuedev/src/continuedev/plugins/steps/custom_command.py index 5a56efb0..d5b6e48b 100644 --- a/continuedev/src/continuedev/steps/custom_command.py +++ b/continuedev/src/continuedev/plugins/steps/custom_command.py @@ -1,7 +1,7 @@ -from ..core.main import Step -from ..core.sdk import ContinueSDK -from ..steps.core.core import UserInputStep -from ..steps.chat import ChatWithFunctions, SimpleChatStep +from ...libs.util.templating import render_templated_string +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ..steps.chat import SimpleChatStep class CustomCommandStep(Step): @@ -15,7 +15,9 @@ class CustomCommandStep(Step): return self.prompt async def run(self, sdk: ContinueSDK): - prompt_user_input = f"Task: {self.prompt}. Additional info: {self.user_input}" + task = render_templated_string(self.prompt) + + prompt_user_input = f"Task: {task}. Additional info: {self.user_input}" messages = await sdk.get_chat_context() # Find the last chat message with this slash command and replace it with the user input for i in range(len(messages) - 1, -1, -1): diff --git a/continuedev/src/continuedev/steps/draft/abstract_method.py b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py index f3131c4b..f3131c4b 100644 --- a/continuedev/src/continuedev/steps/draft/abstract_method.py +++ b/continuedev/src/continuedev/plugins/steps/draft/abstract_method.py diff --git a/continuedev/src/continuedev/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py index f3b36b5e..a76d491b 100644 --- a/continuedev/src/continuedev/steps/draft/migration.py +++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py @@ -1,7 +1,7 @@ # When an edit is made to an existing class or a new sqlalchemy class is created, # this should be kicked off. -from ...core.main import Step +from ....core.main import Step class MigrationStep(Step): diff --git a/continuedev/src/continuedev/steps/draft/redux.py b/continuedev/src/continuedev/plugins/steps/draft/redux.py index 17506316..30c8fdbb 100644 --- a/continuedev/src/continuedev/steps/draft/redux.py +++ b/continuedev/src/continuedev/plugins/steps/draft/redux.py @@ -1,5 +1,5 @@ -from ...core.main import Step -from ...core.sdk import ContinueSDK +from ....core.main import Step +from ....core.sdk import ContinueSDK from ..core.core import EditFileStep @@ -25,14 +25,14 @@ class EditReduxStateStep(Step): sdk.run_step(EditFileStep( filepath=selector_filename, prompt=f"Edit the selector to add a new property for {self.description}. The store looks like this: {store_file_contents}" - ) + )) # Reducer reducer_filename = "" sdk.run_step(EditFileStep( filepath=reducer_filename, prompt=f"Edit the reducer to add a new property for {self.description}. The store looks like this: {store_file_contents}" - + )) """ Starts with implementing selector 1. RootStore diff --git a/continuedev/src/continuedev/steps/draft/typeorm.py b/continuedev/src/continuedev/plugins/steps/draft/typeorm.py index 153c855f..d06a6fb4 100644 --- a/continuedev/src/continuedev/steps/draft/typeorm.py +++ b/continuedev/src/continuedev/plugins/steps/draft/typeorm.py @@ -1,6 +1,6 @@ from textwrap import dedent -from ...core.main import Step -from ...core.sdk import ContinueSDK +from ....core.main import Step +from ....core.sdk import ContinueSDK class CreateTableStep(Step): diff --git a/continuedev/src/continuedev/steps/feedback.py b/continuedev/src/continuedev/plugins/steps/feedback.py index 6f6a9b15..119e3112 100644 --- a/continuedev/src/continuedev/steps/feedback.py +++ b/continuedev/src/continuedev/plugins/steps/feedback.py @@ -1,8 +1,8 @@ from typing import Coroutine -from ..core.main import Models -from ..core.main import Step -from ..core.sdk import ContinueSDK -from ..libs.util.telemetry import capture_event +from ...core.main import Models +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...libs.util.telemetry import capture_event class FeedbackStep(Step): diff --git a/continuedev/src/continuedev/steps/find_and_replace.py b/continuedev/src/continuedev/plugins/steps/find_and_replace.py index 690872c0..a2c9c44e 100644 --- a/continuedev/src/continuedev/steps/find_and_replace.py +++ b/continuedev/src/continuedev/plugins/steps/find_and_replace.py @@ -1,6 +1,6 @@ -from ..models.filesystem_edit import FileEdit, Range -from ..core.main import Models, Step -from ..core.sdk import ContinueSDK +from ...models.filesystem_edit import FileEdit, Range +from ...core.main import Models, Step +from ...core.sdk import ContinueSDK class FindAndReplaceStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py new file mode 100644 index 00000000..5111c7cf --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -0,0 +1,59 @@ +from textwrap import dedent +from ...core.main import ChatMessage, Step +from ...core.sdk import ContinueSDK +from ...libs.util.telemetry import capture_event + +help = dedent("""\ + Continue is an open-source coding autopilot. It is a VS Code extension that brings the power of ChatGPT to your IDE. + + It gathers context for you and stores your interactions automatically, so that you can avoid copy/paste now and benefit from a customized Large Language Model (LLM) later. + + Continue can be used to... + 1. Edit chunks of code with specific instructions (e.g. "/edit migrate this digital ocean terraform file into one that works for GCP") + 2. Get answers to questions without switching windows (e.g. "how do I find running process on port 8000?") + 3. Generate files from scratch (e.g. "/edit Create a Python CLI tool that uses the posthog api to get events from DAUs") + + You tell Continue to edit a specific section of code by highlighting it. If you highlight multiple code sections, then it will only edit the one with the purple glow around it. You can switch which one has the purple glow by clicking the paint brush. + + If you don't highlight any code, then Continue will insert at the location of your cursor. + + Continue passes all of the sections of code you highlight, the code above and below the to-be edited highlighted code section, and all previous steps above input box as context to the LLM. + + You can use cmd+m (Mac) / ctrl+m (Windows) to open Continue. You can use cmd+shift+e / ctrl+shift+e to open file Explorer. You can add your own OpenAI API key to VS Code Settings with `cmd+,` + + If Continue is stuck loading, try using `cmd+shift+p` to open the command palette, search "Reload Window", and then select it. This will reload VS Code and Continue and often fixes issues. + + If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""") + + +class HelpStep(Step): + + name: str = "Help" + user_input: str + manage_own_chat_context: bool = True + description: str = "" + + async def run(self, sdk: ContinueSDK): + + question = self.user_input + + prompt = dedent(f"""Please us the information below to provide a succinct answer to the following quesiton: {question} + + Information: + + {help}""") + + self.chat_context.append(ChatMessage( + role="user", + content=prompt, + summary="Help" + )) + messages = await sdk.get_chat_context() + generator = sdk.models.gpt4.stream_chat(messages) + async for chunk in generator: + if "content" in chunk: + self.description += chunk["content"] + await sdk.update_ui() + + capture_event(sdk.ide.unique_id, "help", { + "question": question, "answer": self.description}) diff --git a/continuedev/src/continuedev/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index aee22866..b54d394a 100644 --- a/continuedev/src/continuedev/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -1,7 +1,7 @@ from typing import List, Union from ..core.core import WaitForUserInputStep -from ...core.main import Step -from ...core.sdk import ContinueSDK +from ....core.main import Step +from ....core.sdk import ContinueSDK class NLMultiselectStep(Step): diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index ce7cbc60..30117c55 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -1,21 +1,18 @@ import os from typing import Coroutine, List, Union - +from textwrap import dedent from pydantic import BaseModel, Field -from ..libs.llm import LLM -from ..models.main import Traceback, Range -from ..models.filesystem_edit import EditDiff, FileEdit -from ..models.filesystem import RangeInFile, RangeInFileWithContents -from ..core.observation import Observation, TextObservation, TracebackObservation -from ..libs.llm.prompt_utils import MarkdownStyleEncoderDecoder -from textwrap import dedent -from ..core.main import ContinueCustomException, Step -from ..core.sdk import ContinueSDK, Models -from ..core.observation import Observation -import subprocess +from ...models.main import Traceback, Range +from ...models.filesystem_edit import EditDiff, FileEdit +from ...models.filesystem import RangeInFile, RangeInFileWithContents +from ...core.observation import Observation +from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder +from ...core.main import ContinueCustomException, Step +from ...core.sdk import ContinueSDK, Models +from ...core.observation import Observation from .core.core import DefaultModelEditCodeStep -from ..libs.util.calculate_diff import calculate_diff2 +from ...libs.util.calculate_diff import calculate_diff2 class SetupContinueWorkspaceStep(Step): @@ -303,8 +300,7 @@ class SolveTracebackStep(Step): range_in_files.append( RangeInFile.from_entire_file(frame.filepath, content)) - await sdk.run_step(EditCodeStep( - range_in_files=range_in_files, prompt=prompt)) + await sdk.run_step(DefaultModelEditCodeStep(range_in_files=range_in_files, user_input=prompt)) return None diff --git a/continuedev/src/continuedev/steps/on_traceback.py b/continuedev/src/continuedev/plugins/steps/on_traceback.py index efb4c703..e99f212d 100644 --- a/continuedev/src/continuedev/steps/on_traceback.py +++ b/continuedev/src/continuedev/plugins/steps/on_traceback.py @@ -1,8 +1,8 @@ import os from .core.core import UserInputStep -from ..core.main import ChatMessage, Step -from ..core.sdk import ContinueSDK +from ...core.main import ChatMessage, Step +from ...core.sdk import ContinueSDK from .chat import SimpleChatStep diff --git a/continuedev/src/continuedev/steps/open_config.py b/continuedev/src/continuedev/plugins/steps/open_config.py index 87f03e9f..d950c26f 100644 --- a/continuedev/src/continuedev/steps/open_config.py +++ b/continuedev/src/continuedev/plugins/steps/open_config.py @@ -1,6 +1,6 @@ from textwrap import dedent -from ..core.main import Step -from ..core.sdk import ContinueSDK +from ...core.main import Step +from ...core.sdk import ContinueSDK import os @@ -14,10 +14,10 @@ class OpenConfigStep(Step): "custom_commands": [ { "name": "test", - "description": "Write unit tests like I do for the highlighted code" + "description": "Write unit tests like I do for the highlighted code", "prompt": "Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated." } - ], + ] ``` `"name"` is the command you will type. `"description"` is the description displayed in the slash command menu. diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index cddb8b42..8b2e7c2e 100644 --- a/continuedev/src/continuedev/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -1,8 +1,7 @@ from textwrap import dedent from typing import List, Union, Tuple -from ..core.main import Step -from ..core.sdk import ContinueSDK -from .core.core import MessageStep +from ...core.main import Step +from ...core.sdk import ContinueSDK class NLDecisionStep(Step): diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index bfb97630..7d02d6fa 100644 --- a/continuedev/src/continuedev/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -2,11 +2,11 @@ import asyncio from textwrap import dedent from typing import List, Union -from ..models.filesystem import RangeInFile -from ..models.main import Range -from ..core.main import Step -from ..core.sdk import ContinueSDK -from ..libs.util.create_async_task import create_async_task +from ...models.filesystem import RangeInFile +from ...models.main import Range +from ...core.main import Step +from ...core.sdk import ContinueSDK +from ...libs.util.create_async_task import create_async_task import os import re diff --git a/continuedev/src/continuedev/plugins/steps/steps_on_startup.py b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py new file mode 100644 index 00000000..19d62d30 --- /dev/null +++ b/continuedev/src/continuedev/plugins/steps/steps_on_startup.py @@ -0,0 +1,17 @@ +from ...core.main import Step +from ...core.sdk import Models, ContinueSDK +from ...libs.util.step_name_to_steps import get_step_from_name + + +class StepsOnStartupStep(Step): + hide: bool = True + + async def describe(self, models: Models): + return "Running steps on startup" + + async def run(self, sdk: ContinueSDK): + steps_on_startup = sdk.config.steps_on_startup + + for step_name, step_params in steps_on_startup.items(): + step = get_step_from_name(step_name, step_params) + await sdk.run_step(step) diff --git a/continuedev/src/continuedev/steps/welcome.py b/continuedev/src/continuedev/plugins/steps/welcome.py index 2dece649..df3e9a8a 100644 --- a/continuedev/src/continuedev/steps/welcome.py +++ b/continuedev/src/continuedev/plugins/steps/welcome.py @@ -1,9 +1,10 @@ from textwrap import dedent -from ..models.filesystem_edit import AddFile -from ..core.main import Step -from ..core.sdk import ContinueSDK, Models import os +from ...models.filesystem_edit import AddFile +from ...core.main import Step +from ...core.sdk import ContinueSDK, Models + class WelcomeStep(Step): name: str = "Welcome to Continue!" diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 238273b2..ae57c0b6 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,3 +1,4 @@ +import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter from starlette.websockets import WebSocketState, WebSocketDisconnect @@ -53,15 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.session = session async def _send_json(self, message_type: str, data: Any): - if self.websocket.client_state == WebSocketState.DISCONNECTED: + if self.websocket.application_state == WebSocketState.DISCONNECTED: return await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str) -> Any: - return await self.sub_queue.get(message_type) + async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + try: + return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + except asyncio.TimeoutError: + raise Exception( + "GUI Protocol _receive_json timed out after 5 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -94,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_set_editing_at_indices(data["indices"]) elif message_type == "set_pinned_at_indices": self.on_set_pinned_at_indices(data["indices"]) + elif message_type == "show_logs_at_index": + self.on_show_logs_at_index(data["index"]) except Exception as e: print(e) @@ -161,6 +168,13 @@ class GUIProtocolServer(AbstractGUIProtocolServer): indices), self.session.autopilot.continue_sdk.ide.unique_id ) + def on_show_logs_at_index(self, index: int): + name = f"continue_logs.txt" + logs = "\n\n############################################\n\n".join( + ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) + create_async_task( + self.session.autopilot.ide.showVirtualFile(name, logs)) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 12a21f19..aeff5623 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -123,10 +123,13 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.websocket = websocket self.session_manager = session_manager - workspace_directory: str + workspace_directory: str = None + unique_id: str = None - async def initialize(self) -> List[str]: + async def initialize(self, session_id: str) -> List[str]: + self.session_id = session_id await self._send_json("workspaceDirectory", {}) + await self._send_json("uniqueId", {}) other_msgs = [] while True: msg_string = await self.websocket.receive_text() @@ -137,21 +140,29 @@ class IdeProtocolServer(AbstractIdeProtocolServer): data = message["data"] if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] - break + elif message_type == "uniqueId": + self.unique_id = data["uniqueId"] else: other_msgs.append(msg_string) + + if self.workspace_directory is not None and self.unique_id is not None: + break return other_msgs async def _send_json(self, message_type: str, data: Any): - if self.websocket.client_state == WebSocketState.DISCONNECTED: + if self.websocket.application_state == WebSocketState.DISCONNECTED: return await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str) -> Any: - return await self.sub_queue.get(message_type) + async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + try: + return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + except asyncio.TimeoutError: + raise Exception( + "IDE Protocol _receive_json timed out after 5 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -183,10 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": self.onDeleteAtIndex(data["index"]) - elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand", "uniqueId"]: + elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]: self.sub_queue.post(message_type, data) elif message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] + elif message_type == "uniqueId": + self.unique_id = data["uniqueId"] else: raise ValueError("Unknown message type", message_type) @@ -211,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "open": open }) + async def showVirtualFile(self, name: str, contents: str): + await self._send_json("showVirtualFile", { + "name": name, + "contents": contents + }) + async def setSuggestionsLocked(self, filepath: str, locked: bool = True): # Lock suggestions in the file so they don't ruin the offset before others are inserted await self._send_json("setSuggestionsLocked", { @@ -219,8 +238,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer): }) async def getSessionId(self): - session_id = self.session_manager.new_session( - self, self.session_id).session_id + session_id = (await self.session_manager.new_session( + self, self.session_id)).session_id await self._send_json("getSessionId", { "sessionId": session_id }) @@ -274,33 +293,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onOpenGUIRequest(self): pass + def __get_autopilot(self): + if self.session_id not in self.session_manager.sessions: + return None + return self.session_manager.sessions[self.session_id].autopilot + def onFileEdits(self, edits: List[FileEditWithFullContents]): - # Send the file edits to ALL autopilots. - # Maybe not ideal behavior - for _, session in self.session_manager.sessions.items(): - session.autopilot.handle_manual_edits(edits) + if autopilot := self.__get_autopilot(): + autopilot.handle_manual_edits(edits) def onDeleteAtIndex(self, index: int): - for _, session in self.session_manager.sessions.items(): - create_async_task( - session.autopilot.delete_at_index(index), self.unique_id) + if autopilot := self.__get_autopilot(): + create_async_task(autopilot.delete_at_index(index), self.unique_id) def onCommandOutput(self, output: str): - # Send the output to ALL autopilots. - # Maybe not ideal behavior - for _, session in self.session_manager.sessions.items(): + if autopilot := self.__get_autopilot(): create_async_task( - session.autopilot.handle_command_output(output), self.unique_id) + autopilot.handle_command_output(output), self.unique_id) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): - for _, session in self.session_manager.sessions.items(): - create_async_task( - session.autopilot.handle_highlighted_code(range_in_files), self.unique_id) + if autopilot := self.__get_autopilot(): + create_async_task(autopilot.handle_highlighted_code( + range_in_files), self.unique_id) def onMainUserInput(self, input: str): - for _, session in self.session_manager.sessions.items(): + if autopilot := self.__get_autopilot(): create_async_task( - session.autopilot.accept_user_input(input), self.unique_id) + autopilot.accept_user_input(input), self.unique_id) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: @@ -311,14 +330,6 @@ class IdeProtocolServer(AbstractIdeProtocolServer): resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles") return resp.visibleFiles - async def get_unique_id(self) -> str: - resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId") - return resp.uniqueId - - @cached_property_no_none - def unique_id(self) -> str: - return asyncio.run(self.get_unique_id()) - async def getHighlightedCode(self) -> List[RangeInFile]: resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode") return resp.highlightedCode @@ -436,10 +447,11 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): ideProtocolServer.handle_json(message_type, data)) ideProtocolServer = IdeProtocolServer(session_manager, websocket) - ideProtocolServer.session_id = session_id if session_id is not None: session_manager.registered_ides[session_id] = ideProtocolServer - other_msgs = await ideProtocolServer.initialize() + other_msgs = await ideProtocolServer.initialize(session_id) + capture_event(ideProtocolServer.unique_id, "session_started", { + "session_id": ideProtocolServer.session_id}) for other_msg in other_msgs: handle_msg(other_msg) @@ -460,4 +472,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() + capture_event(ideProtocolServer.unique_id, "session_ended", { + "session_id": ideProtocolServer.session_id}) session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index 2f78cf0e..0ae7e7fa 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -24,6 +24,10 @@ class AbstractIdeProtocolServer(ABC): """Set whether a file is open""" @abstractmethod + async def showVirtualFile(self, name: str, contents: str): + """Show a virtual file""" + + @abstractmethod async def setSuggestionsLocked(self, filepath: str, locked: bool = True): """Set whether suggestions are locked""" @@ -108,7 +112,4 @@ class AbstractIdeProtocolServer(ABC): """Show a diff""" workspace_directory: str - - @abstractproperty - def unique_id(self) -> str: - """Get a unique ID for this IDE""" + unique_id: str diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index aa093853..42dc0cc1 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,5 +1,6 @@ +import time +import psutil import os -import sys from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router @@ -51,9 +52,31 @@ def cleanup(): session_manager.persist_session(session_id) +def cpu_usage_report(): + process = psutil.Process(os.getpid()) + # Call cpu_percent once to start measurement, but ignore the result + process.cpu_percent(interval=None) + # Wait for a short period of time + time.sleep(1) + # Call cpu_percent again to get the CPU usage over the interval + cpu_usage = process.cpu_percent(interval=None) + print(f"CPU usage: {cpu_usage}%") + + atexit.register(cleanup) + if __name__ == "__main__": try: + # import threading + + # def cpu_usage_loop(): + # while True: + # cpu_usage_report() + # time.sleep(2) + + # cpu_thread = threading.Thread(target=cpu_usage_loop) + # cpu_thread.start() + run_server() except Exception as e: cleanup() diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index fb8ac386..20219273 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -7,7 +7,7 @@ import json from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER -from ..core.policy import DemoPolicy +from ..core.policy import DefaultPolicy from ..core.main import FullState from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer @@ -53,19 +53,19 @@ class SessionManager: session_files = os.listdir(sessions_folder) if f"{session_id}.json" in session_files and session_id in self.registered_ides: if self.registered_ides[session_id].session_id is not None: - return self.new_session(self.registered_ides[session_id], session_id=session_id) + return await self.new_session(self.registered_ides[session_id], session_id=session_id) raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: full_state = FullState(**json.load(f)) - autopilot = DemoAutopilot( - policy=DemoPolicy(), ide=ide, full_state=full_state) + autopilot = await DemoAutopilot.create( + policy=DefaultPolicy(), ide=ide, full_state=full_state) session_id = session_id or str(uuid4()) ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) @@ -100,7 +100,7 @@ class SessionManager: if session_id not in self.sessions: raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: - print(f"Session {session_id} has no websocket") + # print(f"Session {session_id} has no websocket") return await self.sessions[session_id].ws.send_json({ diff --git a/continuedev/src/continuedev/steps/steps_on_startup.py b/continuedev/src/continuedev/steps/steps_on_startup.py deleted file mode 100644 index 365cbe1a..00000000 --- a/continuedev/src/continuedev/steps/steps_on_startup.py +++ /dev/null @@ -1,23 +0,0 @@ -from ..core.main import Step -from ..core.sdk import Models, ContinueSDK -from .main import UserInputStep -from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe -from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe -from ..recipes.AddTransformRecipe.main import AddTransformRecipe -from ..libs.util.step_name_to_steps import get_step_from_name - - -class StepsOnStartupStep(Step): - hide: bool = True - - async def describe(self, models: Models): - return "Running steps on startup" - - async def run(self, sdk: ContinueSDK): - steps_on_startup = sdk.config.steps_on_startup - - for step_name, step_params in steps_on_startup.items(): - step = get_step_from_name(step_name, step_params) - await sdk.run_step(step) |