summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CONTRIBUTING.md1
-rw-r--r--continuedev/README.md12
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py2
-rw-r--r--continuedev/src/continuedev/core/autopilot.py24
-rw-r--r--continuedev/src/continuedev/core/config.py22
-rw-r--r--continuedev/src/continuedev/core/context.py6
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/core/models.py65
-rw-r--r--continuedev/src/continuedev/core/sdk.py141
-rw-r--r--continuedev/src/continuedev/libs/chroma/query.py2
-rw-r--r--continuedev/src/continuedev/libs/chroma/update.py2
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt20
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py31
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py47
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py100
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py53
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py94
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py165
-rw-r--r--continuedev/src/continuedev/libs/llm/utils.py34
-rw-r--r--continuedev/src/continuedev/libs/util/calculate_diff.py2
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py86
-rw-r--r--continuedev/src/continuedev/libs/util/strings.py2
-rw-r--r--continuedev/src/continuedev/models/generate_json_schema.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/embeddings.py79
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py3
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/google.py4
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/util.py5
-rw-r--r--continuedev/src/continuedev/plugins/policies/default.py (renamed from continuedev/src/continuedev/core/policy.py)24
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/README.md2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py8
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py37
-rw-r--r--continuedev/src/continuedev/plugins/steps/draft/migration.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py4
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py6
-rw-r--r--continuedev/src/continuedev/server/gui.py2
-rw-r--r--continuedev/src/continuedev/server/ide.py4
-rw-r--r--continuedev/src/continuedev/server/session_manager.py48
-rw-r--r--docs/docs/customization.md59
-rw-r--r--docs/docs/walkthroughs/create-a-recipe.md2
-rw-r--r--extension/media/terminal-continue.pngbin36669 -> 7891 bytes
-rw-r--r--extension/package-lock.json4
-rw-r--r--extension/package.json2
-rw-r--r--extension/src/suggestions.ts2
50 files changed, 727 insertions, 514 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e6dea5c4..50d694f4 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -107,7 +107,6 @@ When state is updated on the server, we currently send the entirety of the objec
- `history`, a record of previously run Steps. Displayed in order in the sidebar.
- `active`, whether the autopilot is currently running a step. Displayed as a loader while step is running.
- `user_input_queue`, the queue of user inputs that have not yet been processed due to waiting for previous Steps to complete. Displayed below the `active` loader until popped from the queue.
-- `default_model`, the default model used for completions. Displayed as a toggleable button on the bottom of the GUI.
- `selected_context_items`, the ranges of code and other items (like GitHub Issues, files, etc...) that have been selected to include as context. Displayed just above the main text input.
- `slash_commands`, the list of available slash commands. Displayed in the main text input dropdown.
- `adding_highlighted_code`, whether highlighting of new code for context is locked. Displayed as a button adjacent to `highlighted_ranges`.
diff --git a/continuedev/README.md b/continuedev/README.md
index d3ead8ec..6a11ae43 100644
--- a/continuedev/README.md
+++ b/continuedev/README.md
@@ -67,9 +67,9 @@ cd continue/extension/scripts && python3 install_from_source.py
# Understanding the codebase
-- [Continue Server README](./continuedev/README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/)
-- [VS Code Extension README](./extension/README.md): learn about the capabilities of our extension—the first implementation of Continue's IDE Protocol—which makes it possible to use use Continue in VS Code and GitHub Codespaces
-- [Continue GUI README](./extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any suppported IDE
-- [Schema README](./schema): learn about the JSON Schema types generated from Pydantic models, which we use across the `continuedev/` and `extension/` directories
-- [Continue Docs README](./docs): learn how our [docs](https://continue.dev/docs) are written and built
-- [How to debug the VS Code Extension README](./extension/src/README.md): learn how to set up the VS Code extension, so you can debug it
+- [Continue Server README](./README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/)
+- [VS Code Extension README](../extension/README.md): learn about the capabilities of our extension—the first implementation of Continue's IDE Protocol—which makes it possible to use use Continue in VS Code and GitHub Codespaces
+- [Continue GUI README](../extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any suppported IDE
+- [Schema README](../schema/README.md): learn about the JSON Schema types generated from Pydantic models, which we use across the `continuedev/` and `extension/` directories
+- [Continue Docs README](../docs/README.md): learn how our [docs](https://continue.dev/docs) are written and built
+- [How to debug the VS Code Extension README](../extension/src/README.md): learn how to set up the VS Code extension, so you can debug it
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 94d7be10..e048f877 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC):
pass
@abstractmethod
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
pass
config: ContinueConfig
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 57e39d5c..de95a259 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from .observation import Observation, InternalErrorObservation
from .context import ContextManager
+from ..plugins.policies.default import DefaultPolicy
from ..plugins.context_providers.file import FileContextProvider
from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str:
class Autopilot(ContinueBaseModel):
- policy: Policy
ide: AbstractIdeProtocolServer
+
+ policy: Policy = DefaultPolicy()
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
@@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @classmethod
- async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
- autopilot = cls(ide=ide, policy=policy)
- autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ async def start(self):
+ self.continue_sdk = await ContinueSDK.create(self)
+ if override_policy := self.continue_sdk.config.policy_override:
+ self.policy = override_policy
# Load documents into the search index
- autopilot.context_manager = await ContextManager.create(
- autopilot.continue_sdk.config.context_providers + [
- HighlightedCodeContextProvider(ide=ide),
- FileContextProvider(workspace_dir=ide.workspace_directory)
+ self.context_manager = await ContextManager.create(
+ self.continue_sdk.config.context_providers + [
+ HighlightedCodeContextProvider(ide=self.ide),
+ FileContextProvider(workspace_dir=self.ide.workspace_directory)
])
- await autopilot.context_manager.load_index(ide.workspace_directory)
- return autopilot
+ await self.context_manager.load_index(self.ide.workspace_directory)
class Config:
arbitrary_types_allowed = True
@@ -95,7 +96,6 @@ class Autopilot(ContinueBaseModel):
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
- default_model=self.continue_sdk.config.default_model,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
"code"].adding_highlighted_code,
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 9fbda824..84b6b10b 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -2,9 +2,13 @@ import json
import os
from .main import Step
from .context import ContextProvider
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from .models import Models
from pydantic import BaseModel, validator
-from typing import List, Literal, Optional, Dict, Type, Union
-import yaml
+from typing import List, Literal, Optional, Dict, Type
+
+from .main import Policy, Step
+from .context import ContextProvider
class SlashCommand(BaseModel):
@@ -25,13 +29,6 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class OpenAIServerInfo(BaseModel):
- api_base: Optional[str] = None
- engine: Optional[str] = None
- api_version: Optional[str] = None
- api_type: Literal["azure", "openai"] = "openai"
-
-
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
@@ -39,8 +36,9 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2", "ggml"] = 'gpt-4'
+ models: Models = Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ )
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
@@ -50,7 +48,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- openai_server_info: Optional[OpenAIServerInfo] = None
+ policy_override: Optional[Policy] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 86522ce1..e968c35c 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -178,12 +178,6 @@ class ContextManager:
except Exception as e:
logger.debug(f"Error loading meilisearch index: {e}")
- # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
- # """
- # Compiles the chat prompt into a single string.
- # """
- # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
-
async def select_context_item(self, id: str, query: str):
"""
Selects the ContextItem with the given id.
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index df9b98ef..2553850f 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):
history: History
active: bool
user_input_queue: List[str]
- default_model: str
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
new file mode 100644
index 00000000..900762b6
--- /dev/null
+++ b/continuedev/src/continuedev/core/models.py
@@ -0,0 +1,65 @@
+from typing import Optional, Any
+from pydantic import BaseModel, validator
+from ..libs.llm import LLM
+
+
+class Models(BaseModel):
+ """Main class that holds the current model configuration"""
+ default: LLM
+ small: Optional[LLM] = None
+ medium: Optional[LLM] = None
+ large: Optional[LLM] = None
+
+ # TODO namespace these away to not confuse readers,
+ # or split Models into ModelsConfig, which gets turned into Models
+ sdk: "ContinueSDK" = None
+ system_message: Any = None
+
+ """
+ Better to have sdk.llm.stream_chat(messages, model="claude-2").
+ Then you also don't care that it' async.
+ And it's easier to add more models.
+ And intermediate shared code is easier to add.
+ And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
+ PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
+ Easy to reason about, can place anywhere.
+ And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
+ This can all happen inside of Models?
+
+ class Prompt:
+ def __init__(self, ...info):
+ '''take whatever info is needed to describe the prompt'''
+
+ def to_string(self, model: str) -> str:
+ '''depending on the model, return the single prompt string'''
+ """
+
+ async def start(self, sdk: "ContinueSDK"):
+ """Start each of the LLMs, or fall back to default"""
+ self.sdk = sdk
+ self.system_message = self.sdk.config.system_message
+ await sdk.start_model(self.default)
+ if self.small:
+ await sdk.start_model(self.small)
+ else:
+ self.small = self.default
+
+ if self.medium:
+ await sdk.start_model(self.medium)
+ else:
+ self.medium = self.default
+
+ if self.large:
+ await sdk.start_model(self.large)
+ else:
+ self.large = self.default
+
+ async def stop(self, sdk: "ContinueSDK"):
+ """Stop each LLM (if it's not the default, which is shared)"""
+ await self.default.stop()
+ if self.small is not self.default:
+ await self.small.stop()
+ if self.medium is not self.default:
+ await self.medium.stop()
+ if self.large is not self.default:
+ await self.large.stop()
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 4b76a121..bf22d696 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
-from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
-from ..libs.llm.openai import OpenAI
-from ..libs.llm.anthropic import AnthropicLLM
-from ..libs.llm.ggml import GGML
+from ..libs.llm import LLM
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage
from ..plugins.steps.core.core import *
-from ..libs.llm.proxy_server import ProxyServer
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
+from .models import Models
from ..libs.util.logging import logger
@@ -27,121 +24,6 @@ class Autopilot:
pass
-ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
-MODEL_PROVIDER_TO_ENV_VAR = {
- "openai": "OPENAI_API_KEY",
- "hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY",
-}
-
-
-class Models:
- provider_keys: Dict[ModelProvider, str] = {}
- model_providers: List[ModelProvider]
- system_message: str
-
- """
- Better to have sdk.llm.stream_chat(messages, model="claude-2").
- Then you also don't care that it' async.
- And it's easier to add more models.
- And intermediate shared code is easier to add.
- And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
- PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
- Easy to reason about, can place anywhere.
- And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
- This can all happen inside of Models?
-
- class Prompt:
- def __init__(self, ...info):
- '''take whatever info is needed to describe the prompt'''
-
- def to_string(self, model: str) -> str:
- '''depending on the model, return the single prompt string'''
- """
-
- def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
- self.sdk = sdk
- self.model_providers = model_providers
- self.system_message = sdk.config.system_message
-
- @classmethod
- async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
- if sdk.config.default_model == "claude-2":
- with_providers.append("anthropic")
-
- models = Models(sdk, with_providers)
- for provider in with_providers:
- if provider in MODEL_PROVIDER_TO_ENV_VAR:
- env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
- models.provider_keys[provider] = await sdk.get_user_secret(
- env_var, f'Please add your {env_var} to the .env file')
-
- return models
-
- def __load_openai_model(self, model: str) -> OpenAI:
- api_key = self.provider_keys["openai"]
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)
-
- def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
- api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
-
- def __load_anthropic_model(self, model: str) -> AnthropicLLM:
- api_key = self.provider_keys["anthropic"]
- return AnthropicLLM(api_key, model, self.system_message)
-
- @cached_property
- def claude2(self):
- return self.__load_anthropic_model("claude-2")
-
- @cached_property
- def starcoder(self):
- return self.__load_hf_inference_api_model("bigcode/starcoder")
-
- @cached_property
- def gpt35(self):
- return self.__load_openai_model("gpt-3.5-turbo")
-
- @cached_property
- def gpt350613(self):
- return self.__load_openai_model("gpt-3.5-turbo-0613")
-
- @cached_property
- def gpt3516k(self):
- return self.__load_openai_model("gpt-3.5-turbo-16k")
-
- @cached_property
- def gpt4(self):
- return self.__load_openai_model("gpt-4")
-
- @cached_property
- def ggml(self):
- return GGML(system_message=self.system_message)
-
- def __model_from_name(self, model_name: str):
- if model_name == "starcoder":
- return self.starcoder
- elif model_name == "gpt-3.5-turbo":
- return self.gpt35
- elif model_name == "gpt-3.5-turbo-16k":
- return self.gpt3516k
- elif model_name == "gpt-4":
- return self.gpt4
- elif model_name == "claude-2":
- return self.claude2
- elif model_name == "ggml":
- return self.ggml
- else:
- raise Exception(f"Unknown model {model_name}")
-
- @property
- def default(self):
- default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt4
-
-
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
ide: AbstractIdeProtocolServer
@@ -171,7 +53,7 @@ class ContinueSDK(AbstractContinueSDK):
formatted_err = '\n'.join(traceback.format_exception(e))
msg_step = MessageStep(
name="Invalid Continue Config File", message=formatted_err)
- msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```"
+ msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```\n\nIt's possible this error was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py.txt)."
sdk.history.add_node(HistoryNode(
step=msg_step,
observation=None,
@@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK):
active=False
))
+ sdk.models = sdk.config.models
+ await sdk.models.start(sdk)
+
# When the config is loaded, setup posthog logger
posthog_logger.setup(
sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
- sdk.models = await Models.create(sdk)
return sdk
@property
@@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK):
def write_log(self, message: str):
self.history.timeline[self.history.current_index].logs.append(message)
+ async def start_model(self, llm: LLM):
+ kwargs = {}
+ if llm.requires_api_key:
+ kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key)
+ if llm.requires_unique_id:
+ kwargs["unique_id"] = self.ide.unique_id
+ if llm.requires_write_log:
+ kwargs["write_log"] = self.write_log
+ await llm.start(**kwargs)
+
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
return path
@@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
+ # TODO support error prompt dynamically set on env_var
return await self.ide.getUserSecret(env_var)
_last_valid_config: ContinueConfig = None
diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py
index f09b813a..dba4874f 100644
--- a/continuedev/src/continuedev/libs/chroma/query.py
+++ b/continuedev/src/continuedev/libs/chroma/query.py
@@ -59,7 +59,7 @@ class ChromaIndexManager:
except:
logger.warning(
f"ERROR (probably found special token): {doc.text}")
- continue
+ continue # lol
filename = doc.extra_info["filename"]
chunks[filename] = len(text_chunks)
for i, text in enumerate(text_chunks):
diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py
index 23ed950f..d5326a06 100644
--- a/continuedev/src/continuedev/libs/chroma/update.py
+++ b/continuedev/src/continuedev/libs/chroma/update.py
@@ -23,7 +23,7 @@ def filter_ignored_files(files: List[str], root_dir: str):
"""Further filter files before indexing."""
for file in files:
if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'):
- continue
+ continue # nice
yield root_dir + "/" + file
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index 1a66c847..cf8b0324 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -9,9 +9,12 @@ import subprocess
from continuedev.core.main import Step
from continuedev.core.sdk import ContinueSDK
+from continuedev.core.models import Models
from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
from continuedev.plugins.context_providers.google import GoogleContextProvider
+from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from continuedev.plugins.policies.default import DefaultPolicy
from continuedev.plugins.steps.open_config import OpenConfigStep
from continuedev.plugins.steps.clear_history import ClearHistoryStep
@@ -35,9 +38,9 @@ class CommitMessageStep(Step):
diff = subprocess.check_output(
["git", "diff"], cwd=dir).decode("utf-8")
- # Ask gpt-3.5-16k to write a commit message,
+ # Ask the LLM to write a commit message,
# and set it as the description of this step
- self.description = await sdk.models.gpt3516k.complete(
+ self.description = await sdk.models.default.complete(
f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:")
@@ -47,9 +50,10 @@ config = ContinueConfig(
# See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry
allow_anonymous_telemetry=True,
- # GPT-4 is recommended for best results
- # See options here: https://continue.dev/docs/customization#change-the-default-llm
- default_model="gpt-4",
+ models=Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ medium=MaybeProxyOpenAI(model="gpt-3.5-turbo")
+ ),
# Set a system message with information that the LLM should always keep in mind
# E.g. "Please give concise answers. Always respond in Spanish."
@@ -114,5 +118,9 @@ config = ContinueConfig(
# GoogleContextProvider(
# serper_api_key="<your serper.dev api key>"
# )
- ]
+ ],
+
+ # Policies hold the main logic that decides which Step to take next
+ # You can use them to design agents, or deeply customize Continue
+ policy=DefaultPolicy()
)
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 2766db4b..50577993 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,14 +1,30 @@
-from abc import ABC
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from abc import ABC, abstractproperty
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
from ...core.main import ChatMessage
-from ...models.main import AbstractModel
-from pydantic import BaseModel
+from ...models.main import ContinueBaseModel
-class LLM(ABC):
+class LLM(ContinueBaseModel, ABC):
+ requires_api_key: Optional[str] = None
+ requires_unique_id: bool = False
+ requires_write_log: bool = False
+
system_message: Union[str, None] = None
+ @abstractproperty
+ def name(self):
+ """Return the name of the LLM."""
+ raise NotImplementedError
+
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ """Start the connection to the LLM."""
+ raise NotImplementedError
+
+ async def stop(self):
+ """Stop the connection to the LLM."""
+ raise NotImplementedError
+
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
@@ -24,3 +40,8 @@ class LLM(ABC):
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
raise NotImplementedError
+
+ @abstractproperty
+ def context_length(self) -> int:
+ """Return the context length of the LLM in tokens, as counted by count_tokens."""
+ raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 625d4e57..ec1b7e40 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,32 +1,39 @@
from functools import cached_property
import time
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
class AnthropicLLM(LLM):
- api_key: str
- default_model: str
- async_client: AsyncAnthropic
+ model: str = "claude-2"
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
- self.api_key = api_key
- self.default_model = default_model
+ requires_api_key: str = "ANTHROPIC_API_KEY"
+ _async_client: AsyncAnthropic = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __init__(self, model: str, system_message: str = None):
+ self.model = model
self.system_message = system_message
- self.async_client = AsyncAnthropic(api_key=api_key)
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self._async_client = AsyncAnthropic(api_key=api_key)
+
+ async def stop(self):
+ pass
@cached_property
def name(self):
- return self.default_model
+ return self.model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
args = args.copy()
@@ -40,7 +47,13 @@ class AnthropicLLM(LLM):
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
+
+ @property
+ def context_length(self):
+ if self.model == "claude-2":
+ return 100000
+ raise Exception(f"Unknown Anthropic model {self.model}")
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -60,7 +73,7 @@ class AnthropicLLM(LLM):
args["stream"] = True
args = self._transform_args(args)
- async for chunk in await self.async_client.completions.create(
+ async for chunk in await self._async_client.completions.create(
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**args
):
@@ -73,8 +86,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
- async for chunk in await self.async_client.completions.create(
+ args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
+ async for chunk in await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
):
@@ -88,8 +101,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
- resp = (await self.async_client.completions.create(
+ args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
+ resp = (await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
)).completion
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 4889a556..7742e8c3 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,6 +1,7 @@
from functools import cached_property
import json
from typing import Any, Coroutine, Dict, Generator, List, Union
+from pydantic import ConfigDict
import aiohttp
from ...core.main import ChatMessage
@@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000"
class GGML(LLM):
+ # this is model-specific
+ max_context_length: int = 2048
- def __init__(self, system_message: str = None):
- self.system_message = system_message
+ _client_session: aiohttp.ClientSession = None
- @cached_property
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession()
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
def name(self):
return "ggml"
@property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
@@ -33,54 +48,51 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
-
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": messages,
- **args
- }) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
- "messages": messages,
- **args
- }) as resp:
- # This is streaming application/json instaed of text/event-stream
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
- continue
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
- except:
- raise Exception(str(line[0]))
+ async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
- **args
- }) as resp:
- try:
- return await resp.text()
- except:
- raise Exception(await resp.text())
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 36f03270..49f593d8 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from ...core.main import ChatMessage
from ..llm import LLM
import requests
@@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
- api_key: str
model: str
- def __init__(self, api_key: str, model: str, system_message: str = None):
- self.api_key = api_key
+ requires_api_key: str = "HUGGING_FACE_TOKEN"
+ api_key: str = None
+
+ def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self.api_key = api_key
+
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
new file mode 100644
index 00000000..edf58fd7
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -0,0 +1,53 @@
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable
+
+from ...core.main import ChatMessage
+from . import LLM
+from .proxy_server import ProxyServer
+from .openai import OpenAI
+
+
+class MaybeProxyOpenAI(LLM):
+ model: str
+
+ requires_api_key: Optional[str] = "OPENAI_API_KEY"
+ requires_write_log: bool = True
+ requires_unique_id: bool = True
+ system_message: Union[str, None] = None
+
+ llm: Optional[LLM] = None
+
+ @property
+ def name(self):
+ return self.llm.name
+
+ @property
+ def context_length(self):
+ return self.llm.context_length
+
+ async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]):
+ if api_key is None or api_key.strip() == "":
+ self.llm = ProxyServer(model=self.model)
+ else:
+ self.llm = OpenAI(model=self.model)
+
+ await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id)
+
+ async def stop(self):
+ await self.llm.stop()
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ return await self.llm.complete(prompt, with_history=with_history, **kwargs)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_complete(
+ prompt, with_history=with_history, **kwargs)
+ async for item in resp:
+ yield item
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_chat(messages=messages, **kwargs)
+ async for item in resp:
+ yield item
+
+ def count_tokens(self, text: str):
+ return self.llm.count_tokens(text)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 654c7326..fce6e8ab 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,48 +1,83 @@
from functools import cached_property
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
-from ...core.main import ChatMessage
+from pydantic import BaseModel
import openai
+
+from ...core.main import ChatMessage
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
-from ...core.config import OpenAIServerInfo
+
+
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
+
+
+CHAT_MODELS = {
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
+}
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
+
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
class OpenAI(LLM):
- api_key: str
- default_model: str
+ model: str
+ openai_server_info: Optional[OpenAIServerInfo] = None
- def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None):
- self.api_key = api_key
- self.default_model = default_model
- self.system_message = system_message
- self.openai_server_info = openai_server_info
+ requires_api_key = "OPENAI_API_KEY"
+ requires_write_log = True
+
+ system_message: Optional[str] = None
+ azure_info: Optional[AzureInfo] = None
+ write_log: Optional[Callable[[str], None]] = None
+ api_key: str = None
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):
self.write_log = write_log
+ self.api_key = api_key
+ openai.api_key = self.api_key
- openai.api_key = api_key
+ if self.openai_server_info is not None:
+ openai.api_type = self.openai_server_info.api_type
+ if self.openai_server_info.api_base is not None:
+ openai.api_base = self.openai_server_info.api_base
+ if self.openai_server_info.api_version is not None:
+ openai.api_version = self.openai_server_info.api_version
- # Using an Azure OpenAI deployment
- if openai_server_info is not None:
- openai.api_type = openai_server_info.api_type
- if openai_server_info.api_base is not None:
- openai.api_base = openai_server_info.api_base
- if openai_server_info.api_version is not None:
- openai.api_version = openai_server_info.api_version
+ async def stop(self):
+ pass
- @cached_property
+ @property
def name(self):
- return self.default_model
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- args = {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.model}
if self.openai_server_info is not None:
args["engine"] = self.openai_server_info.engine
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -51,7 +86,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -62,7 +97,7 @@ class OpenAI(LLM):
yield chunk.choices[0].delta.content
completion += chunk.choices[0].delta.content
else:
- continue
+ continue # :)
self.write_log(f"Completion: \n\n{completion}")
else:
@@ -78,12 +113,13 @@ class OpenAI(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
- args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"
+ # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
+ args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -100,7 +136,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
messages=messages,
@@ -109,7 +145,7 @@ class OpenAI(LLM):
self.write_log(f"Completion: \n\n{resp}")
else:
prompt = prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"])
+ args["model"], self.context_length, prompt, args["max_tokens"])
self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
prompt=prompt,
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index f9e3fa01..1a48f213 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,7 +1,6 @@
-
import json
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
@@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
class ProxyServer(LLM):
- unique_id: str
- name: str
- default_model: Literal["gpt-3.5-turbo", "gpt-4"]
- write_log: Callable[[str], None]
+ model: str
+ system_message: Optional[str]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
- self.unique_id = unique_id
- self.default_model = default_model
- self.system_message = system_message
- self.name = default_model
+ unique_id: str = None
+ write_log: Callable[[str], None] = None
+ _client_session: aiohttp.ClientSession
+
+ requires_unique_id = True
+ requires_write_log = True
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context))
self.write_log = write_log
+ self.unique_id = unique_id
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
def get_headers(self):
# headers with unique id
@@ -45,75 +69,72 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- if resp.status != 200:
- raise Exception(await resp.text())
-
- response_text = await resp.text()
- self.write_log(f"Completion: \n\n{response_text}")
- return response_text
+ async with self._client_session.post(f"{SERVER_URL}/complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ if resp.status != 200:
+ raise Exception(await resp.text())
+
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_chat", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- # This is streaming application/json instaed of text/event-stream
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- json_chunk = "{}" if json_chunk == "" else json_chunk
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- loaded_chunk = json.loads(chunk)
- yield loaded_chunk
- if "content" in loaded_chunk:
- completion += loaded_chunk["content"]
- except Exception as e:
- posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
- "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
- else:
- break
-
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ json_chunk = "{}" if json_chunk == "" else json_chunk
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
+ except Exception as e:
+ posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
+ "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
+ else:
+ break
+
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_any():
- if line:
- try:
- decoded_line = line.decode("utf-8")
- yield decoded_line
- completion += decoded_line
- except:
- raise Exception(str(line))
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
+ except:
+ raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")
diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py
deleted file mode 100644
index 76240d4e..00000000
--- a/continuedev/src/continuedev/libs/llm/utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from transformers import GPT2TokenizerFast
-
-gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
-def count_tokens(text: str) -> int:
- return len(gpt2_tokenizer.encode(text))
-
-prices = {
- # All prices are per 1k tokens
- "fine-tune-train": {
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
- },
- "completion": {
- "davinci": 0.02,
- "curie": 0.002,
- "babbage": 0.0005,
- "ada": 0.0004,
- },
- "fine-tune-completion": {
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
- },
- "embedding": {
- "ada": 0.0004
- }
-}
-
-def get_price(text: str, model: str="davinci", task: str="completion") -> float:
- return count_tokens(text) * prices[task][model] / 1000 \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py
index ff0a135f..3e82bab3 100644
--- a/continuedev/src/continuedev/libs/util/calculate_diff.py
+++ b/continuedev/src/continuedev/libs/util/calculate_diff.py
@@ -92,7 +92,7 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit
tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index]
replacement = updated[j1:j2]
if tag == "equal":
- continue
+ continue # ;)
elif tag == "delete":
edits.append(FileEdit.from_deletion(
filepath, Range.from_indices(original, i1, i2)))
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index c58ae499..6add7b1a 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -2,43 +2,47 @@ import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
from .templating import render_templated_string
+from ...libs.llm import LLM
import tiktoken
+# TODO move many of these into specific LLM.properties() function that
+# contains max tokens, if its a chat model or not, default args (not all models
+# want to be run at 0.5 temp). also lets custom models made for long contexts
+# exist here (likg LLongMA)
aliases = {
"ggml": "gpt-3.5-turbo",
"claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
-MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192,
- "ggml": 2048,
- "claude-2": 100000
-}
-CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
-}
DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0}
-def encoding_for_model(model: str):
- return tiktoken.encoding_for_model(aliases.get(model, model))
+def encoding_for_model(model_name: str):
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
-def count_tokens(model: str, text: Union[str, None]):
+def count_tokens(model_name: str, text: Union[str, None]):
if text is None:
return 0
- encoding = encoding_for_model(model)
+ encoding = encoding_for_model(model_name)
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(
- model, DEFAULT_MAX_TOKENS) - tokens_for_completion
- encoding = encoding_for_model(model)
+def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE
+
+
+def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int):
+ max_tokens = context_length - tokens_for_completion
+ encoding = encoding_for_model(model_name)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
return prompt
@@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in
return encoding.decode(tokens[-max_tokens:])
-def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int:
- # Doing simpler, safer version of what is here:
- # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- # every message follows <|start|>{role/name}\n{content}<|end|>\n
- TOKENS_PER_MESSAGE = 4
- return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE
-
-
-def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
- sum(count_chat_message_tokens(model, message)
+ sum(count_chat_message_tokens(model_name, message)
for message in chat_history)
# 1. Replace beyond last 5 messages with summary
i = 0
- while total_tokens > max_tokens and i < len(chat_history) - 5:
+ while total_tokens > context_length and i < len(chat_history) - 5:
message = chat_history[0]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 2. Remove entire messages until the last 5
- while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:
+ while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 3. Truncate message in the last 5, except last 1
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1:
+ while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:
message = chat_history[i]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 4. Remove entire messages in the last 5, except last 1
- while total_tokens > max_tokens and len(chat_history) > 1:
+ while total_tokens > context_length and len(chat_history) > 1:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 5. Truncate last message
- if total_tokens > max_tokens and len(chat_history) > 0:
+ if total_tokens > context_length and len(chat_history) > 0:
message = chat_history[0]
message.content = prune_raw_prompt_from_top(
- model, message.content, tokens_for_completion)
- total_tokens = max_tokens
+ model_name, context_length, message.content, tokens_for_completion)
+ total_tokens = context_length
return chat_history
@@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
TOKEN_BUFFER_FOR_SAFETY = 100
-def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
@@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_
function_tokens = 0
if functions is not None:
for function in functions:
- function_tokens += count_tokens(model, json.dumps(function))
+ function_tokens += count_tokens(model_name, json.dumps(function))
msgs_copy = prune_chat_history(
- model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+ model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
history = [msg.to_dict(with_functions=functions is not None)
for msg in msgs_copy]
diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py
index f1fb8d0b..285c1e47 100644
--- a/continuedev/src/continuedev/libs/util/strings.py
+++ b/continuedev/src/continuedev/libs/util/strings.py
@@ -12,7 +12,7 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:
for i in range(1, len(lines)):
# Empty lines are wildcards
if lines[i].strip() == "":
- continue
+ continue # hey that's us!
# Iterate through the leading whitespace characters of the current line
for j in range(0, len(lcp)):
# If it doesn't have the same whitespace as lcp, then update lcp
diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py
index 51869fdd..2166bc37 100644
--- a/continuedev/src/continuedev/models/generate_json_schema.py
+++ b/continuedev/src/continuedev/models/generate_json_schema.py
@@ -39,7 +39,7 @@ def main():
json = schema_json_of(model, indent=2, title=title)
except Exception as e:
print(f"Failed to generate json schema for {title}: {e}")
- continue
+ continue # pun intended
with open(f"{SCHEMA_DIR}/{title}.json", "w") as f:
f.write(json)
diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
new file mode 100644
index 00000000..42d1f754
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
@@ -0,0 +1,79 @@
+import os
+from typing import List, Optional
+import uuid
+from pydantic import BaseModel
+
+from ...core.main import ContextItemId
+from ...core.context import ContextProvider
+from ...core.main import ContextItem, ContextItemDescription, ContextItemId
+from ...libs.chroma.query import ChromaIndexManager
+from .util import remove_meilisearch_disallowed_chars
+
+
+class EmbeddingResult(BaseModel):
+ filename: str
+ content: str
+
+
+class EmbeddingsProvider(ContextProvider):
+ title = "embed"
+
+ workspace_directory: str
+
+ EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"
+
+ index_manager: Optional[ChromaIndexManager] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ @property
+ def index(self):
+ if self.index_manager is None:
+ self.index_manager = ChromaIndexManager(self.workspace_directory)
+ return self.index_manager
+
+ @property
+ def BASE_CONTEXT_ITEM(self):
+ return ContextItem(
+ content="",
+ description=ContextItemDescription(
+ name="Embedding Search",
+ description="Enter a query to embedding search codebase",
+ id=ContextItemId(
+ provider_title=self.title,
+ item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID
+ )
+ )
+ )
+
+ async def _get_query_results(self, query: str) -> str:
+ results = self.index.query_codebase_index(query)
+
+ ret = []
+ for node in results.source_nodes:
+ resource_name = list(node.node.relationships.values())[0]
+ filepath = resource_name[:resource_name.index("::")]
+ ret.append(EmbeddingResult(
+ filename=filepath, content=node.node.text))
+
+ return ret
+
+ async def provide_context_items(self) -> List[ContextItem]:
+ self.index.create_codebase_index() # TODO Synchronous here is not ideal
+
+ return [self.BASE_CONTEXT_ITEM]
+
+ async def add_context_item(self, id: ContextItemId, query: str):
+ if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID:
+ raise Exception("Invalid item id")
+
+ results = await self._get_query_results(query)
+
+ for i in range(len(results)):
+ result = results[i]
+ ctx_item = self.BASE_CONTEXT_ITEM.copy()
+ ctx_item.description.name = os.path.basename(result.filename)
+ ctx_item.content = f"{result.filename}\n```\n{result.content}\n```"
+ ctx_item.description.id.item_id = uuid.uuid4().hex
+ self.selected_items.append(ctx_item)
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 634774df..31aa5423 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -3,6 +3,7 @@ import re
from typing import List
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...core.context import ContextProvider
+from .util import remove_meilisearch_disallowed_chars
from fnmatch import fnmatch
@@ -79,7 +80,7 @@ class FileContextProvider(ContextProvider):
description=file,
id=ContextItemId(
provider_title=self.title,
- item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file)
+ item_id=remove_meilisearch_disallowed_chars(file)
)
)
))
diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py
index fc76fe67..4b0a59ec 100644
--- a/continuedev/src/continuedev/plugins/context_providers/google.py
+++ b/continuedev/src/continuedev/plugins/context_providers/google.py
@@ -2,6 +2,7 @@ import json
from typing import List
import aiohttp
+from .util import remove_meilisearch_disallowed_chars
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...core.context import ContextProvider
@@ -60,5 +61,6 @@ class GoogleContextProvider(ContextProvider):
ctx_item = self.BASE_CONTEXT_ITEM.copy()
ctx_item.content = content
- ctx_item.description.id.item_id = query
+ ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(
+ query)
return ctx_item
diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py
new file mode 100644
index 00000000..da2e6b17
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/context_providers/util.py
@@ -0,0 +1,5 @@
+import re
+
+
+def remove_meilisearch_disallowed_chars(id: str) -> str:
+ return re.sub(r'[^0-9a-zA-Z_-]', '', id)
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/plugins/policies/default.py
index d90177b5..523c2cf4 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/plugins/policies/default.py
@@ -1,15 +1,15 @@
from textwrap import dedent
from typing import Union
-from ..plugins.steps.chat import SimpleChatStep
-from ..plugins.steps.welcome import WelcomeStep
-from .config import ContinueConfig
-from ..plugins.steps.steps_on_startup import StepsOnStartupStep
-from .main import Step, History, Policy
-from .observation import UserInputObservation
-from ..plugins.steps.core.core import MessageStep
-from ..plugins.steps.custom_command import CustomCommandStep
-from ..plugins.steps.main import EditHighlightedCodeStep
+from ..steps.chat import SimpleChatStep
+from ..steps.welcome import WelcomeStep
+from ...core.config import ContinueConfig
+from ..steps.steps_on_startup import StepsOnStartupStep
+from ...core.main import Step, History, Policy
+from ...core.observation import UserInputObservation
+from ..steps.core.core import MessageStep
+from ..steps.custom_command import CustomCommandStep
+from ..steps.main import EditHighlightedCodeStep
def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
@@ -45,7 +45,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
class DefaultPolicy(Policy):
- ran_code_last: bool = False
+
+ default_step: Step = SimpleChatStep()
def next(self, config: ContinueConfig, history: History) -> Step:
# At the very start, run initial Steps spcecified in the config
@@ -56,7 +57,6 @@ class DefaultPolicy(Policy):
- Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
- Use `/help` to ask questions about how to use Continue""")) >>
WelcomeStep() >>
- # CreateCodebaseIndexChroma() >>
StepsOnStartupStep())
observation = history.get_current().observation
@@ -75,6 +75,6 @@ class DefaultPolicy(Policy):
if user_input.startswith("/edit"):
return EditHighlightedCodeStep(user_input=user_input[5:])
- return SimpleChatStep()
+ return self.default_step.copy()
return None
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index 433e309e..872f8d62 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -27,7 +27,7 @@ class SetupPipelineStep(Step):
async def run(self, sdk: ContinueSDK):
sdk.context.set("api_description", self.api_description)
- source_name = (await sdk.models.gpt35.complete(
+ source_name = (await sdk.models.medium.complete(
f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()
filename = f'{source_name}.py'
@@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{file_content}
```
@@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):
This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:"""))
- api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\
+ api_documentation_url = await sdk.models.medium.complete(dedent(f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
```python
{file_content}
@@ -151,7 +151,7 @@ class RunQueryStep(Step):
output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}
```
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index 6ef5ffd6..c66cd629 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):
"{self.user_input}"
Here is a complete set of pytest unit tests:""")
- tests = await sdk.models.gpt35.complete(prompt)
+ tests = await sdk.models.medium.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md
index 12073835..3f2f804c 100644
--- a/continuedev/src/continuedev/plugins/steps/README.md
+++ b/continuedev/src/continuedev/plugins/steps/README.md
@@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
Here's an example:
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index f72a8ec0..455d5a13 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep
from ...core.main import FunctionCall, Models
from ...core.main import ChatMessage, Step, step_to_json_schema
from ...core.sdk import ContinueSDK
+from ...libs.llm.openai import OpenAI
import openai
import os
from dotenv import load_dotenv
@@ -41,7 +42,7 @@ class SimpleChatStep(Step):
self.description += chunk["content"]
await sdk.update_ui()
- self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
f"Write a short title for the following chat message: {self.description}"))
self.chat_context.append(ChatMessage(
@@ -166,7 +167,10 @@ class ChatWithFunctions(Step):
msg_content = ""
msg_step = None
- async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
+ gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(gpt350613)
+
+ async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
if sdk.current_step_was_deleted():
return
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index dbe8363e..658cc7f3 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = await sdk.models.gpt35.complete(prompt)
+ answer = await sdk.models.medium.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index c80cecc3..4476c7ae 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -11,11 +11,12 @@ from pydantic import validator
from ....libs.llm.ggml import GGML
from ....models.main import Range
+from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
-from ....core.observation import Observation, TextObservation, UserInputObservation
-from ....core.main import ChatMessage, ContinueCustomException, Step
-from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
+from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
+from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
+from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
@@ -97,7 +98,7 @@ class ShellCommandsStep(Step):
return f"Error when running shell commands:\n```\n{self._err_text}\n```"
cmds_str = "\n".join(self.cmds)
- return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
+ return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd
@@ -105,7 +106,7 @@ class ShellCommandsStep(Step):
for cmd in self.cmds:
output = await sdk.ide.runCommand(cmd)
if self.handle_error and output is not None and output_contains_error(output):
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
While running the command `{cmd}`, the following error occurred:
```ascii
@@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step):
else:
changes = '\n'.join(difflib.ndiff(
self._previous_contents.splitlines(), self._new_contents.splitlines()))
- description = await models.gpt3516k.complete(dedent(f"""\
+ description = await models.medium.complete(dedent(f"""\
Diff summary: "{self.user_input}"
```diff
@@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step):
```
Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))
- name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
+ name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
self.name = remove_quotes_and_escapes(name)
return f"{remove_quotes_and_escapes(description)}"
@@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.default
- max_tokens = int(MAX_TOKENS_FOR_MODEL.get(
- model_to_use.name, DEFAULT_MAX_TOKENS) / 2)
+ max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step):
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
- if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
- model_to_use = sdk.models.gpt3516k
+ if total_tokens > model_to_use.context_length:
+ model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(model_to_use)
# Remove tokens from the end first, and then the start to clear space
# This part finds the start and end lines
@@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step):
cur_start_line = 0
cur_end_line = len(full_file_contents_lst) - 1
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_end_line > min_end_line:
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_end_line])
cur_end_line -= 1
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_start_line < max_start_line:
cur_start_line += 1
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_start_line])
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
# Now use the found start/end lines to get the prefix and suffix strings
@@ -525,7 +526,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user
# Accumulate lines
if "content" not in chunk:
- continue
+ continue # ayo
chunk = chunk["content"]
chunk_lines = chunk.split("\n")
chunk_lines[0] = unfinished_line + chunk_lines[0]
@@ -546,12 +547,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user
break
# Lines that should be ignored, like the <> tags
elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0):
- continue
+ continue # noice
# Check if we are currently just copying the prefix
elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]:
# This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line
lines_of_prefix_copied += 1
- continue
+ continue # also nice
# Because really short lines might be expected to be repeated, this is only a !heuristic!
# Stop when it starts copying the file_suffix
elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()):
diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py
index a76d491b..c38f54dc 100644
--- a/continuedev/src/continuedev/plugins/steps/draft/migration.py
+++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py
@@ -13,7 +13,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run([
"cd libs",
"poetry run alembic revision --autogenerate -m " + description,
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index 6997a547..ec670999 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -56,7 +56,7 @@ class HelpStep(Step):
summary="Help"
))
messages = await sdk.get_chat_context()
- generator = sdk.models.gpt4.stream_chat(messages)
+ generator = sdk.models.default.stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index b54d394a..3d8d96fb 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -23,6 +23,6 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.gpt35.complete(
+ gpt_parsed = await sdk.models.default.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index f28d9660..d2d6f4dd 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.gpt35.complete(prompt)
+ completion = await sdk.models.medium.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step):
_prompt_and_completion: str = ""
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
+ return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
range_in_files = await sdk.get_code_context(only_editing=True)
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index 8b2e7c2e..da6acdbf 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -27,7 +27,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""")
- resp = (await sdk.models.gpt35.complete(prompt)).lower()
+ resp = (await sdk.models.medium.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 07b50473..456dba84 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -20,10 +20,10 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]:
for root, dirs, files in os.walk(dirpath):
dirname = os.path.basename(root)
if dirname.startswith(".") or dirname in IGNORE_DIRS:
- continue
+ continue # continue!
for file in files:
if file in IGNORE_FILES:
- continue
+ continue # pun intended
with open(os.path.join(root, file), "r") as f:
# Find the index of all occurences of the pattern in the file. Use re.
file_content = f.read()
@@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.gpt35.complete(dedent(f"""\
+ pattern = await sdk.models.medium.complete(dedent(f"""\
This is the user request:
{self.user_request}
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 98a5aea0..cf18c56b 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -176,7 +176,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
message = json.loads(message)
if "messageType" not in message or "data" not in message:
- continue
+ continue # :o
message_type = message["messageType"]
data = message["data"]
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e4c07029..6124f3bd 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -139,7 +139,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
msg_string = await self.websocket.receive_text()
message = json.loads(msg_string)
if "messageType" not in message or "data" not in message:
- continue
+ continue # <-- hey that's the name of this repo!
message_type = message["messageType"]
data = message["data"]
logger.debug(f"Received message while initializing {message_type}")
@@ -311,7 +311,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onFileEdits(self, edits: List[FileEditWithFullContents]):
if autopilot := self.__get_autopilot():
- autopilot.handle_manual_edits(edits)
+ pass
def onDeleteAtIndex(self, index: int):
if autopilot := self.__get_autopilot():
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index cf46028f..b5580fe8 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,4 +1,5 @@
import os
+import traceback
from fastapi import WebSocket
from typing import Any, Dict, List, Union
from uuid import uuid4
@@ -6,12 +7,10 @@ import json
from fastapi.websockets import WebSocketState
-from ..plugins.steps.core.core import DisplayErrorStep
+from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
-from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
-from ..core.policy import DefaultPolicy
-from ..core.main import FullState
+from ..core.main import FullState, HistoryNode
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
@@ -31,19 +30,6 @@ class Session:
self.ws = None
-class DemoAutopilot(Autopilot):
- first_seen: bool = False
- cumulative_edit_string = ""
-
- def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
- return
- for edit in edits:
- self.cumulative_edit_string += edit.fileEdit.replacement
- self._manual_edits_buffer.append(edit)
- # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
- # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
-
-
class SessionManager:
sessions: Dict[str, Session] = {}
# Mapping of session_id to IDE, where the IDE is still alive
@@ -65,27 +51,47 @@ class SessionManager:
async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
logger.debug(f"New session: {session_id}")
+ # Load the persisted state (not being used right now)
full_state = None
if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
with open(getSessionFilePath(session_id), "r") as f:
full_state = FullState(**json.load(f))
- autopilot = await DemoAutopilot.create(
- policy=DefaultPolicy(), ide=ide, full_state=full_state)
+ # Register the session and ide (do this first so that the autopilot can access the session)
+ autopilot = Autopilot(ide=ide)
session_id = session_id or str(uuid4())
ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
self.sessions[session_id] = session
self.registered_ides[session_id] = ide
+ # Set up the autopilot to update the GUI
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
"state": state.dict()
})
autopilot.on_update(on_update)
- create_async_task(autopilot.run_policy(
- ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)))
+
+ # Start the autopilot (must be after session is added to sessions) and the policy
+ try:
+ await autopilot.start()
+ except Exception as e:
+ # Have to manually add to history because autopilot isn't started
+ formatted_err = '\n'.join(traceback.format_exception(e))
+ msg_step = MessageStep(
+ name="Error loading context manager", message=formatted_err)
+ msg_step.description = f"```\n{formatted_err}\n```"
+ autopilot.history.add_node(HistoryNode(
+ step=msg_step,
+ observation=None,
+ depth=0,
+ active=False
+ ))
+ logger.warning(f"Error loading context manager: {e}")
+
+ create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step(
+ DisplayErrorStep(e=e)))
return session
async def remove_session(self, session_id: str):
diff --git a/docs/docs/customization.md b/docs/docs/customization.md
index fa4d110e..60764527 100644
--- a/docs/docs/customization.md
+++ b/docs/docs/customization.md
@@ -4,11 +4,25 @@ Continue can be deeply customized by editing the `ContinueConfig` object in `~/.
## Change the default LLM
-Change the `default_model` field to any of "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "claude-2", or "ggml".
+In `config.py`, you'll find the `models` property:
+
+```python
+config = ContinueConfig(
+ ...
+ models=Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ medium=MaybeProxyOpenAI(model="gpt-3.5-turbo")
+ )
+)
+```
+
+The `default` model is the one used for most operations, including responding to your messages and editing code. The `medium` model is used for summarization tasks that require less quality. There are also `small` and `large` roles that can be filled, but all will fall back to `default` if not set. The values of these fields must be of the [`LLM`](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/llm/__init__.py) class, which implements methods for retrieving and streaming completions from an LLM.
+
+Below, we describe the `LLM` classes available in the Continue core library, and how they can be used.
### Adding an OpenAI API key
-New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code.
+With the `MaybeProxyOpenAI` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code.
Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps:
@@ -18,34 +32,55 @@ Once you are using Continue regularly though, you will need to add an OpenAI API
4. Click Edit in settings.json under Continue: OpenAI_API_KEY" section
5. Paste your API key as the value for "continue.OPENAI_API_KEY" in settings.json
-### claude-2 and gpt-X
+The `MaybeProxyOpenAI` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead.
+
+These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k".
+
+### claude-2
-If you have access, simply set `default_model` to the model you would like to use, then you will be prompted for a personal API key after reloading VS Code. If using an OpenAI model, you can press enter to try with our API key for free.
+Import the `Anthropic` LLM class and set it as the default model:
+
+```python
+from continuedev.libs.llm.anthropic import Anthropic
+
+config = ContinueConfig(
+ ...
+ models=Models(
+ default=Anthropic(model="claude-2")
+ )
+)
+```
+
+Continue will automatically prompt you for your Anthropic API key, which must have access to Claude 2. You can request early access [here](https://www.anthropic.com/earlyaccess).
### Local models with ggml
See our [5 minute quickstart](https://github.com/continuedev/ggml-server-example) to run any model locally with ggml. While these models don't yet perform as well, they are free, entirely private, and run offline.
-Once the model is running on localhost:8000, set `default_model` in `~/.continue/config.py` to "ggml".
+Once the model is running on localhost:8000, import the `GGML` LLM class from `continuedev.libs.llm.ggml` and set `default=GGML(max_context_length=2048)`.
### Self-hosting an open-source model
-If you want to self-host on Colab, RunPod, Replicate, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`.
+If you want to self-host on Colab, RunPod, Replicate, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 primary methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`.
If by chance the provider has the exact same API interface as OpenAI, the `GGML` class will work for you out of the box, after changing the endpoint at the top of the file.
### Azure OpenAI Service
-If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, set `default_model` to "gpt-4", and then set the `openai_server_info` property in the `ContinueConfig` like so:
+If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, instantiate the model like so:
```python
+from continuedev.libs.llm.openai import OpenAI, OpenAIServerInfo
+
config = ContinueConfig(
...
- openai_server_info=OpenAIServerInfo(
- api_base="https://my-azure-openai-instance.openai.azure.com/",
- engine="my-azure-openai-deployment",
- api_version="2023-03-15-preview",
- api_type="azure"
+ models=Models(
+ default=OpenAI(model="gpt-3.5-turbo", server_info=OpenAIServerInfo(
+ api_base="https://my-azure-openai-instance.openai.azure.com/"
+ engine="my-azure-openai-deployment",
+ api_version="2023-03-15-preview",
+ api_type="azure"
+ ))
)
)
```
diff --git a/docs/docs/walkthroughs/create-a-recipe.md b/docs/docs/walkthroughs/create-a-recipe.md
index 5d80d083..2cb28f77 100644
--- a/docs/docs/walkthroughs/create-a-recipe.md
+++ b/docs/docs/walkthroughs/create-a-recipe.md
@@ -31,7 +31,7 @@ If you'd like to override the default description of your steps, which is just t
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
## 2. Compose steps together into a complete recipe
diff --git a/extension/media/terminal-continue.png b/extension/media/terminal-continue.png
index ef310fa3..27667fe9 100644
--- a/extension/media/terminal-continue.png
+++ b/extension/media/terminal-continue.png
Binary files differ
diff --git a/extension/package-lock.json b/extension/package-lock.json
index 6f289260..4c0b7093 100644
--- a/extension/package-lock.json
+++ b/extension/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "continue",
- "version": "0.0.227",
+ "version": "0.0.228",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "continue",
- "version": "0.0.227",
+ "version": "0.0.228",
"license": "Apache-2.0",
"dependencies": {
"@electron/rebuild": "^3.2.10",
diff --git a/extension/package.json b/extension/package.json
index 5d0ccad2..df54bb4f 100644
--- a/extension/package.json
+++ b/extension/package.json
@@ -14,7 +14,7 @@
"displayName": "Continue",
"pricing": "Free",
"description": "The open-source coding autopilot",
- "version": "0.0.227",
+ "version": "0.0.228",
"publisher": "Continue",
"engines": {
"vscode": "^1.67.0"
diff --git a/extension/src/suggestions.ts b/extension/src/suggestions.ts
index 5c2b8860..b5be341d 100644
--- a/extension/src/suggestions.ts
+++ b/extension/src/suggestions.ts
@@ -72,7 +72,7 @@ export function rerenderDecorations(editorUri: string) {
range.end.character === 0
) {
// Empty range, don't show it
- continue;
+ continue; // is great
}
newRanges.push(
new vscode.Range(