diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-16 16:25:02 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-16 16:25:02 -0700 |
commit | da7827189181328baa329d2984d0fd9c6d476be3 (patch) | |
tree | 56308a5595e8e01ab83974688d8ac8e4945f319a /continuedev/src/continuedev/core | |
parent | d80119982e9b60ca0022533a0086eb526dc7d957 (diff) | |
parent | eab69781a3e3b5236916d9057ce29aba2e868913 (diff) | |
download | sncontinue-da7827189181328baa329d2984d0fd9c6d476be3.tar.gz sncontinue-da7827189181328baa329d2984d0fd9c6d476be3.tar.bz2 sncontinue-da7827189181328baa329d2984d0fd9c6d476be3.zip |
Merge branch 'main' into ggml-server
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 44 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 95 |
5 files changed, 103 insertions, 56 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 7bd3da6c..94d7be10 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC): async def get_user_secret(self, env_var: str, prompt: str) -> str: pass - @abstractproperty - def config(self) -> ContinueConfig: - pass + config: ContinueConfig @abstractmethod def set_loading_message(self, message: str): diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 5c3baafd..0696c360 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -1,13 +1,13 @@ from functools import cached_property import traceback import time -from typing import Any, Callable, Coroutine, Dict, List +from typing import Any, Callable, Coroutine, Dict, List, Union import os from aiohttp import ClientPayloadError +from pydantic import root_validator from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.llm import LLM from .observation import Observation, InternalErrorObservation from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue @@ -16,10 +16,10 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from ..libs.util.telemetry import capture_event from .sdk import ContinueSDK -import asyncio from ..libs.util.step_name_to_steps import get_step_from_name from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback from openai import error as openai_errors +from ..libs.util.create_async_task import create_async_task def get_error_title(e: Exception) -> str: @@ -34,9 +34,11 @@ def get_error_title(e: Exception) -> str: elif isinstance(e, ClientPayloadError): return "The request to OpenAI failed. Please try again." elif isinstance(e, openai_errors.APIConnectionError): - return "The request failed. Please check your internet connection and try again." + return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\"" elif isinstance(e, openai_errors.InvalidRequestError): return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""' + elif e.__str__().startswith("Cannot connect to host"): + return "The request failed. Please check your internet connection and try again." return e.__str__() or e.__repr__() @@ -45,8 +47,11 @@ class Autopilot(ContinueBaseModel): ide: AbstractIdeProtocolServer history: History = History.from_empty() context: Context = Context() + full_state: Union[FullState, None] = None _on_update_callbacks: List[Callable[[FullState], None]] = [] + continue_sdk: ContinueSDK = None + _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] @@ -54,16 +59,25 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @cached_property - def continue_sdk(self) -> ContinueSDK: - return ContinueSDK(self) + @classmethod + async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": + autopilot = cls(ide=ide, policy=policy) + autopilot.continue_sdk = await ContinueSDK.create(autopilot) + return autopilot class Config: arbitrary_types_allowed = True keep_untouched = (cached_property,) + @root_validator(pre=True) + def fill_in_values(cls, values): + full_state: FullState = values.get('full_state') + if full_state is not None: + values['history'] = full_state.history + return values + def get_full_state(self) -> FullState: - return FullState( + full_state = FullState( history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, @@ -72,6 +86,8 @@ class Autopilot(ContinueBaseModel): slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self._adding_highlighted_code, ) + self.full_state = full_state + return full_state def get_available_slash_commands(self) -> List[Dict]: custom_commands = list(map(lambda x: { @@ -207,6 +223,8 @@ class Autopilot(ContinueBaseModel): async def delete_at_index(self, index: int): self.history.timeline[index].step.hide = True self.history.timeline[index].deleted = True + self.history.timeline[index].active = False + await self.update_subscribers() async def delete_context_at_indices(self, indices: List[int]): @@ -250,7 +268,7 @@ class Autopilot(ContinueBaseModel): # i -= 1 capture_event(self.continue_sdk.ide.unique_id, 'step run', { - 'step_name': step.name, 'params': step.dict()}) + 'step_name': step.name, 'params': step.dict()}) if not is_future_step: # Check manual edits buffer, clear out if needed by creating a ManualEditStep @@ -284,12 +302,13 @@ class Autopilot(ContinueBaseModel): e.__class__, ContinueCustomException) error_string = e.message if is_continue_custom_exception else '\n'.join( - traceback.format_tb(e.__traceback__)) + f"\n\n{e.__repr__()}" + traceback.format_exception(e)) error_title = e.title if is_continue_custom_exception else get_error_title( e) # Attach an InternalErrorObservation to the step and unhide it. - print(f"Error while running step: \n{error_string}\n{error_title}") + print( + f"Error while running step: \n{error_string}\n{error_title}") capture_event(self.continue_sdk.ide.unique_id, 'step error', { 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) @@ -341,7 +360,8 @@ class Autopilot(ContinueBaseModel): # Update subscribers with new description await self.update_subscribers() - asyncio.create_task(update_description()) + create_async_task(update_description(), + self.continue_sdk.ide.unique_id) return observation diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 55f5bc60..6e430c04 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -45,6 +45,11 @@ DEFAULT_SLASH_COMMANDS = [ step_name="OpenConfigStep", ), SlashCommand( + name="help", + description="Ask a question like '/help what is given to the llm as context?'", + step_name="HelpStep", + ), + SlashCommand( name="comment", description="Write comments for the current file or highlighted code", step_name="CommentCodeStep", @@ -131,7 +136,7 @@ def load_global_config() -> ContinueConfig: config_path = os.path.join(global_dir, 'config.json') if not os.path.exists(config_path): with open(config_path, 'w') as f: - json.dump(ContinueConfig().dict(), f) + json.dump(ContinueConfig().dict(), f, indent=4) with open(config_path, 'r') as f: try: config_dict = json.load(f) @@ -151,7 +156,7 @@ def update_global_config(config: ContinueConfig): yaml_path = os.path.join(global_dir, 'config.yaml') if os.path.exists(yaml_path): with open(config_path, 'w') as f: - yaml.dump(config.dict(), f) + yaml.dump(config.dict(), f, indent=4) else: config_path = os.path.join(global_dir, 'config.json') with open(config_path, 'w') as f: diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index b8363df2..bc897357 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -59,11 +59,8 @@ class DemoPolicy(Policy): return ( MessageStep(name="Welcome to Continue", message=dedent("""\ - Highlight code and ask a question or give instructions - - Use `cmd+k` (Mac) / `ctrl+k` (Windows) to open Continue - - Use `cmd+shift+e` / `ctrl+shift+e` to open file Explorer - - Add your own OpenAI API key to VS Code Settings with `cmd+,` - - Use slash commands when you want fine-grained control - - Past steps are included as part of the context by default""")) >> + - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue + - Use `/help` to ask questions about how to use Continue""")) >> WelcomeStep() >> # SetupContinueWorkspaceStep() >> # CreateCodebaseIndexChroma() >> diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 22393746..9389e1e9 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,6 +1,6 @@ import asyncio from functools import cached_property -from typing import Coroutine, Union +from typing import Coroutine, Dict, Union import os from ..steps.core.core import DefaultModelEditCodeStep @@ -14,7 +14,7 @@ from ..libs.llm.openai import OpenAI from ..libs.llm.ggml import GGML from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer -from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole +from .main import Context, ContinueCustomException, History, Step, ChatMessage from ..steps.core.core import * from ..libs.llm.proxy_server import ProxyServer @@ -23,26 +23,46 @@ class Autopilot: pass +ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] +MODEL_PROVIDER_TO_ENV_VAR = { + "openai": "OPENAI_API_KEY", + "hf_inference_api": "HUGGING_FACE_TOKEN", + "anthropic": "ANTHROPIC_API_KEY" +} + + class Models: - def __init__(self, sdk: "ContinueSDK"): + provider_keys: Dict[ModelProvider, str] = {} + model_providers: List[ModelProvider] + + def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk + self.model_providers = model_providers + + @classmethod + async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": + models = Models(sdk, with_providers) + for provider in with_providers: + if provider in MODEL_PROVIDER_TO_ENV_VAR: + env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] + models.provider_keys[provider] = await sdk.get_user_secret( + env_var, f'Please add your {env_var} to the .env file') + + return models def __load_openai_model(self, model: str) -> OpenAI: - async def load_openai_model(): - api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free') - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) - return asyncio.get_event_loop().run_until_complete(load_openai_model()) + api_key = self.provider_keys["openai"] + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, model) + return OpenAI(api_key=api_key, default_model=model) + + def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: + api_key = self.provider_keys["hf_inference_api"] + return HuggingFaceInferenceAPI(api_key=api_key, model=model) @cached_property def starcoder(self): - async def load_starcoder(): - api_key = await self.sdk.get_user_secret( - 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file') - return HuggingFaceInferenceAPI(api_key=api_key) - return asyncio.get_event_loop().run_until_complete(load_starcoder()) + return self.__load_hf_inference_api_model("bigcode/starcoder") @cached_property def gpt35(self): @@ -80,7 +100,7 @@ class Models: def default(self): return self.ggml default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt35 + return self.__model_from_name(default_model) if default_model is not None else self.gpt4 class ContinueSDK(AbstractContinueSDK): @@ -93,8 +113,27 @@ class ContinueSDK(AbstractContinueSDK): def __init__(self, autopilot: Autopilot): self.ide = autopilot.ide self.__autopilot = autopilot - self.models = Models(self) self.context = autopilot.context + self.config = self._load_config() + + @classmethod + async def create(cls, autopilot: Autopilot) -> "ContinueSDK": + sdk = ContinueSDK(autopilot) + sdk.models = await Models.create(sdk) + return sdk + + config: ContinueConfig + + def _load_config(self) -> ContinueConfig: + dir = self.ide.workspace_directory + yaml_path = os.path.join(dir, '.continue', 'config.yaml') + json_path = os.path.join(dir, '.continue', 'config.json') + if os.path.exists(yaml_path): + return load_config(yaml_path) + elif os.path.exists(json_path): + return load_config(json_path) + else: + return load_global_config() @property def history(self) -> History: @@ -172,18 +211,6 @@ class ContinueSDK(AbstractContinueSDK): async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) - @property - def config(self) -> ContinueConfig: - dir = self.ide.workspace_directory - yaml_path = os.path.join(dir, '.continue', 'config.yaml') - json_path = os.path.join(dir, '.continue', 'config.json') - if os.path.exists(yaml_path): - return load_config(yaml_path) - elif os.path.exists(json_path): - return load_config(json_path) - else: - return load_global_config() - def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges) ) if only_editing else self.__autopilot._highlighted_ranges @@ -208,14 +235,14 @@ class ContinueSDK(AbstractContinueSDK): preface = "The following code is highlighted" + # If no higlighted ranges, use first file as context if len(highlighted_code) == 0: preface = "The following file is open" - # Get the full contents of all open files - files = await self.ide.getOpenFiles() - if len(files) > 0: - content = await self.ide.readFile(files[0]) + visible_files = await self.ide.getVisibleFiles() + if len(visible_files) > 0: + content = await self.ide.readFile(visible_files[0]) highlighted_code = [ - RangeInFileWithContents.from_entire_file(files[0], content)] + RangeInFileWithContents.from_entire_file(visible_files[0], content)] for rif in highlighted_code: msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", |