summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 16:25:02 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 16:25:02 -0700
commitda7827189181328baa329d2984d0fd9c6d476be3 (patch)
tree56308a5595e8e01ab83974688d8ac8e4945f319a /continuedev/src/continuedev/core
parentd80119982e9b60ca0022533a0086eb526dc7d957 (diff)
parenteab69781a3e3b5236916d9057ce29aba2e868913 (diff)
downloadsncontinue-da7827189181328baa329d2984d0fd9c6d476be3.tar.gz
sncontinue-da7827189181328baa329d2984d0fd9c6d476be3.tar.bz2
sncontinue-da7827189181328baa329d2984d0fd9c6d476be3.zip
Merge branch 'main' into ggml-server
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py4
-rw-r--r--continuedev/src/continuedev/core/autopilot.py44
-rw-r--r--continuedev/src/continuedev/core/config.py9
-rw-r--r--continuedev/src/continuedev/core/policy.py7
-rw-r--r--continuedev/src/continuedev/core/sdk.py95
5 files changed, 103 insertions, 56 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 7bd3da6c..94d7be10 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
pass
- @abstractproperty
- def config(self) -> ContinueConfig:
- pass
+ config: ContinueConfig
@abstractmethod
def set_loading_message(self, message: str):
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 5c3baafd..0696c360 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,13 +1,13 @@
from functools import cached_property
import traceback
import time
-from typing import Any, Callable, Coroutine, Dict, List
+from typing import Any, Callable, Coroutine, Dict, List, Union
import os
from aiohttp import ClientPayloadError
+from pydantic import root_validator
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
-from ..libs.llm import LLM
from .observation import Observation, InternalErrorObservation
from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -16,10 +16,10 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol
from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
from ..libs.util.telemetry import capture_event
from .sdk import ContinueSDK
-import asyncio
from ..libs.util.step_name_to_steps import get_step_from_name
from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
from openai import error as openai_errors
+from ..libs.util.create_async_task import create_async_task
def get_error_title(e: Exception) -> str:
@@ -34,9 +34,11 @@ def get_error_title(e: Exception) -> str:
elif isinstance(e, ClientPayloadError):
return "The request to OpenAI failed. Please try again."
elif isinstance(e, openai_errors.APIConnectionError):
- return "The request failed. Please check your internet connection and try again."
+ return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\""
elif isinstance(e, openai_errors.InvalidRequestError):
return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""'
+ elif e.__str__().startswith("Cannot connect to host"):
+ return "The request failed. Please check your internet connection and try again."
return e.__str__() or e.__repr__()
@@ -45,8 +47,11 @@ class Autopilot(ContinueBaseModel):
ide: AbstractIdeProtocolServer
history: History = History.from_empty()
context: Context = Context()
+ full_state: Union[FullState, None] = None
_on_update_callbacks: List[Callable[[FullState], None]] = []
+ continue_sdk: ContinueSDK = None
+
_active: bool = False
_should_halt: bool = False
_main_user_input_queue: List[str] = []
@@ -54,16 +59,25 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @cached_property
- def continue_sdk(self) -> ContinueSDK:
- return ContinueSDK(self)
+ @classmethod
+ async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
+ autopilot = cls(ide=ide, policy=policy)
+ autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ return autopilot
class Config:
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
+ @root_validator(pre=True)
+ def fill_in_values(cls, values):
+ full_state: FullState = values.get('full_state')
+ if full_state is not None:
+ values['history'] = full_state.history
+ return values
+
def get_full_state(self) -> FullState:
- return FullState(
+ full_state = FullState(
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
@@ -72,6 +86,8 @@ class Autopilot(ContinueBaseModel):
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self._adding_highlighted_code,
)
+ self.full_state = full_state
+ return full_state
def get_available_slash_commands(self) -> List[Dict]:
custom_commands = list(map(lambda x: {
@@ -207,6 +223,8 @@ class Autopilot(ContinueBaseModel):
async def delete_at_index(self, index: int):
self.history.timeline[index].step.hide = True
self.history.timeline[index].deleted = True
+ self.history.timeline[index].active = False
+
await self.update_subscribers()
async def delete_context_at_indices(self, indices: List[int]):
@@ -250,7 +268,7 @@ class Autopilot(ContinueBaseModel):
# i -= 1
capture_event(self.continue_sdk.ide.unique_id, 'step run', {
- 'step_name': step.name, 'params': step.dict()})
+ 'step_name': step.name, 'params': step.dict()})
if not is_future_step:
# Check manual edits buffer, clear out if needed by creating a ManualEditStep
@@ -284,12 +302,13 @@ class Autopilot(ContinueBaseModel):
e.__class__, ContinueCustomException)
error_string = e.message if is_continue_custom_exception else '\n'.join(
- traceback.format_tb(e.__traceback__)) + f"\n\n{e.__repr__()}"
+ traceback.format_exception(e))
error_title = e.title if is_continue_custom_exception else get_error_title(
e)
# Attach an InternalErrorObservation to the step and unhide it.
- print(f"Error while running step: \n{error_string}\n{error_title}")
+ print(
+ f"Error while running step: \n{error_string}\n{error_title}")
capture_event(self.continue_sdk.ide.unique_id, 'step error', {
'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()})
@@ -341,7 +360,8 @@ class Autopilot(ContinueBaseModel):
# Update subscribers with new description
await self.update_subscribers()
- asyncio.create_task(update_description())
+ create_async_task(update_description(),
+ self.continue_sdk.ide.unique_id)
return observation
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 55f5bc60..6e430c04 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -45,6 +45,11 @@ DEFAULT_SLASH_COMMANDS = [
step_name="OpenConfigStep",
),
SlashCommand(
+ name="help",
+ description="Ask a question like '/help what is given to the llm as context?'",
+ step_name="HelpStep",
+ ),
+ SlashCommand(
name="comment",
description="Write comments for the current file or highlighted code",
step_name="CommentCodeStep",
@@ -131,7 +136,7 @@ def load_global_config() -> ContinueConfig:
config_path = os.path.join(global_dir, 'config.json')
if not os.path.exists(config_path):
with open(config_path, 'w') as f:
- json.dump(ContinueConfig().dict(), f)
+ json.dump(ContinueConfig().dict(), f, indent=4)
with open(config_path, 'r') as f:
try:
config_dict = json.load(f)
@@ -151,7 +156,7 @@ def update_global_config(config: ContinueConfig):
yaml_path = os.path.join(global_dir, 'config.yaml')
if os.path.exists(yaml_path):
with open(config_path, 'w') as f:
- yaml.dump(config.dict(), f)
+ yaml.dump(config.dict(), f, indent=4)
else:
config_path = os.path.join(global_dir, 'config.json')
with open(config_path, 'w') as f:
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index b8363df2..bc897357 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -59,11 +59,8 @@ class DemoPolicy(Policy):
return (
MessageStep(name="Welcome to Continue", message=dedent("""\
- Highlight code and ask a question or give instructions
- - Use `cmd+k` (Mac) / `ctrl+k` (Windows) to open Continue
- - Use `cmd+shift+e` / `ctrl+shift+e` to open file Explorer
- - Add your own OpenAI API key to VS Code Settings with `cmd+,`
- - Use slash commands when you want fine-grained control
- - Past steps are included as part of the context by default""")) >>
+ - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
+ - Use `/help` to ask questions about how to use Continue""")) >>
WelcomeStep() >>
# SetupContinueWorkspaceStep() >>
# CreateCodebaseIndexChroma() >>
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 22393746..9389e1e9 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,6 +1,6 @@
import asyncio
from functools import cached_property
-from typing import Coroutine, Union
+from typing import Coroutine, Dict, Union
import os
from ..steps.core.core import DefaultModelEditCodeStep
@@ -14,7 +14,7 @@ from ..libs.llm.openai import OpenAI
from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
-from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole
+from .main import Context, ContinueCustomException, History, Step, ChatMessage
from ..steps.core.core import *
from ..libs.llm.proxy_server import ProxyServer
@@ -23,26 +23,46 @@ class Autopilot:
pass
+ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
+MODEL_PROVIDER_TO_ENV_VAR = {
+ "openai": "OPENAI_API_KEY",
+ "hf_inference_api": "HUGGING_FACE_TOKEN",
+ "anthropic": "ANTHROPIC_API_KEY"
+}
+
+
class Models:
- def __init__(self, sdk: "ContinueSDK"):
+ provider_keys: Dict[ModelProvider, str] = {}
+ model_providers: List[ModelProvider]
+
+ def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
+ self.model_providers = model_providers
+
+ @classmethod
+ async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
+ models = Models(sdk, with_providers)
+ for provider in with_providers:
+ if provider in MODEL_PROVIDER_TO_ENV_VAR:
+ env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
+ models.provider_keys[provider] = await sdk.get_user_secret(
+ env_var, f'Please add your {env_var} to the .env file')
+
+ return models
def __load_openai_model(self, model: str) -> OpenAI:
- async def load_openai_model():
- api_key = await self.sdk.get_user_secret(
- 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free')
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model)
- return OpenAI(api_key=api_key, default_model=model)
- return asyncio.get_event_loop().run_until_complete(load_openai_model())
+ api_key = self.provider_keys["openai"]
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, model)
+ return OpenAI(api_key=api_key, default_model=model)
+
+ def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
+ api_key = self.provider_keys["hf_inference_api"]
+ return HuggingFaceInferenceAPI(api_key=api_key, model=model)
@cached_property
def starcoder(self):
- async def load_starcoder():
- api_key = await self.sdk.get_user_secret(
- 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file')
- return HuggingFaceInferenceAPI(api_key=api_key)
- return asyncio.get_event_loop().run_until_complete(load_starcoder())
+ return self.__load_hf_inference_api_model("bigcode/starcoder")
@cached_property
def gpt35(self):
@@ -80,7 +100,7 @@ class Models:
def default(self):
return self.ggml
default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt35
+ return self.__model_from_name(default_model) if default_model is not None else self.gpt4
class ContinueSDK(AbstractContinueSDK):
@@ -93,8 +113,27 @@ class ContinueSDK(AbstractContinueSDK):
def __init__(self, autopilot: Autopilot):
self.ide = autopilot.ide
self.__autopilot = autopilot
- self.models = Models(self)
self.context = autopilot.context
+ self.config = self._load_config()
+
+ @classmethod
+ async def create(cls, autopilot: Autopilot) -> "ContinueSDK":
+ sdk = ContinueSDK(autopilot)
+ sdk.models = await Models.create(sdk)
+ return sdk
+
+ config: ContinueConfig
+
+ def _load_config(self) -> ContinueConfig:
+ dir = self.ide.workspace_directory
+ yaml_path = os.path.join(dir, '.continue', 'config.yaml')
+ json_path = os.path.join(dir, '.continue', 'config.json')
+ if os.path.exists(yaml_path):
+ return load_config(yaml_path)
+ elif os.path.exists(json_path):
+ return load_config(json_path)
+ else:
+ return load_global_config()
@property
def history(self) -> History:
@@ -172,18 +211,6 @@ class ContinueSDK(AbstractContinueSDK):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)
- @property
- def config(self) -> ContinueConfig:
- dir = self.ide.workspace_directory
- yaml_path = os.path.join(dir, '.continue', 'config.yaml')
- json_path = os.path.join(dir, '.continue', 'config.json')
- if os.path.exists(yaml_path):
- return load_config(yaml_path)
- elif os.path.exists(json_path):
- return load_config(json_path)
- else:
- return load_global_config()
-
def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges)
) if only_editing else self.__autopilot._highlighted_ranges
@@ -208,14 +235,14 @@ class ContinueSDK(AbstractContinueSDK):
preface = "The following code is highlighted"
+ # If no higlighted ranges, use first file as context
if len(highlighted_code) == 0:
preface = "The following file is open"
- # Get the full contents of all open files
- files = await self.ide.getOpenFiles()
- if len(files) > 0:
- content = await self.ide.readFile(files[0])
+ visible_files = await self.ide.getVisibleFiles()
+ if len(visible_files) > 0:
+ content = await self.ide.readFile(visible_files[0])
highlighted_code = [
- RangeInFileWithContents.from_entire_file(files[0], content)]
+ RangeInFileWithContents.from_entire_file(visible_files[0], content)]
for rif in highlighted_code:
msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```",