diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-19 00:33:50 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-19 00:33:50 -0700 |
commit | 1b92180d4b7720bf1cf36dd63142760d421dabf8 (patch) | |
tree | 26e25e005b06526267c2a140c1fbf1cbf822f066 /continuedev/src | |
parent | 924a0c09259d25a4dfe62c0a626a9204df45daa9 (diff) | |
parent | a7c57e1d1e4a0eff3e4b598f8bf0448ea6068353 (diff) | |
download | sncontinue-1b92180d4b7720bf1cf36dd63142760d421dabf8.tar.gz sncontinue-1b92180d4b7720bf1cf36dd63142760d421dabf8.tar.bz2 sncontinue-1b92180d4b7720bf1cf36dd63142760d421dabf8.zip |
Merge branch 'main' into config-py
Diffstat (limited to 'continuedev/src')
34 files changed, 1227 insertions, 443 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 7bd3da6c..94d7be10 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC): async def get_user_secret(self, env_var: str, prompt: str) -> str: pass - @abstractproperty - def config(self) -> ContinueConfig: - pass + config: ContinueConfig @abstractmethod def set_loading_message(self, message: str): diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 118744f9..1f3e6323 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,13 +1,13 @@ from functools import cached_property import traceback import time -from typing import Any, Callable, Coroutine, Dict, List +from typing import Any, Callable, Coroutine, Dict, List, Union import os from aiohttp import ClientPayloadError +from pydantic import root_validator from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.llm import LLM from .observation import Observation, InternalErrorObservation from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue @@ -16,9 +16,10 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from ..libs.util.telemetry import capture_event from .sdk import ContinueSDK -import asyncio +from ..libs.util.step_name_to_steps import get_step_from_name from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback from openai import error as openai_errors +from ..libs.util.create_async_task import create_async_task def get_error_title(e: Exception) -> str: @@ -33,9 +34,11 @@ def get_error_title(e: Exception) -> str: elif isinstance(e, ClientPayloadError): return "The request to OpenAI failed. Please try again." elif isinstance(e, openai_errors.APIConnectionError): - return "The request failed. Please check your internet connection and try again." + return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\"" elif isinstance(e, openai_errors.InvalidRequestError): - return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' + return 'Invalid request sent to OpenAI. Please try again.' + elif e.__str__().startswith("Cannot connect to host"): + return "The request failed. Please check your internet connection and try again." return e.__str__() or e.__repr__() @@ -44,8 +47,11 @@ class Autopilot(ContinueBaseModel): ide: AbstractIdeProtocolServer history: History = History.from_empty() context: Context = Context() + full_state: Union[FullState, None] = None _on_update_callbacks: List[Callable[[FullState], None]] = [] + continue_sdk: ContinueSDK = None + _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] @@ -53,16 +59,25 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @cached_property - def continue_sdk(self) -> ContinueSDK: - return ContinueSDK(self) + @classmethod + async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": + autopilot = cls(ide=ide, policy=policy) + autopilot.continue_sdk = await ContinueSDK.create(autopilot) + return autopilot class Config: arbitrary_types_allowed = True keep_untouched = (cached_property,) + @root_validator(pre=True) + def fill_in_values(cls, values): + full_state: FullState = values.get('full_state') + if full_state is not None: + values['history'] = full_state.history + return values + def get_full_state(self) -> FullState: - return FullState( + full_state = FullState( history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, @@ -71,6 +86,8 @@ class Autopilot(ContinueBaseModel): slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self._adding_highlighted_code, ) + self.full_state = full_state + return full_state def get_available_slash_commands(self) -> List[Dict]: custom_commands = list(map(lambda x: { @@ -83,9 +100,14 @@ class Autopilot(ContinueBaseModel): self.continue_sdk.update_default_model(model) async def clear_history(self): + # Reset history self.history = History.from_empty() self._main_user_input_queue = [] self._active = False + + # Also remove all context + self._highlighted_ranges = [] + await self.update_subscribers() def on_update(self, callback: Coroutine["FullState", None, None]): @@ -148,31 +170,46 @@ class Autopilot(ContinueBaseModel): if not any(map(lambda x: x.editing, self._highlighted_ranges)): self._highlighted_ranges[0].editing = True - async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): - - # If un-highlighting, then remove the range - if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end) and not self._adding_highlighted_code: - self._highlighted_ranges = [] - await self.update_subscribers() - return + def _disambiguate_highlighted_ranges(self): + """If any files have the same name, also display their folder name""" + name_status: Dict[str, set] = { + } # basename -> set of full paths with that basename + for rif in self._highlighted_ranges: + basename = os.path.basename(rif.range.filepath) + if basename in name_status: + name_status[basename].add(rif.range.filepath) + else: + name_status[basename] = {rif.range.filepath} - # If not toggled to be adding context, only edit or add the first range - if not self._adding_highlighted_code and len(self._highlighted_ranges) > 0: - if len(range_in_files) == 0: - return - if range_in_files[0].range.overlaps_with(self._highlighted_ranges[0].range.range) and range_in_files[0].filepath == self._highlighted_ranges[0].range.filepath: - self._highlighted_ranges = [HighlightedRangeContext( - range=range_in_files[0].range, editing=True, pinned=False)] - await self.update_subscribers() - return + for rif in self._highlighted_ranges: + basename = os.path.basename(rif.range.filepath) + if len(name_status[basename]) > 1: + rif.display_name = os.path.join( + os.path.basename(os.path.dirname(rif.range.filepath)), basename) + else: + rif.display_name = basename + async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): # Filter out rifs from ~/.continue/diffs folder range_in_files = [ rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] + # Make sure all filepaths are relative to workspace workspace_path = self.continue_sdk.ide.workspace_directory - for rif in range_in_files: - rif.filepath = os.path.basename(rif.filepath) + + # If not adding highlighted code + if not self._adding_highlighted_code: + if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): + # If un-highlighting the range to edit, then remove the range + self._highlighted_ranges = [] + await self.update_subscribers() + elif len(range_in_files) > 0: + # Otherwise, replace the current range with the new one + # This is the first range to be highlighted + self._highlighted_ranges = [HighlightedRangeContext( + range=range_in_files[0], editing=True, pinned=False, display_name=os.path.basename(range_in_files[0].filepath))] + await self.update_subscribers() + return # If current range overlaps with any others, delete them and only keep the new range new_ranges = [] @@ -193,10 +230,11 @@ class Autopilot(ContinueBaseModel): new_ranges.append(rif) self._highlighted_ranges = new_ranges + [HighlightedRangeContext( - range=rif, editing=False, pinned=False + range=rif, editing=False, pinned=False, display_name=os.path.basename(rif.filepath) ) for rif in range_in_files] self._make_sure_is_editing_range() + self._disambiguate_highlighted_ranges() await self.update_subscribers() @@ -209,6 +247,8 @@ class Autopilot(ContinueBaseModel): async def delete_at_index(self, index: int): self.history.timeline[index].step.hide = True self.history.timeline[index].deleted = True + self.history.timeline[index].active = False + await self.update_subscribers() async def delete_context_at_indices(self, indices: List[int]): @@ -252,7 +292,7 @@ class Autopilot(ContinueBaseModel): # i -= 1 capture_event(self.continue_sdk.ide.unique_id, 'step run', { - 'step_name': step.name, 'params': step.dict()}) + 'step_name': step.name, 'params': step.dict()}) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep @@ -286,12 +326,13 @@ class Autopilot(ContinueBaseModel): e.__class__, ContinueCustomException) error_string = e.message if is_continue_custom_exception else '\n'.join( - traceback.format_tb(e.__traceback__)) + f"\n\n{e.__repr__()}" + traceback.format_exception(e)) error_title = e.title if is_continue_custom_exception else get_error_title( e) # Attach an InternalErrorObservation to the step and unhide it. - print(f"Error while running step: \n{error_string}\n{error_title}") + print( + f"Error while running step: \n{error_string}\n{error_title}") capture_event(self.continue_sdk.ide.unique_id, 'step error', { 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) @@ -343,7 +384,8 @@ class Autopilot(ContinueBaseModel): # Update subscribers with new description await self.update_subscribers() - asyncio.create_task(update_description()) + create_async_task(update_description(), + self.continue_sdk.ide.unique_id) return observation diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 8c7ed2fd..54f15143 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -24,16 +24,22 @@ class OnTracebackSteps(BaseModel): params: Optional[Dict] = {} +class AzureInfo(BaseModel): + endpoint: str + engine: str + api_version: str + + class ContinueConfig(BaseModel): """ A pydantic class for the continue config file. """ steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] - server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4"] = 'gpt-4' + "gpt-4", "claude-2", "ggml"] = 'gpt-4' + temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", @@ -41,21 +47,16 @@ class ContinueConfig(BaseModel): )] slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] + system_message: Optional[str] = None + azure_openai_info: Optional[AzureInfo] = None # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) def default_slash_commands_validator(cls, v): - from ..steps.core.core import UserInputStep from ..steps.open_config import OpenConfigStep from ..steps.clear_history import ClearHistoryStep - from ..steps.on_traceback import DefaultOnTracebackStep - from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe - from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe - from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe - from ..recipes.AddTransformRecipe.main import AddTransformRecipe from ..steps.feedback import FeedbackStep from ..steps.comment_code import CommentCodeStep - from ..steps.chat import SimpleChatStep from ..steps.main import EditHighlightedCodeStep DEFAULT_SLASH_COMMANDS = [ @@ -88,6 +89,10 @@ class ContinueConfig(BaseModel): return DEFAULT_SLASH_COMMANDS + v + @validator('temperature', pre=True) + def temperature_validator(cls, v): + return max(0.0, min(1.0, v)) + def load_config(config_file: str) -> ContinueConfig: """ @@ -133,7 +138,7 @@ def load_global_config() -> ContinueConfig: config_path = os.path.join(global_dir, 'config.json') if not os.path.exists(config_path): with open(config_path, 'w') as f: - json.dump(ContinueConfig().dict(), f) + json.dump(ContinueConfig().dict(), f, indent=4) with open(config_path, 'r') as f: try: config_dict = json.load(f) @@ -153,7 +158,7 @@ def update_global_config(config: ContinueConfig): yaml_path = os.path.join(global_dir, 'config.yaml') if os.path.exists(yaml_path): with open(config_path, 'w') as f: - yaml.dump(config.dict(), f) + yaml.dump(config.dict(), f, indent=4) else: config_path = os.path.join(global_dir, 'config.json') with open(config_path, 'w') as f: diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 4ea17f20..50d01f8d 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -102,6 +102,7 @@ class HistoryNode(ContinueBaseModel): depth: int deleted: bool = False active: bool = True + logs: List[str] = [] def to_chat_messages(self) -> List[ChatMessage]: if self.step.description is None or self.step.manage_own_chat_context: @@ -205,6 +206,7 @@ class HighlightedRangeContext(ContinueBaseModel): range: RangeInFileWithContents editing: bool pinned: bool + display_name: str class FullState(ContinueBaseModel): @@ -257,10 +259,8 @@ class Step(ContinueBaseModel): def dict(self, *args, **kwargs): d = super().dict(*args, **kwargs) - if self.description is not None: - d["description"] = self.description - else: - d["description"] = "" + # Make sure description is always a string + d["description"] = self.description or "" return d @validator("name", pre=True, always=True) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 86e85e7e..53e482fa 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,22 +1,12 @@ from textwrap import dedent -from typing import List, Tuple, Type, Union +from typing import Union +from ..steps.chat import SimpleChatStep from ..steps.welcome import WelcomeStep from .config import ContinueConfig -from ..steps.chroma import AnswerQuestionChroma, EditFileChroma, CreateCodebaseIndexChroma from ..steps.steps_on_startup import StepsOnStartupStep -from ..recipes.CreatePipelineRecipe.main import CreatePipelineRecipe -from ..recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRecipe -from ..recipes.AddTransformRecipe.main import AddTransformRecipe -from .main import Step, Validator, History, Policy -from .observation import Observation, TracebackObservation, UserInputObservation -from ..steps.main import EditHighlightedCodeStep, SolveTracebackStep -from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe -from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep -from ..steps.comment_code import CommentCodeStep -from ..steps.react import NLDecisionStep -from ..steps.chat import SimpleChatStep, ChatWithFunctions, EditFileStep, AddFileStep -from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe +from .main import Step, History, Policy +from .observation import UserInputObservation from ..steps.core.core import MessageStep from ..steps.custom_command import CustomCommandStep @@ -61,12 +51,9 @@ class DemoPolicy(Policy): if history.get_current() is None: return ( MessageStep(name="Welcome to Continue", message=dedent("""\ - - Highlight code and ask a question or give instructions - - Use `cmd+k` (Mac) / `ctrl+k` (Windows) to open Continue - - Use `cmd+shift+e` / `ctrl+shift+e` to open file Explorer - - Add your own OpenAI API key to VS Code Settings with `cmd+,` - - Use slash commands when you want fine-grained control - - Past steps are included as part of the context by default""")) >> + - Highlight code section and ask a question or give instructions + - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue + - Use `/help` to ask questions about how to use Continue""")) >> WelcomeStep() >> # SetupContinueWorkspaceStep() >> # CreateCodebaseIndexChroma() >> diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 48d41f7c..4100efa6 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,6 +1,6 @@ import asyncio from functools import cached_property -from typing import Coroutine, Union +from typing import Coroutine, Dict, Union import os from ..steps.core.core import DefaultModelEditCodeStep @@ -11,9 +11,11 @@ from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFi from ..models.filesystem import RangeInFile from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI from ..libs.llm.openai import OpenAI +from ..libs.llm.anthropic import AnthropicLLM +from ..libs.llm.ggml import GGML from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, History, Step, ChatMessage, ChatMessageRole +from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage from ..steps.core.core import * from ..libs.llm.proxy_server import ProxyServer @@ -22,26 +24,78 @@ class Autopilot: pass +ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] +MODEL_PROVIDER_TO_ENV_VAR = { + "openai": "OPENAI_API_KEY", + "hf_inference_api": "HUGGING_FACE_TOKEN", + "anthropic": "ANTHROPIC_API_KEY", +} + + class Models: - def __init__(self, sdk: "ContinueSDK"): + provider_keys: Dict[ModelProvider, str] = {} + model_providers: List[ModelProvider] + system_message: str + + """ + Better to have sdk.llm.stream_chat(messages, model="claude-2"). + Then you also don't care that it' async. + And it's easier to add more models. + And intermediate shared code is easier to add. + And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" + PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. + Easy to reason about, can place anywhere. + And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. + This can all happen inside of Models? + + class Prompt: + def __init__(self, ...info): + '''take whatever info is needed to describe the prompt''' + + def to_string(self, model: str) -> str: + '''depending on the model, return the single prompt string''' + """ + + def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk + self.model_providers = model_providers + self.system_message = sdk.config.system_message + + @classmethod + async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + if sdk.config.default_model == "claude-2": + with_providers.append("anthropic") + + models = Models(sdk, with_providers) + for provider in with_providers: + if provider in MODEL_PROVIDER_TO_ENV_VAR: + env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] + models.provider_keys[provider] = await sdk.get_user_secret( + env_var, f'Please add your {env_var} to the .env file') + + return models def __load_openai_model(self, model: str) -> OpenAI: - async def load_openai_model(): - api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free') - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) - return asyncio.get_event_loop().run_until_complete(load_openai_model()) + api_key = self.provider_keys["openai"] + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) + + def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: + api_key = self.provider_keys["hf_inference_api"] + return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) + + def __load_anthropic_model(self, model: str) -> AnthropicLLM: + api_key = self.provider_keys["anthropic"] + return AnthropicLLM(api_key, model, self.system_message) + + @cached_property + def claude2(self): + return self.__load_anthropic_model("claude-2") @cached_property def starcoder(self): - async def load_starcoder(): - api_key = await self.sdk.get_user_secret( - 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') - return HuggingFaceInferenceAPI(api_key=api_key) - return asyncio.get_event_loop().run_until_complete(load_starcoder()) + return self.__load_hf_inference_api_model("bigcode/starcoder") @cached_property def gpt35(self): @@ -59,6 +113,10 @@ class Models: def gpt4(self): return self.__load_openai_model("gpt-4") + @cached_property + def ggml(self): + return GGML(system_message=self.system_message) + def __model_from_name(self, model_name: str): if model_name == "starcoder": return self.starcoder @@ -68,13 +126,17 @@ class Models: return self.gpt3516k elif model_name == "gpt-4": return self.gpt4 + elif model_name == "claude-2": + return self.claude2 + elif model_name == "ggml": + return self.ggml else: raise Exception(f"Unknown model {model_name}") @property def default(self): default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt35 + return self.__model_from_name(default_model) if default_model is not None else self.gpt4 class ContinueSDK(AbstractContinueSDK): @@ -82,18 +144,44 @@ class ContinueSDK(AbstractContinueSDK): ide: AbstractIdeProtocolServer models: Models context: Context + config: ContinueConfig __autopilot: Autopilot def __init__(self, autopilot: Autopilot): self.ide = autopilot.ide self.__autopilot = autopilot - self.models = Models(self) self.context = autopilot.context + @classmethod + async def create(cls, autopilot: Autopilot) -> "ContinueSDK": + sdk = ContinueSDK(autopilot) + + try: + config = sdk._load_config_dot_py() + sdk.config = config + except Exception as e: + print(e) + sdk.config = ContinueConfig() + msg_step = MessageStep( + name="Invalid Continue Config File", message=e.__repr__()) + msg_step.description = e.__repr__() + sdk.history.add_node(HistoryNode( + step=msg_step, + observation=None, + depth=0, + active=False + )) + + sdk.models = await Models.create(sdk) + return sdk + @property def history(self) -> History: return self.__autopilot.history + def write_log(self, message: str): + self.history.timeline[self.history.current_index].logs.append(message) + async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): return path @@ -168,8 +256,9 @@ class ContinueSDK(AbstractContinueSDK): _last_valid_config: ContinueConfig = None - def load_config_dot_py(self, path: str) -> ContinueConfig: + def _load_config_dot_py(self) -> ContinueConfig: # Use importlib to load the config file config.py at the given path + path = os.path.join(os.path.expanduser("~"), ".continue", "config.py") try: import importlib.util spec = importlib.util.spec_from_file_location("config", path) @@ -181,22 +270,10 @@ class ContinueSDK(AbstractContinueSDK): print("Error loading config.py: ", e) return ContinueConfig() if self._last_valid_config is None else self._last_valid_config - @property - def config(self) -> ContinueConfig: - # TODO: Workspace config files should override global - dir = self.ide.workspace_directory - path = os.path.join(dir, '.continue', 'config.py') - if not os.path.exists(path): - global_dir = os.path.expanduser('~/.continue') - if not os.path.exists(global_dir): - os.mkdir(global_dir) - path = os.path.join(global_dir, 'config.py') - if not os.path.exists(path): - # Need to copy over the default config - return ContinueConfig() - - config = self.load_config_dot_py(path) - return config + def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: + context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges) + ) if only_editing else self.__autopilot._highlighted_ranges + return [c.range for c in context] def update_default_model(self, model: str): config = self.config @@ -217,18 +294,18 @@ class ContinueSDK(AbstractContinueSDK): preface = "The following code is highlighted" + # If no higlighted ranges, use first file as context if len(highlighted_code) == 0: preface = "The following file is open" - # Get the full contents of all open files - files = await self.ide.getOpenFiles() - if len(files) > 0: - content = await self.ide.readFile(files[0]) + visible_files = await self.ide.getVisibleFiles() + if len(visible_files) > 0: + content = await self.ide.readFile(visible_files[0]) highlighted_code = [ - RangeInFileWithContents.from_entire_file(files[0], content)] + RangeInFileWithContents.from_entire_file(visible_files[0], content)] for rif in highlighted_code: msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", - role="system", summary=f"{preface}: {rif.filepath}") + role="user", summary=f"{preface}: {rif.filepath}") # Don't insert after latest user message or function call i = -1 diff --git a/continuedev/src/continuedev/libs/constants/main.py b/continuedev/src/continuedev/libs/constants/main.py new file mode 100644 index 00000000..96eb6e69 --- /dev/null +++ b/continuedev/src/continuedev/libs/constants/main.py @@ -0,0 +1,6 @@ +## PATHS ## + +CONTINUE_GLOBAL_FOLDER = ".continue" +CONTINUE_SESSIONS_FOLDER = "sessions" +CONTINUE_SERVER_FOLDER = "server" + diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py new file mode 100644 index 00000000..c82895c6 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -0,0 +1,97 @@ + +from functools import cached_property +import time +from typing import Any, Coroutine, Dict, Generator, List, Union +from ...core.main import ChatMessage +from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic +from ..llm import LLM +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top + + +class AnthropicLLM(LLM): + api_key: str + default_model: str + async_client: AsyncAnthropic + + def __init__(self, api_key: str, default_model: str, system_message: str = None): + self.api_key = api_key + self.default_model = default_model + self.system_message = system_message + + self.async_client = AsyncAnthropic(api_key=api_key) + + @cached_property + def name(self): + return self.default_model + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.default_model} + + def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: + args = args.copy() + if "max_tokens" in args: + args["max_tokens_to_sample"] = args["max_tokens"] + del args["max_tokens"] + if "frequency_penalty" in args: + del args["frequency_penalty"] + if "presence_penalty" in args: + del args["presence_penalty"] + return args + + def count_tokens(self, text: str): + return count_tokens(self.default_model, text) + + def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if len(messages) > 0 and messages[0]["role"] != "user" and messages[0]["role"] != "system": + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + args = self._transform_args(args) + + async for chunk in await self.async_client.completions.create( + prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", + **args + ): + yield chunk.completion + + async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + args = self._transform_args(args) + + messages = compile_chat_messages( + args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + async for chunk in await self.async_client.completions.create( + prompt=self.__messages_to_prompt(messages), + **args + ): + yield { + "role": "assistant", + "content": chunk.completion + } + + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + args = self._transform_args(args) + + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + resp = (await self.async_client.completions.create( + prompt=self.__messages_to_prompt(messages), + **args + )).completion + + return resp diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py new file mode 100644 index 00000000..6007fdb4 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -0,0 +1,86 @@ +from functools import cached_property +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + +SERVER_URL = "http://localhost:8000" + + +class GGML(LLM): + + def __init__(self, system_message: str = None): + self.system_message = system_message + + @cached_property + def name(self): + return "ggml" + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream"] = True + + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": messages, + **args + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) + + async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["stream"] = True + + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ + "messages": messages, + **args + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + **args + }) as resp: + try: + return await resp.text() + except: + raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 1586c620..7e11fbbe 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -9,7 +9,12 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): api_key: str - model: str = "bigcode/starcoder" + model: str + + def __init__(self, api_key: str, model: str, system_message: str = None): + self.api_key = api_key + self.model = model + self.system_message = system_message # TODO: Nothing being done with this def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c4e4139f..64bb39a2 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,30 +1,43 @@ from functools import cached_property -import time -from typing import Any, Coroutine, Dict, Generator, List, Union +import json +from typing import Any, Callable, Coroutine, Dict, Generator, List, Union + from ...core.main import ChatMessage import openai from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top +from ...core.config import AzureInfo class OpenAI(LLM): api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None): + def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): self.api_key = api_key self.default_model = default_model self.system_message = system_message + self.azure_info = azure_info + self.write_log = write_log openai.api_key = api_key + # Using an Azure OpenAI deployment + if azure_info is not None: + openai.api_type = "azure" + openai.api_base = azure_info.endpoint + openai.api_version = azure_info.api_version + @cached_property def name(self): return self.default_model @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + args = {**DEFAULT_ARGS, "model": self.default_model} + if self.azure_info is not None: + args["engine"] = self.azure_info.engine + return args def count_tokens(self, text: str): return count_tokens(self.default_model, text) @@ -35,18 +48,29 @@ class OpenAI(LLM): args["stream"] = True if args["model"] in CHAT_MODELS: + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + completion = "" async for chunk in await openai.ChatCompletion.acreate( - messages=compile_chat_messages( - args["model"], with_history, prompt, functions=None), + messages=messages, **args, ): if "content" in chunk.choices[0].delta: yield chunk.choices[0].delta.content + completion += chunk.choices[0].delta.content else: continue + + self.write_log(f"Completion: \n\n{completion}") else: + self.write_log(f"Prompt:\n\n{prompt}") + completion = "" async for chunk in await openai.Completion.acreate(prompt=prompt, **args): yield chunk.choices[0].text + completion += chunk.choices[0].text + + self.write_log(f"Completion:\n\n{completion}") async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() @@ -56,27 +80,39 @@ class OpenAI(LLM): if not args["model"].endswith("0613") and "functions" in args: del args["functions"] + messages = compile_chat_messages( + args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + completion = "" async for chunk in await openai.ChatCompletion.acreate( - messages=compile_chat_messages( - args["model"], messages, functions=args.get("functions", None)), + messages=messages, **args, ): yield chunk.choices[0].delta + if "content" in chunk.choices[0].delta: + completion += chunk.choices[0].delta.content + self.write_log(f"Completion: \n\n{completion}") async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} if args["model"] in CHAT_MODELS: + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") resp = (await openai.ChatCompletion.acreate( - messages=compile_chat_messages( - args["model"], with_history, prompt, functions=None), + messages=messages, **args, )).choices[0].message.content + self.write_log(f"Completion: \n\n{resp}") else: + prompt = prune_raw_prompt_from_top( + args["model"], prompt, args["max_tokens"]) + self.write_log(f"Prompt:\n\n{prompt}") resp = (await openai.Completion.acreate( - prompt=prune_raw_prompt_from_top( - args["model"], prompt, args["max_tokens"]), + prompt=prompt, **args, )).choices[0].text + self.write_log(f"Completion:\n\n{resp}") return resp diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 05ece394..bd50fe02 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,10 +1,13 @@ + from functools import cached_property import json -from typing import Any, Coroutine, Dict, Generator, List, Literal, Union +import traceback +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union import aiohttp +from ..util.telemetry import capture_event from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens +from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens, format_chat_messages import certifi import ssl @@ -19,12 +22,14 @@ class ProxyServer(LLM): unique_id: str name: str default_model: Literal["gpt-3.5-turbo", "gpt-4"] + write_log: Callable[[str], None] - def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None): + def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): self.unique_id = unique_id self.default_model = default_model self.system_message = system_message self.name = default_model + self.write_log = write_log @property def default_args(self): @@ -36,21 +41,27 @@ class ProxyServer(LLM): async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ - "messages": compile_chat_messages(args["model"], with_history, prompt, functions=None), + "messages": messages, "unique_id": self.unique_id, **args }) as resp: try: - return await resp.text() + response_text = await resp.text() + self.write_log(f"Completion: \n\n{response_text}") + return response_text except: raise Exception(await resp.text()) async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, None, functions=args.get("functions", None)) + args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -59,6 +70,7 @@ class ProxyServer(LLM): **args }) as resp: # This is streaming application/json instaed of text/event-stream + completion = "" async for line in resp.content.iter_chunks(): if line[1]: try: @@ -67,14 +79,21 @@ class ProxyServer(LLM): chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": - yield json.loads(chunk) - except: - raise Exception(str(line[0])) + loaded_chunk = json.loads(chunk) + yield loaded_chunk + if "content" in loaded_chunk: + completion += loaded_chunk["content"] + except Exception as e: + capture_event(self.unique_id, "proxy_server_parse_error", { + "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + + self.write_log(f"Completion: \n\n{completion}") async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, prompt, functions=args.get("functions", None)) + self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ @@ -82,9 +101,13 @@ class ProxyServer(LLM): "unique_id": self.unique_id, **args }) as resp: + completion = "" async for line in resp.content.iter_any(): if line: try: - yield line.decode("utf-8") + decoded_line = line.decode("utf-8") + yield decoded_line + completion += decoded_line except: raise Exception(str(line)) + self.write_log(f"Completion: \n\n{completion}") diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 8b06fef9..987aa722 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,15 +1,21 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage +from .templating import render_templated_string import tiktoken -aliases = {} +aliases = { + "ggml": "gpt-3.5-turbo", + "claude-2": "gpt-3.5-turbo", +} DEFAULT_MAX_TOKENS = 2048 MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0613": 4096, "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192 + "gpt-4": 8192, + "ggml": 2048, + "claude-2": 100000 } CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" @@ -40,9 +46,17 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in return encoding.decode(tokens[-max_tokens:]) +def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE + + def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): total_tokens = tokens_for_completion + \ - sum(count_tokens(model, message.content) + sum(count_chat_message_tokens(model, message) for message in chat_history) # 1. Replace beyond last 5 messages with summary @@ -68,34 +82,64 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: message.content = message.summary i += 1 - # 4. Remove entire messages in the last 5 - while total_tokens > max_tokens and len(chat_history) > 0: + # 4. Remove entire messages in the last 5, except last 1 + while total_tokens > max_tokens and len(chat_history) > 1: message = chat_history.pop(0) total_tokens -= count_tokens(model, message.content) + # 5. Truncate last message + if total_tokens > max_tokens and len(chat_history) > 0: + message = chat_history[0] + message.content = prune_raw_prompt_from_top( + model, message.content, tokens_for_completion) + total_tokens = max_tokens + return chat_history -def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: - prompt_tokens = count_tokens(model, prompt) +# In case we've missed weird edge cases +TOKEN_BUFFER_FOR_SAFETY = 100 + + +def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: + """ + The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message + """ + if prompt is not None: + prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) + msgs += [prompt_msg] + + if system_message is not None: + # NOTE: System message takes second precedence to user prompt, so it is placed just before + # but move back to start after processing + rendered_system_message = render_templated_string(system_message) + system_chat_msg = ChatMessage( + role="system", content=rendered_system_message, summary=rendered_system_message) + # insert at second-to-last position + msgs.insert(-1, system_chat_msg) + + # Add tokens from functions + function_tokens = 0 if functions is not None: for function in functions: - prompt_tokens += count_tokens(model, json.dumps(function)) - - msgs = prune_chat_history(model, - msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + DEFAULT_MAX_TOKENS + count_tokens(model, system_message)) - history = [] - if system_message: - history.append({ - "role": "system", - "content": system_message - }) - history += [msg.to_dict(with_functions=functions is not None) - for msg in msgs] - if prompt: - history.append({ - "role": "user", - "content": prompt - }) + function_tokens += count_tokens(model, json.dumps(function)) + + msgs = prune_chat_history( + model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + + history = [msg.to_dict(with_functions=functions is not None) + for msg in msgs] + + # Move system message back to start + if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system": + system_message_dict = history.pop(-2) + history.insert(0, system_message_dict) return history + + +def format_chat_messages(messages: List[ChatMessage]) -> str: + formatted = "" + for msg in messages: + formatted += f"<{msg['role'].capitalize()}>\n{msg['content']}\n\n" + return formatted diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py new file mode 100644 index 00000000..354cea82 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -0,0 +1,24 @@ +from typing import Coroutine, Union +import traceback +from .telemetry import capture_event +import asyncio +import nest_asyncio +nest_asyncio.apply() + + +def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): + """asyncio.create_task and log errors by adding a callback""" + task = asyncio.create_task(coro) + + def callback(future: asyncio.Future): + try: + future.result() + except Exception as e: + print("Exception caught from async task: ", + '\n'.join(traceback.format_exception(e))) + capture_event(unique_id or "None", "async_task_error", { + "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) + }) + + task.add_done_callback(callback) + return task diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py new file mode 100644 index 00000000..46074cfc --- /dev/null +++ b/continuedev/src/continuedev/libs/util/errors.py @@ -0,0 +1,2 @@ +class SessionNotFound(Exception): + pass diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py new file mode 100644 index 00000000..fddef887 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -0,0 +1,17 @@ +import os + +from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER + +def getGlobalFolderPath(): + return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) + + + +def getSessionsFolderPath(): + return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) + +def getServerFolderPath(): + return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) + +def getSessionFilePath(session_id: str): + return os.path.join(getSessionsFolderPath(), f"{session_id}.json")
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/dedent.py b/continuedev/src/continuedev/libs/util/strings.py index e59c2e97..f1fb8d0b 100644 --- a/continuedev/src/continuedev/libs/util/dedent.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -23,3 +23,27 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: break return "\n".join(map(lambda x: x.lstrip(lcp), lines)), lcp + + +def remove_quotes_and_escapes(output: str) -> str: + """ + Clean up the output of the completion API, removing unnecessary escapes and quotes + """ + output = output.strip() + + # Replace smart quotes + output = output.replace("“", '"') + output = output.replace("”", '"') + output = output.replace("‘", "'") + output = output.replace("’", "'") + + # Remove escapes + output = output.replace('\\"', '"') + output = output.replace("\\'", "'") + output = output.replace("\\n", "\n") + output = output.replace("\\t", "\t") + output = output.replace("\\\\", "\\") + if (output.startswith('"') and output.endswith('"')) or (output.startswith("'") and output.endswith("'")): + output = output[1:-1] + + return output diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py new file mode 100644 index 00000000..bb922ad7 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -0,0 +1,39 @@ +import os +import chevron + + +def get_vars_in_template(template): + """ + Get the variables in a template + """ + return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] + + +def escape_var(var: str) -> str: + """ + Escape a variable so it can be used in a template + """ + return var.replace(os.path.sep, '').replace('.', '') + + +def render_templated_string(template: str) -> str: + """ + Render system message or other templated string with mustache syntax. + Right now it only supports rendering absolute file paths as their contents. + """ + vars = get_vars_in_template(template) + + args = {} + for var in vars: + if var.startswith(os.path.sep): + # Escape vars which are filenames, because mustache doesn't allow / in variable names + escaped_var = escape_var(var) + template = template.replace( + var, escaped_var) + + if os.path.exists(var): + args[escaped_var] = open(var, 'r').read() + else: + args[escaped_var] = '' + + return chevron.render(template, args) diff --git a/continuedev/src/continuedev/recipes/TemplateRecipe/main.py b/continuedev/src/continuedev/recipes/TemplateRecipe/main.py index 94675725..16132cfd 100644 --- a/continuedev/src/continuedev/recipes/TemplateRecipe/main.py +++ b/continuedev/src/continuedev/recipes/TemplateRecipe/main.py @@ -20,8 +20,8 @@ class TemplateRecipe(Step): # The code executed when the recipe is run async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - open_files = await sdk.ide.getOpenFiles() + visible_files = await sdk.ide.getVisibleFiles() await sdk.edit_file( - filename=open_files[0], + filename=visible_files[0], prompt=f"Append a statement to print `Hello, {self.name}!` at the end of the file." ) diff --git a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py index 6e1244b3..c7a65fa6 100644 --- a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py @@ -14,7 +14,7 @@ class WritePytestsRecipe(Step): async def run(self, sdk: ContinueSDK): if self.for_filepath is None: - self.for_filepath = (await sdk.ide.getOpenFiles())[0] + self.for_filepath = (await sdk.ide.getVisibleFiles())[0] filename = os.path.basename(self.for_filepath) dirname = os.path.dirname(self.for_filepath) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 8e9b1fb9..ae57c0b6 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,15 +1,17 @@ +import asyncio import json from fastapi import Depends, Header, WebSocket, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect from typing import Any, List, Type, TypeVar, Union from pydantic import BaseModel +import traceback from uvicorn.main import Server from .session_manager import SessionManager, session_manager, Session from .gui_protocol import AbstractGUIProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue -import asyncio -import nest_asyncio -nest_asyncio.apply() +from ..libs.util.telemetry import capture_event +from ..libs.util.create_async_task import create_async_task router = APIRouter(prefix="/gui", tags=["gui"]) @@ -30,12 +32,12 @@ class AppStatus: Server.handle_exit = AppStatus.handle_exit -def session(x_continue_session_id: str = Header("anonymous")) -> Session: - return session_manager.get_session(x_continue_session_id) +async def session(x_continue_session_id: str = Header("anonymous")) -> Session: + return await session_manager.get_session(x_continue_session_id) -def websocket_session(session_id: str) -> Session: - return session_manager.get_session(session_id) +async def websocket_session(session_id: str) -> Session: + return await session_manager.get_session(session_id) T = TypeVar("T", bound=BaseModel) @@ -52,13 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.session = session async def _send_json(self, message_type: str, data: Any): + if self.websocket.application_state == WebSocketState.DISCONNECTED: + return await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str) -> Any: - return await self.sub_queue.get(message_type) + async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + try: + return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + except asyncio.TimeoutError: + raise Exception( + "GUI Protocol _receive_json timed out after 5 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -91,6 +99,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_set_editing_at_indices(data["indices"]) elif message_type == "set_pinned_at_indices": self.on_set_pinned_at_indices(data["indices"]) + elif message_type == "show_logs_at_index": + self.on_show_logs_at_index(data["index"]) except Exception as e: print(e) @@ -102,53 +112,69 @@ class GUIProtocolServer(AbstractGUIProtocolServer): def on_main_input(self, input: str): # Do something with user input - asyncio.create_task(self.session.autopilot.accept_user_input(input)) + create_async_task(self.session.autopilot.accept_user_input( + input), self.session.autopilot.continue_sdk.ide.unique_id) def on_reverse_to_index(self, index: int): # Reverse the history to the given index - asyncio.create_task(self.session.autopilot.reverse_to_index(index)) + create_async_task(self.session.autopilot.reverse_to_index( + index), self.session.autopilot.continue_sdk.ide.unique_id) def on_step_user_input(self, input: str, index: int): - asyncio.create_task( - self.session.autopilot.give_user_input(input, index)) + create_async_task( + self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) def on_refinement_input(self, input: str, index: int): - asyncio.create_task( - self.session.autopilot.accept_refinement_input(input, index)) + create_async_task( + self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) def on_retry_at_index(self, index: int): - asyncio.create_task( - self.session.autopilot.retry_at_index(index)) + create_async_task( + self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id) def on_change_default_model(self, model: str): - asyncio.create_task(self.session.autopilot.change_default_model(model)) + create_async_task(self.session.autopilot.change_default_model( + model), self.session.autopilot.continue_sdk.ide.unique_id) def on_clear_history(self): - asyncio.create_task(self.session.autopilot.clear_history()) + create_async_task(self.session.autopilot.clear_history( + ), self.session.autopilot.continue_sdk.ide.unique_id) def on_delete_at_index(self, index: int): - asyncio.create_task(self.session.autopilot.delete_at_index(index)) + create_async_task(self.session.autopilot.delete_at_index( + index), self.session.autopilot.continue_sdk.ide.unique_id) def on_delete_context_at_indices(self, indices: List[int]): - asyncio.create_task( - self.session.autopilot.delete_context_at_indices(indices) + create_async_task( + self.session.autopilot.delete_context_at_indices( + indices), self.session.autopilot.continue_sdk.ide.unique_id ) def on_toggle_adding_highlighted_code(self): - asyncio.create_task( - self.session.autopilot.toggle_adding_highlighted_code() + create_async_task( + self.session.autopilot.toggle_adding_highlighted_code( + ), self.session.autopilot.continue_sdk.ide.unique_id ) def on_set_editing_at_indices(self, indices: List[int]): - asyncio.create_task( - self.session.autopilot.set_editing_at_indices(indices) + create_async_task( + self.session.autopilot.set_editing_at_indices( + indices), self.session.autopilot.continue_sdk.ide.unique_id ) def on_set_pinned_at_indices(self, indices: List[int]): - asyncio.create_task( - self.session.autopilot.set_pinned_at_indices(indices) + create_async_task( + self.session.autopilot.set_pinned_at_indices( + indices), self.session.autopilot.continue_sdk.ide.unique_id ) + def on_show_logs_at_index(self, index: int): + name = f"continue_logs.txt" + logs = "\n\n############################################\n\n".join( + ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) + create_async_task( + self.session.autopilot.ide.showVirtualFile(name, logs)) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): @@ -176,11 +202,17 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we data = message["data"] protocol.handle_json(message_type, data) - + except WebSocketDisconnect as e: + print("GUI websocket disconnected") except Exception as e: print("ERROR in gui websocket: ", e) + capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", { + "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) raise e finally: print("Closing gui websocket") - await websocket.close() + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() + + session_manager.persist_session(session.session_id) session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e4a6266a..aeff5623 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -5,7 +5,9 @@ import os from typing import Any, Dict, List, Type, TypeVar, Union import uuid from fastapi import WebSocket, Body, APIRouter +from starlette.websockets import WebSocketState, WebSocketDisconnect from uvicorn.main import Server +import traceback from ..libs.util.telemetry import capture_event from ..libs.util.queue import AsyncSubscriptionQueue @@ -15,6 +17,7 @@ from pydantic import BaseModel from .gui import SessionManager, session_manager from .ide_protocol import AbstractIdeProtocolServer import asyncio +from ..libs.util.create_async_task import create_async_task import nest_asyncio nest_asyncio.apply() @@ -50,6 +53,10 @@ class OpenFilesResponse(BaseModel): openFiles: List[str] +class VisibleFilesResponse(BaseModel): + visibleFiles: List[str] + + class HighlightedCodeResponse(BaseModel): highlightedCode: List[RangeInFile] @@ -110,19 +117,52 @@ class IdeProtocolServer(AbstractIdeProtocolServer): websocket: WebSocket session_manager: SessionManager sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() + session_id: Union[str, None] = None def __init__(self, session_manager: SessionManager, websocket: WebSocket): self.websocket = websocket self.session_manager = session_manager + workspace_directory: str = None + unique_id: str = None + + async def initialize(self, session_id: str) -> List[str]: + self.session_id = session_id + await self._send_json("workspaceDirectory", {}) + await self._send_json("uniqueId", {}) + other_msgs = [] + while True: + msg_string = await self.websocket.receive_text() + message = json.loads(msg_string) + if "messageType" not in message or "data" not in message: + continue + message_type = message["messageType"] + data = message["data"] + if message_type == "workspaceDirectory": + self.workspace_directory = data["workspaceDirectory"] + elif message_type == "uniqueId": + self.unique_id = data["uniqueId"] + else: + other_msgs.append(msg_string) + + if self.workspace_directory is not None and self.unique_id is not None: + break + return other_msgs + async def _send_json(self, message_type: str, data: Any): + if self.websocket.application_state == WebSocketState.DISCONNECTED: + return await self.websocket.send_json({ "messageType": message_type, "data": data }) - async def _receive_json(self, message_type: str) -> Any: - return await self.sub_queue.get(message_type) + async def _receive_json(self, message_type: str, timeout: int = 5) -> Any: + try: + return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) + except asyncio.TimeoutError: + raise Exception( + "IDE Protocol _receive_json timed out after 5 seconds") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -130,8 +170,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer): return resp_model.parse_obj(resp) async def handle_json(self, message_type: str, data: Any): - if message_type == "openGUI": - await self.openGUI() + if message_type == "getSessionId": + await self.getSessionId() elif message_type == "setFileOpen": await self.setFileOpen(data["filepath"], data["open"]) elif message_type == "setSuggestionsLocked": @@ -154,8 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): self.onMainUserInput(data["input"]) elif message_type == "deleteAtIndex": self.onDeleteAtIndex(data["index"]) - elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]: + elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]: self.sub_queue.post(message_type, data) + elif message_type == "workspaceDirectory": + self.workspace_directory = data["workspaceDirectory"] + elif message_type == "uniqueId": + self.unique_id = data["uniqueId"] else: raise ValueError("Unknown message type", message_type) @@ -180,6 +224,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "open": open }) + async def showVirtualFile(self, name: str, contents: str): + await self._send_json("showVirtualFile", { + "name": name, + "contents": contents + }) + async def setSuggestionsLocked(self, filepath: str, locked: bool = True): # Lock suggestions in the file so they don't ruin the offset before others are inserted await self._send_json("setSuggestionsLocked", { @@ -187,9 +237,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): "locked": locked }) - async def openGUI(self): - session_id = self.session_manager.new_session(self) - await self._send_json("openGUI", { + async def getSessionId(self): + session_id = (await self.session_manager.new_session( + self, self.session_id)).session_id + await self._send_json("getSessionId", { "sessionId": session_id }) @@ -242,53 +293,42 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onOpenGUIRequest(self): pass + def __get_autopilot(self): + if self.session_id not in self.session_manager.sessions: + return None + return self.session_manager.sessions[self.session_id].autopilot + def onFileEdits(self, edits: List[FileEditWithFullContents]): - # Send the file edits to ALL autopilots. - # Maybe not ideal behavior - for _, session in self.session_manager.sessions.items(): - session.autopilot.handle_manual_edits(edits) + if autopilot := self.__get_autopilot(): + autopilot.handle_manual_edits(edits) def onDeleteAtIndex(self, index: int): - for _, session in self.session_manager.sessions.items(): - asyncio.create_task(session.autopilot.delete_at_index(index)) + if autopilot := self.__get_autopilot(): + create_async_task(autopilot.delete_at_index(index), self.unique_id) def onCommandOutput(self, output: str): - # Send the output to ALL autopilots. - # Maybe not ideal behavior - for _, session in self.session_manager.sessions.items(): - asyncio.create_task( - session.autopilot.handle_command_output(output)) + if autopilot := self.__get_autopilot(): + create_async_task( + autopilot.handle_command_output(output), self.unique_id) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): - for _, session in self.session_manager.sessions.items(): - asyncio.create_task( - session.autopilot.handle_highlighted_code(range_in_files)) + if autopilot := self.__get_autopilot(): + create_async_task(autopilot.handle_highlighted_code( + range_in_files), self.unique_id) def onMainUserInput(self, input: str): - for _, session in self.session_manager.sessions.items(): - asyncio.create_task( - session.autopilot.accept_user_input(input)) + if autopilot := self.__get_autopilot(): + create_async_task( + autopilot.accept_user_input(input), self.unique_id) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles") return resp.openFiles - async def getWorkspaceDirectory(self) -> str: - resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory") - return resp.workspaceDirectory - - async def get_unique_id(self) -> str: - resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId") - return resp.uniqueId - - @property - def workspace_directory(self) -> str: - return asyncio.run(self.getWorkspaceDirectory()) - - @cached_property_no_none - def unique_id(self) -> str: - return asyncio.run(self.get_unique_id()) + async def getVisibleFiles(self) -> List[str]: + resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles") + return resp.visibleFiles async def getHighlightedCode(self) -> List[RangeInFile]: resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode") @@ -389,28 +429,49 @@ class IdeProtocolServer(AbstractIdeProtocolServer): @router.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint(websocket: WebSocket, session_id: str = None): try: await websocket.accept() print("Accepted websocket connection from, ", websocket.client) await websocket.send_json({"messageType": "connected", "data": {}}) - ideProtocolServer = IdeProtocolServer(session_manager, websocket) - - while AppStatus.should_exit is False: - message = await websocket.receive_text() - message = json.loads(message) + def handle_msg(msg): + message = json.loads(msg) if "messageType" not in message or "data" not in message: - continue + return message_type = message["messageType"] data = message["data"] - await ideProtocolServer.handle_json(message_type, data) + create_async_task( + ideProtocolServer.handle_json(message_type, data)) + + ideProtocolServer = IdeProtocolServer(session_manager, websocket) + if session_id is not None: + session_manager.registered_ides[session_id] = ideProtocolServer + other_msgs = await ideProtocolServer.initialize(session_id) + capture_event(ideProtocolServer.unique_id, "session_started", { + "session_id": ideProtocolServer.session_id}) + + for other_msg in other_msgs: + handle_msg(other_msg) + + while AppStatus.should_exit is False: + message = await websocket.receive_text() + handle_msg(message) print("Closing ide websocket") - await websocket.close() + except WebSocketDisconnect as e: + print("IDE wbsocket disconnected") except Exception as e: print("Error in ide websocket: ", e) - await websocket.close() + capture_event(ideProtocolServer.unique_id, "gui_error", { + "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) raise e + finally: + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() + + capture_event(ideProtocolServer.unique_id, "session_ended", { + "session_id": ideProtocolServer.session_id}) + session_manager.registered_ides.pop(ideProtocolServer.session_id) diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index dfdca504..0ae7e7fa 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -1,5 +1,6 @@ -from typing import Any, List +from typing import Any, List, Union from abc import ABC, abstractmethod, abstractproperty +from fastapi import WebSocket from ..models.main import Traceback from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff @@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents class AbstractIdeProtocolServer(ABC): + websocket: WebSocket + session_id: Union[str, None] + @abstractmethod async def handle_json(self, data: Any): """Handle a json message""" @@ -16,20 +20,20 @@ class AbstractIdeProtocolServer(ABC): """Show a suggestion to the user""" @abstractmethod - async def getWorkspaceDirectory(self): - """Get the workspace directory""" - - @abstractmethod async def setFileOpen(self, filepath: str, open: bool = True): """Set whether a file is open""" @abstractmethod + async def showVirtualFile(self, name: str, contents: str): + """Show a virtual file""" + + @abstractmethod async def setSuggestionsLocked(self, filepath: str, locked: bool = True): """Set whether suggestions are locked""" @abstractmethod - async def openGUI(self): - """Open a GUI""" + async def getSessionId(self): + """Get a new session ID""" @abstractmethod async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool: @@ -56,6 +60,10 @@ class AbstractIdeProtocolServer(ABC): """Get a list of open files""" @abstractmethod + async def getVisibleFiles(self) -> List[str]: + """Get a list of visible files""" + + @abstractmethod async def getHighlightedCode(self) -> List[RangeInFile]: """Get a list of highlighted code""" @@ -103,10 +111,5 @@ class AbstractIdeProtocolServer(ABC): async def showDiff(self, filepath: str, replacement: str, step_index: int): """Show a diff""" - @abstractproperty - def workspace_directory(self) -> str: - """Get the workspace directory""" - - @abstractproperty - def unique_id(self) -> str: - """Get a unique ID for this IDE""" + workspace_directory: str + unique_id: str diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index f4d82903..42dc0cc1 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,10 +1,12 @@ +import time +import psutil import os -import sys from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router from .gui import router as gui_router -import logging +from .session_manager import session_manager +import atexit import uvicorn import argparse @@ -44,5 +46,38 @@ def run_server(): uvicorn.run(app, host="0.0.0.0", port=args.port) +def cleanup(): + print("Cleaning up sessions") + for session_id in session_manager.sessions: + session_manager.persist_session(session_id) + + +def cpu_usage_report(): + process = psutil.Process(os.getpid()) + # Call cpu_percent once to start measurement, but ignore the result + process.cpu_percent(interval=None) + # Wait for a short period of time + time.sleep(1) + # Call cpu_percent again to get the CPU usage over the interval + cpu_usage = process.cpu_percent(interval=None) + print(f"CPU usage: {cpu_usage}%") + + +atexit.register(cleanup) + if __name__ == "__main__": - run_server() + try: + # import threading + + # def cpu_usage_loop(): + # while True: + # cpu_usage_report() + # time.sleep(2) + + # cpu_thread = threading.Thread(target=cpu_usage_loop) + # cpu_thread.start() + + run_server() + except Exception as e: + cleanup() + raise e diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 99a38146..90172a4e 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,20 +1,24 @@ +import os from fastapi import WebSocket from typing import Any, Dict, List, Union from uuid import uuid4 +import json +from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents +from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER from ..core.policy import DemoPolicy from ..core.main import FullState from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer -import asyncio -import nest_asyncio -nest_asyncio.apply() +from ..libs.util.create_async_task import create_async_task +from ..libs.util.errors import SessionNotFound class Session: session_id: str autopilot: Autopilot + # The GUI websocket for the session ws: Union[WebSocket, None] def __init__(self, session_id: str, autopilot: Autopilot): @@ -38,18 +42,35 @@ class DemoAutopilot(Autopilot): class SessionManager: sessions: Dict[str, Session] = {} - _event_loop: Union[asyncio.BaseEventLoop, None] = None + # Mapping of session_id to IDE, where the IDE is still alive + registered_ides: Dict[str, AbstractIdeProtocolServer] = {} - def get_session(self, session_id: str) -> Session: + async def get_session(self, session_id: str) -> Session: if session_id not in self.sessions: + # Check then whether it is persisted by listing all files in the sessions folder + # And only if the IDE is still alive + sessions_folder = getSessionsFolderPath() + session_files = os.listdir(sessions_folder) + if f"{session_id}.json" in session_files and session_id in self.registered_ides: + if self.registered_ides[session_id].session_id is not None: + return await self.new_session(self.registered_ides[session_id], session_id=session_id) + raise KeyError("Session ID not recognized", session_id) return self.sessions[session_id] - def new_session(self, ide: AbstractIdeProtocolServer) -> str: - autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide) - session_id = str(uuid4()) + async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + full_state = None + if session_id is not None and os.path.exists(getSessionFilePath(session_id)): + with open(getSessionFilePath(session_id), "r") as f: + full_state = FullState(**json.load(f)) + + autopilot = await DemoAutopilot.create( + policy=DemoPolicy(), ide=ide, full_state=full_state) + session_id = session_id or str(uuid4()) + ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) self.sessions[session_id] = session + self.registered_ides[session_id] = ide async def on_update(state: FullState): await session_manager.send_ws_data(session_id, "state_update", { @@ -57,19 +78,29 @@ class SessionManager: }) autopilot.on_update(on_update) - asyncio.create_task(autopilot.run_policy()) - return session_id + create_async_task(autopilot.run_policy()) + return session def remove_session(self, session_id: str): del self.sessions[session_id] + def persist_session(self, session_id: str): + """Save the session's FullState as a json file""" + full_state = self.sessions[session_id].autopilot.get_full_state() + if not os.path.exists(getSessionsFolderPath()): + os.mkdir(getSessionsFolderPath()) + with open(getSessionFilePath(session_id), "w") as f: + json.dump(full_state.dict(), f) + def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws print("Registered websocket for session", session_id) async def send_ws_data(self, session_id: str, message_type: str, data: Any): + if session_id not in self.sessions: + raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: - print(f"Session {session_id} has no websocket") + # print(f"Session {session_id} has no websocket") return await self.sessions[session_id].ws.send_json({ diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py deleted file mode 100644 index c9bd760b..00000000 --- a/continuedev/src/continuedev/server/state_manager.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Any, List, Tuple, Union -from fastapi import WebSocket -from pydantic import BaseModel -from ..core.main import FullState - -# State updates represented as (path, replacement) pairs -StateUpdate = Tuple[List[Union[str, int]], Any] - - -class StateManager: - """ - A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client. - """ - - def __init__(self, ws: WebSocket): - self.ws = ws - - def _send_update(self, updates: List[StateUpdate]): - self.ws.send_json( - [update.dict() for update in updates] - ) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index a10319d8..aade1ea1 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -3,6 +3,7 @@ from typing import Any, Coroutine, List from pydantic import Field +from ..libs.util.strings import remove_quotes_and_escapes from .main import EditHighlightedCodeStep from .core.core import MessageStep from ..core.main import FunctionCall, Models @@ -27,26 +28,32 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): completion = "" messages = self.messages or await sdk.get_chat_context() - async for chunk in sdk.models.gpt4.stream_chat(messages, temperature=0.5): - if sdk.current_step_was_deleted(): - return - if "content" in chunk: - self.description += chunk["content"] - completion += chunk["content"] - await sdk.update_ui() - - self.name = (await sdk.models.gpt35.complete( - f"Write a short title for the following chat message: {self.description}")).strip() - - if self.name.startswith('"') and self.name.endswith('"'): - self.name = self.name[1:-1] + generator = sdk.models.default.stream_chat( + messages, temperature=sdk.config.temperature) + try: + async for chunk in generator: + if sdk.current_step_was_deleted(): + # So that the message doesn't disappear + self.hide = False + break - self.chat_context.append(ChatMessage( - role="assistant", - content=completion, - summary=self.name - )) + if "content" in chunk: + self.description += chunk["content"] + completion += chunk["content"] + await sdk.update_ui() + finally: + self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( + f"Write a short title for the following chat message: {self.description}")) + + self.chat_context.append(ChatMessage( + role="assistant", + content=completion, + summary=self.name + )) + + # TODO: Never actually closing. + await generator.aclose() class AddFileStep(Step): diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 10853828..4afc36e8 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -1,17 +1,19 @@ # These steps are depended upon by ContinueSDK import os import subprocess +import difflib from textwrap import dedent from typing import Coroutine, List, Literal, Union +from ...libs.llm.ggml import GGML from ...models.main import Range from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation -from ...core.main import ChatMessage, Step, SequentialStep +from ...core.main import ChatMessage, ContinueCustomException, Step, SequentialStep from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS -from ...libs.util.dedent import dedent_and_get_common_whitespace +from ...libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes import difflib @@ -152,44 +154,51 @@ class DefaultModelEditCodeStep(Step): Main task: """) - + _previous_contents: str = "" + _new_contents: str = "" _prompt_and_completion: str = "" - def _cleanup_output(self, output: str) -> str: - output = output.replace('\\"', '"') - output = output.replace("\\'", "'") - output = output.replace("\\n", "\n") - output = output.replace("\\t", "\t") - output = output.replace("\\\\", "\\") - if output.startswith('"') and output.endswith('"'): - output = output[1:-1] + async def describe(self, models: Models) -> Coroutine[str, None, None]: + if self._previous_contents.strip() == self._new_contents.strip(): + description = "No edits were made" + else: + changes = '\n'.join(difflib.ndiff( + self._previous_contents.splitlines(), self._new_contents.splitlines())) + description = await models.gpt3516k.complete(dedent(f"""\ + Diff summary: "{self.user_input}" - return output + ```diff + {changes} + ``` - async def describe(self, models: Models) -> Coroutine[str, None, None]: - description = await models.gpt3516k.complete(dedent(f"""\ - {self._prompt_and_completion} - - Please give brief a description of the changes made above using markdown bullet points. Be concise and only mention changes made to the commit before, not prefix or suffix:""")) + Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") - self.name = self._cleanup_output(name) + self.name = remove_quotes_and_escapes(name) - return f"{self._cleanup_output(description)}" + return f"{remove_quotes_and_escapes(description)}" async def get_prompt_parts(self, rif: RangeInFileWithContents, sdk: ContinueSDK, full_file_contents: str): # We don't know here all of the functions being passed in. # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. - model_to_use = sdk.models.gpt4 + model_to_use = sdk.models.default + max_tokens = int(MAX_TOKENS_FOR_MODEL.get( + model_to_use.name, DEFAULT_MAX_TOKENS) / 2) - BUFFER_FOR_FUNCTIONS = 400 - total_tokens = model_to_use.count_tokens( - full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + DEFAULT_MAX_TOKENS - - TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1000 + TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: self.description += "\n\n**It looks like you've selected a large range to edit, which may take a while to complete. If you'd like to cancel, click the 'X' button above. If you highlight a more specific range, Continue will only edit within it.**" + # At this point, we also increase the max_tokens parameter so it doesn't stop in the middle of generation + # Increase max_tokens to be double the size of the range + # But don't exceed twice default max tokens + max_tokens = int(min(model_to_use.count_tokens( + rif.contents), DEFAULT_MAX_TOKENS) * 2.5) + + BUFFER_FOR_FUNCTIONS = 400 + total_tokens = model_to_use.count_tokens( + full_file_contents + self._prompt + self.user_input) + BUFFER_FOR_FUNCTIONS + max_tokens + # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: @@ -252,9 +261,26 @@ class DefaultModelEditCodeStep(Step): file_suffix = "\n" + file_suffix rif.contents = rif.contents[:-1] - return file_prefix, rif.contents, file_suffix, model_to_use + return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens def compile_prompt(self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK) -> str: + if contents.strip() == "": + # Seperate prompt for insertion at the cursor, the other tends to cause it to repeat whole file + prompt = dedent(f"""\ +<file_prefix> +{file_prefix} +</file_prefix> +<insertion_code_here> +<file_suffix> +{file_suffix} +</file_suffix> +<user_request> +{self.user_input} +</user_request> + +Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""") + return prompt + prompt = self._prompt if file_prefix.strip() != "": prompt += dedent(f""" @@ -289,22 +315,39 @@ class DefaultModelEditCodeStep(Step): await sdk.ide.saveFile(rif.filepath) full_file_contents = await sdk.ide.readFile(rif.filepath) - file_prefix, contents, file_suffix, model_to_use = await self.get_prompt_parts( + file_prefix, contents, file_suffix, model_to_use, max_tokens = await self.get_prompt_parts( rif, sdk, full_file_contents) contents, common_whitespace = dedent_and_get_common_whitespace( contents) prompt = self.compile_prompt(file_prefix, contents, file_suffix, sdk) full_file_contents_lines = full_file_contents.split("\n") - async def sendDiffUpdate(lines: List[str], sdk: ContinueSDK): - nonlocal full_file_contents_lines, rif + lines_to_display = [] + + async def sendDiffUpdate(lines: List[str], sdk: ContinueSDK, final: bool = False): + nonlocal full_file_contents_lines, rif, lines_to_display completion = "\n".join(lines) full_prefix_lines = full_file_contents_lines[:rif.range.start.line] full_suffix_lines = full_file_contents_lines[rif.range.end.line:] + + # Don't do this at the very end, just show the inserted code + if final: + lines_to_display = [] + # Only recalculate at every new-line, because this is sort of expensive + elif completion.endswith("\n"): + contents_lines = rif.contents.split("\n") + rewritten_lines = 0 + for line in lines: + for i in range(rewritten_lines, len(contents_lines)): + if difflib.SequenceMatcher(None, line, contents_lines[i]).ratio() > 0.7 and contents_lines[i].strip() != "": + rewritten_lines = i + 1 + break + lines_to_display = contents_lines[rewritten_lines:] + new_file_contents = "\n".join( - full_prefix_lines) + "\n" + completion + "\n" + "\n".join(full_suffix_lines) + full_prefix_lines) + "\n" + completion + "\n" + ("\n".join(lines_to_display) + "\n" if len(lines_to_display) > 0 else "") + "\n".join(full_suffix_lines) step_index = sdk.history.current_index @@ -423,6 +466,14 @@ class DefaultModelEditCodeStep(Step): current_block_lines.append(line) messages = await sdk.get_chat_context() + # Delete the last user and assistant messages + i = len(messages) - 1 + deleted = 0 + while i >= 0 and deleted < 2: + if messages[i].role == "user" or messages[i].role == "assistant": + messages.pop(i) + deleted += 1 + i -= 1 messages.append(ChatMessage( role="user", content=prompt, @@ -435,58 +486,68 @@ class DefaultModelEditCodeStep(Step): completion_lines_covered = 0 repeating_file_suffix = False line_below_highlighted_range = file_suffix.lstrip().split("\n")[0] - async for chunk in model_to_use.stream_chat(messages, temperature=0): - # Stop early if it is repeating the file_suffix or the step was deleted - if repeating_file_suffix: - break - if sdk.current_step_was_deleted(): - return - # Accumulate lines - if "content" not in chunk: - continue - chunk = chunk["content"] - chunk_lines = chunk.split("\n") - chunk_lines[0] = unfinished_line + chunk_lines[0] - if chunk.endswith("\n"): - unfinished_line = "" - chunk_lines.pop() # because this will be an empty string - else: - unfinished_line = chunk_lines.pop() - - # Deal with newly accumulated lines - for i in range(len(chunk_lines)): - # Trailing whitespace doesn't matter - chunk_lines[i] = chunk_lines[i].rstrip() - chunk_lines[i] = common_whitespace + chunk_lines[i] - - # Lines that should signify the end of generation - if self.is_end_line(chunk_lines[i]): - break - # Lines that should be ignored, like the <> tags - elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): - continue - # Check if we are currently just copying the prefix - elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: - # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line - lines_of_prefix_copied += 1 - continue - # Because really short lines might be expected to be repeated, this is only a !heuristic! - # Stop when it starts copying the file_suffix - elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): - repeating_file_suffix = True + if isinstance(model_to_use, GGML): + messages = [ChatMessage( + role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] + + generator = model_to_use.stream_chat( + messages, temperature=sdk.config.temperature, max_tokens=max_tokens) + + try: + async for chunk in generator: + # Stop early if it is repeating the file_suffix or the step was deleted + if repeating_file_suffix: break + if sdk.current_step_was_deleted(): + return - # If none of the above, insert the line! - if False: - await handle_generated_line(chunk_lines[i]) + # Accumulate lines + if "content" not in chunk: + continue + chunk = chunk["content"] + chunk_lines = chunk.split("\n") + chunk_lines[0] = unfinished_line + chunk_lines[0] + if chunk.endswith("\n"): + unfinished_line = "" + chunk_lines.pop() # because this will be an empty string + else: + unfinished_line = chunk_lines.pop() + + # Deal with newly accumulated lines + for i in range(len(chunk_lines)): + # Trailing whitespace doesn't matter + chunk_lines[i] = chunk_lines[i].rstrip() + chunk_lines[i] = common_whitespace + chunk_lines[i] + + # Lines that should signify the end of generation + if self.is_end_line(chunk_lines[i]): + break + # Lines that should be ignored, like the <> tags + elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): + continue + # Check if we are currently just copying the prefix + elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: + # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line + lines_of_prefix_copied += 1 + continue + # Because really short lines might be expected to be repeated, this is only a !heuristic! + # Stop when it starts copying the file_suffix + elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): + repeating_file_suffix = True + break - lines.append(chunk_lines[i]) - completion_lines_covered += 1 - current_line_in_file += 1 + # If none of the above, insert the line! + if False: + await handle_generated_line(chunk_lines[i]) - await sendDiffUpdate(lines + [common_whitespace + unfinished_line], sdk) + lines.append(chunk_lines[i]) + completion_lines_covered += 1 + current_line_in_file += 1 + await sendDiffUpdate(lines + [common_whitespace if unfinished_line.startswith("<") else (common_whitespace + unfinished_line)], sdk) + finally: + await generator.aclose() # Add the unfinished line if unfinished_line != "" and not self.line_to_be_ignored(unfinished_line, completion_lines_covered == 0) and not self.is_end_line(unfinished_line): unfinished_line = common_whitespace + unfinished_line @@ -495,7 +556,7 @@ class DefaultModelEditCodeStep(Step): completion_lines_covered += 1 current_line_in_file += 1 - await sendDiffUpdate(lines, sdk) + await sendDiffUpdate(lines, sdk, final=True) if False: # If the current block isn't empty, add that suggestion @@ -529,6 +590,8 @@ class DefaultModelEditCodeStep(Step): # Record the completion completion = "\n".join(lines) + self._previous_contents = "\n".join(original_lines) + self._new_contents = completion self._prompt_and_completion += prompt + completion async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: @@ -549,6 +612,13 @@ class DefaultModelEditCodeStep(Step): rif_dict[rif.filepath] = rif.contents for rif in rif_with_contents: + # If the file doesn't exist, ask them to save it first + if not os.path.exists(rif.filepath): + message = f"The file {rif.filepath} does not exist. Please save it first." + raise ContinueCustomException( + title=message, message=message + ) + await sdk.ide.setFileOpen(rif.filepath) await sdk.ide.setSuggestionsLocked(rif.filepath, True) await self.stream_rif(rif, sdk) diff --git a/continuedev/src/continuedev/steps/custom_command.py b/continuedev/src/continuedev/steps/custom_command.py index 5a56efb0..d96ac8e2 100644 --- a/continuedev/src/continuedev/steps/custom_command.py +++ b/continuedev/src/continuedev/steps/custom_command.py @@ -1,7 +1,7 @@ +from ..libs.util.templating import render_templated_string from ..core.main import Step from ..core.sdk import ContinueSDK -from ..steps.core.core import UserInputStep -from ..steps.chat import ChatWithFunctions, SimpleChatStep +from ..steps.chat import SimpleChatStep class CustomCommandStep(Step): @@ -15,7 +15,9 @@ class CustomCommandStep(Step): return self.prompt async def run(self, sdk: ContinueSDK): - prompt_user_input = f"Task: {self.prompt}. Additional info: {self.user_input}" + task = render_templated_string(self.prompt) + + prompt_user_input = f"Task: {task}. Additional info: {self.user_input}" messages = await sdk.get_chat_context() # Find the last chat message with this slash command and replace it with the user input for i in range(len(messages) - 1, -1, -1): diff --git a/continuedev/src/continuedev/steps/help.py b/continuedev/src/continuedev/steps/help.py new file mode 100644 index 00000000..ba1e6087 --- /dev/null +++ b/continuedev/src/continuedev/steps/help.py @@ -0,0 +1,59 @@ +from textwrap import dedent +from ..core.main import ChatMessage, Step +from ..core.sdk import ContinueSDK +from ..libs.util.telemetry import capture_event + +help = dedent("""\ + Continue is an open-source coding autopilot. It is a VS Code extension that brings the power of ChatGPT to your IDE. + + It gathers context for you and stores your interactions automatically, so that you can avoid copy/paste now and benefit from a customized Large Language Model (LLM) later. + + Continue can be used to... + 1. Edit chunks of code with specific instructions (e.g. "/edit migrate this digital ocean terraform file into one that works for GCP") + 2. Get answers to questions without switching windows (e.g. "how do I find running process on port 8000?") + 3. Generate files from scratch (e.g. "/edit Create a Python CLI tool that uses the posthog api to get events from DAUs") + + You tell Continue to edit a specific section of code by highlighting it. If you highlight multiple code sections, then it will only edit the one with the purple glow around it. You can switch which one has the purple glow by clicking the paint brush. + + If you don't highlight any code, then Continue will insert at the location of your cursor. + + Continue passes all of the sections of code you highlight, the code above and below the to-be edited highlighted code section, and all previous steps above input box as context to the LLM. + + You can use cmd+m (Mac) / ctrl+m (Windows) to open Continue. You can use cmd+shift+e / ctrl+shift+e to open file Explorer. You can add your own OpenAI API key to VS Code Settings with `cmd+,` + + If Continue is stuck loading, try using `cmd+shift+p` to open the command palette, search "Reload Window", and then select it. This will reload VS Code and Continue and often fixes issues. + + If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""") + + +class HelpStep(Step): + + name: str = "Help" + user_input: str + manage_own_chat_context: bool = True + description: str = "" + + async def run(self, sdk: ContinueSDK): + + question = self.user_input + + prompt = dedent(f"""Please us the information below to provide a succinct answer to the following quesiton: {question} + + Information: + + {help}""") + + self.chat_context.append(ChatMessage( + role="user", + content=prompt, + summary="Help" + )) + messages = await sdk.get_chat_context() + generator = sdk.models.gpt4.stream_chat(messages) + async for chunk in generator: + if "content" in chunk: + self.description += chunk["content"] + await sdk.update_ui() + + capture_event(sdk.ide.unique_id, "help", { + "question": question, "answer": self.description}) diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 5ccffbfe..ce7cbc60 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -10,7 +10,7 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents from ..core.observation import Observation, TextObservation, TracebackObservation from ..libs.llm.prompt_utils import MarkdownStyleEncoderDecoder from textwrap import dedent -from ..core.main import Step +from ..core.main import ContinueCustomException, Step from ..core.sdk import ContinueSDK, Models from ..core.observation import Observation import subprocess @@ -97,29 +97,24 @@ class FasterEditHighlightedCodeStep(Step): return "Editing highlighted code" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - range_in_files = await sdk.ide.getHighlightedCode() + range_in_files = await sdk.get_code_context(only_editing=True) if len(range_in_files) == 0: - # Get the full contents of all open files - files = await sdk.ide.getOpenFiles() + # Get the full contents of all visible files + files = await sdk.ide.getVisibleFiles() contents = {} for file in files: contents[file] = await sdk.ide.readFile(file) - range_in_files = [RangeInFile.from_entire_file( + range_in_files = [RangeInFileWithContents.from_entire_file( filepath, content) for filepath, content in contents.items()] - rif_with_contents = [] - for range_in_file in range_in_files: - file_contents = await sdk.ide.readRangeInFile(range_in_file) - rif_with_contents.append( - RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) - enc_dec = MarkdownStyleEncoderDecoder(rif_with_contents) + enc_dec = MarkdownStyleEncoderDecoder(range_in_files) code_string = enc_dec.encode() prompt = self._prompt.format( code=code_string, user_input=self.user_input) rif_dict = {} - for rif in rif_with_contents: + for rif in range_in_files: rif_dict[rif.filepath] = rif.contents completion = await sdk.models.gpt35.complete(prompt) @@ -193,29 +188,23 @@ class StarCoderEditHighlightedCodeStep(Step): return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - range_in_files = await sdk.ide.getHighlightedCode() + range_in_files = await sdk.get_code_context(only_editing=True) found_highlighted_code = len(range_in_files) > 0 if not found_highlighted_code: - # Get the full contents of all open files - files = await sdk.ide.getOpenFiles() + # Get the full contents of all visible files + files = await sdk.ide.getVisibleFiles() contents = {} for file in files: contents[file] = await sdk.ide.readFile(file) - range_in_files = [RangeInFile.from_entire_file( + range_in_files = [RangeInFileWithContents.from_entire_file( filepath, content) for filepath, content in contents.items()] - rif_with_contents = [] - for range_in_file in range_in_files: - file_contents = await sdk.ide.readRangeInFile(range_in_file) - rif_with_contents.append( - RangeInFileWithContents.from_range_in_file(range_in_file, file_contents)) - rif_dict = {} - for rif in rif_with_contents: + for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - for rif in rif_with_contents: + for rif in range_in_files: prompt = self._prompt.format( code=rif.contents, user_request=self.user_input) @@ -255,35 +244,35 @@ class EditHighlightedCodeStep(Step): return "Editing code" async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - range_in_files = await sdk.ide.getHighlightedCode() - if len(range_in_files) == 0: - # Get the full contents of all open files - files = await sdk.ide.getOpenFiles() - contents = {} - for file in files: - contents[file] = await sdk.ide.readFile(file) - - range_in_files = [RangeInFile.from_entire_file( - filepath, content) for filepath, content in contents.items()] + range_in_files = sdk.get_code_context(only_editing=True) - # If still no highlighted code, create a new file and edit there + # If nothing highlighted, insert at the cursor if possible if len(range_in_files) == 0: - # Create a new file - new_file_path = "new_file.txt" - await sdk.add_file(new_file_path, "") - range_in_files = [RangeInFile.from_entire_file(new_file_path, "")] - - await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) - + highlighted_code = await sdk.ide.getHighlightedCode() + if highlighted_code is not None: + for rif in highlighted_code: + if os.path.dirname(rif.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")): + raise ContinueCustomException( + message="Please accept or reject the change before making another edit in this file.", title="Accept/Reject First") + if rif.range.start == rif.range.end: + range_in_files.append( + RangeInFileWithContents.from_range_in_file(rif, "")) + + # If still no highlighted code, raise error + if len(range_in_files) == 0: + raise ContinueCustomException( + message="Please highlight some code and try again.", title="No Code Selected") -class FindCodeStep(Step): - prompt: str + range_in_files = list(map(lambda x: RangeInFile( + filepath=x.filepath, range=x.range + ), range_in_files)) - async def describe(self, models: Models) -> Coroutine[str, None, None]: - return "Finding code" + for range_in_file in range_in_files: + if os.path.dirname(range_in_file.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")): + self.description = "Please accept or reject the change before making another edit in this file." + return - async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: - return await sdk.ide.getOpenFiles() + await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) class UserInputStep(Step): diff --git a/continuedev/src/continuedev/steps/open_config.py b/continuedev/src/continuedev/steps/open_config.py index 441cb0e7..af55a95a 100644 --- a/continuedev/src/continuedev/steps/open_config.py +++ b/continuedev/src/continuedev/steps/open_config.py @@ -14,11 +14,14 @@ class OpenConfigStep(Step): "custom_commands": [ { "name": "test", - "description": "Write unit tests like I do for the highlighted code" + "description": "Write unit tests like I do for the highlighted code", "prompt": "Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated." } - ], - ```""") + ] + ``` + `"name"` is the command you will type. + `"description"` is the description displayed in the slash command menu. + `"prompt"` is the instruction given to the model. The overall prompt becomes "Task: {prompt}, Additional info: {user_input}". For example, if you entered "/test exactly 5 assertions", the overall prompt would become "Task: Write a comprehensive...and sophisticated, Additional info: exactly 5 assertions".""") async def run(self, sdk: ContinueSDK): global_dir = os.path.expanduser('~/.continue') diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py index 2eecc99c..bfb97630 100644 --- a/continuedev/src/continuedev/steps/search_directory.py +++ b/continuedev/src/continuedev/steps/search_directory.py @@ -6,6 +6,7 @@ from ..models.filesystem import RangeInFile from ..models.main import Range from ..core.main import Step from ..core.sdk import ContinueSDK +from ..libs.util.create_async_task import create_async_task import os import re @@ -60,9 +61,9 @@ class EditAllMatchesStep(Step): # Search all files for a given string range_in_files = find_all_matches_in_dir(self.pattern, self.directory or await sdk.ide.getWorkspaceDirectory()) - tasks = [asyncio.create_task(sdk.edit_file( + tasks = [create_async_task(sdk.edit_file( range=range_in_file.range, filename=range_in_file.filepath, prompt=self.user_request - )) for range_in_file in range_in_files] + ), sdk.ide.unique_id) for range_in_file in range_in_files] await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/steps/welcome.py b/continuedev/src/continuedev/steps/welcome.py index 32ebc3ba..2dece649 100644 --- a/continuedev/src/continuedev/steps/welcome.py +++ b/continuedev/src/continuedev/steps/welcome.py @@ -29,4 +29,4 @@ class WelcomeStep(Step): - Ask about how the class works, how to write it in another language, etc. \"\"\""""))) - await sdk.ide.setFileOpen(filepath=filepath) + # await sdk.ide.setFileOpen(filepath=filepath) |