summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/README.md12
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py2
-rw-r--r--continuedev/src/continuedev/core/autopilot.py24
-rw-r--r--continuedev/src/continuedev/core/config.py22
-rw-r--r--continuedev/src/continuedev/core/context.py6
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/core/models.py65
-rw-r--r--continuedev/src/continuedev/core/sdk.py141
-rw-r--r--continuedev/src/continuedev/libs/chroma/query.py2
-rw-r--r--continuedev/src/continuedev/libs/chroma/update.py2
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt20
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py31
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py47
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py100
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py53
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py94
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py165
-rw-r--r--continuedev/src/continuedev/libs/llm/utils.py34
-rw-r--r--continuedev/src/continuedev/libs/util/calculate_diff.py2
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py86
-rw-r--r--continuedev/src/continuedev/libs/util/strings.py2
-rw-r--r--continuedev/src/continuedev/models/generate_json_schema.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/embeddings.py79
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py3
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/google.py4
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/util.py5
-rw-r--r--continuedev/src/continuedev/plugins/policies/default.py (renamed from continuedev/src/continuedev/core/policy.py)24
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/README.md2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py8
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py37
-rw-r--r--continuedev/src/continuedev/plugins/steps/draft/migration.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py4
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py6
-rw-r--r--continuedev/src/continuedev/server/gui.py2
-rw-r--r--continuedev/src/continuedev/server/ide.py4
-rw-r--r--continuedev/src/continuedev/server/session_manager.py48
43 files changed, 675 insertions, 496 deletions
diff --git a/continuedev/README.md b/continuedev/README.md
index d3ead8ec..6a11ae43 100644
--- a/continuedev/README.md
+++ b/continuedev/README.md
@@ -67,9 +67,9 @@ cd continue/extension/scripts && python3 install_from_source.py
# Understanding the codebase
-- [Continue Server README](./continuedev/README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/)
-- [VS Code Extension README](./extension/README.md): learn about the capabilities of our extension—the first implementation of Continue's IDE Protocol—which makes it possible to use use Continue in VS Code and GitHub Codespaces
-- [Continue GUI README](./extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any suppported IDE
-- [Schema README](./schema): learn about the JSON Schema types generated from Pydantic models, which we use across the `continuedev/` and `extension/` directories
-- [Continue Docs README](./docs): learn how our [docs](https://continue.dev/docs) are written and built
-- [How to debug the VS Code Extension README](./extension/src/README.md): learn how to set up the VS Code extension, so you can debug it
+- [Continue Server README](./README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/)
+- [VS Code Extension README](../extension/README.md): learn about the capabilities of our extension—the first implementation of Continue's IDE Protocol—which makes it possible to use use Continue in VS Code and GitHub Codespaces
+- [Continue GUI README](../extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any suppported IDE
+- [Schema README](../schema/README.md): learn about the JSON Schema types generated from Pydantic models, which we use across the `continuedev/` and `extension/` directories
+- [Continue Docs README](../docs/README.md): learn how our [docs](https://continue.dev/docs) are written and built
+- [How to debug the VS Code Extension README](../extension/src/README.md): learn how to set up the VS Code extension, so you can debug it
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 94d7be10..e048f877 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC):
pass
@abstractmethod
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
pass
config: ContinueConfig
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 57e39d5c..de95a259 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from .observation import Observation, InternalErrorObservation
from .context import ContextManager
+from ..plugins.policies.default import DefaultPolicy
from ..plugins.context_providers.file import FileContextProvider
from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider
from ..server.ide_protocol import AbstractIdeProtocolServer
@@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str:
class Autopilot(ContinueBaseModel):
- policy: Policy
ide: AbstractIdeProtocolServer
+
+ policy: Policy = DefaultPolicy()
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
@@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @classmethod
- async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
- autopilot = cls(ide=ide, policy=policy)
- autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ async def start(self):
+ self.continue_sdk = await ContinueSDK.create(self)
+ if override_policy := self.continue_sdk.config.policy_override:
+ self.policy = override_policy
# Load documents into the search index
- autopilot.context_manager = await ContextManager.create(
- autopilot.continue_sdk.config.context_providers + [
- HighlightedCodeContextProvider(ide=ide),
- FileContextProvider(workspace_dir=ide.workspace_directory)
+ self.context_manager = await ContextManager.create(
+ self.continue_sdk.config.context_providers + [
+ HighlightedCodeContextProvider(ide=self.ide),
+ FileContextProvider(workspace_dir=self.ide.workspace_directory)
])
- await autopilot.context_manager.load_index(ide.workspace_directory)
- return autopilot
+ await self.context_manager.load_index(self.ide.workspace_directory)
class Config:
arbitrary_types_allowed = True
@@ -95,7 +96,6 @@ class Autopilot(ContinueBaseModel):
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
- default_model=self.continue_sdk.config.default_model,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
"code"].adding_highlighted_code,
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 9fbda824..84b6b10b 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -2,9 +2,13 @@ import json
import os
from .main import Step
from .context import ContextProvider
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from .models import Models
from pydantic import BaseModel, validator
-from typing import List, Literal, Optional, Dict, Type, Union
-import yaml
+from typing import List, Literal, Optional, Dict, Type
+
+from .main import Policy, Step
+from .context import ContextProvider
class SlashCommand(BaseModel):
@@ -25,13 +29,6 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class OpenAIServerInfo(BaseModel):
- api_base: Optional[str] = None
- engine: Optional[str] = None
- api_version: Optional[str] = None
- api_type: Literal["azure", "openai"] = "openai"
-
-
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
@@ -39,8 +36,9 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2", "ggml"] = 'gpt-4'
+ models: Models = Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ )
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
@@ -50,7 +48,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- openai_server_info: Optional[OpenAIServerInfo] = None
+ policy_override: Optional[Policy] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 86522ce1..e968c35c 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -178,12 +178,6 @@ class ContextManager:
except Exception as e:
logger.debug(f"Error loading meilisearch index: {e}")
- # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
- # """
- # Compiles the chat prompt into a single string.
- # """
- # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
-
async def select_context_item(self, id: str, query: str):
"""
Selects the ContextItem with the given id.
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index df9b98ef..2553850f 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):
history: History
active: bool
user_input_queue: List[str]
- default_model: str
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
new file mode 100644
index 00000000..900762b6
--- /dev/null
+++ b/continuedev/src/continuedev/core/models.py
@@ -0,0 +1,65 @@
+from typing import Optional, Any
+from pydantic import BaseModel, validator
+from ..libs.llm import LLM
+
+
+class Models(BaseModel):
+ """Main class that holds the current model configuration"""
+ default: LLM
+ small: Optional[LLM] = None
+ medium: Optional[LLM] = None
+ large: Optional[LLM] = None
+
+ # TODO namespace these away to not confuse readers,
+ # or split Models into ModelsConfig, which gets turned into Models
+ sdk: "ContinueSDK" = None
+ system_message: Any = None
+
+ """
+ Better to have sdk.llm.stream_chat(messages, model="claude-2").
+ Then you also don't care that it' async.
+ And it's easier to add more models.
+ And intermediate shared code is easier to add.
+ And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
+ PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
+ Easy to reason about, can place anywhere.
+ And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
+ This can all happen inside of Models?
+
+ class Prompt:
+ def __init__(self, ...info):
+ '''take whatever info is needed to describe the prompt'''
+
+ def to_string(self, model: str) -> str:
+ '''depending on the model, return the single prompt string'''
+ """
+
+ async def start(self, sdk: "ContinueSDK"):
+ """Start each of the LLMs, or fall back to default"""
+ self.sdk = sdk
+ self.system_message = self.sdk.config.system_message
+ await sdk.start_model(self.default)
+ if self.small:
+ await sdk.start_model(self.small)
+ else:
+ self.small = self.default
+
+ if self.medium:
+ await sdk.start_model(self.medium)
+ else:
+ self.medium = self.default
+
+ if self.large:
+ await sdk.start_model(self.large)
+ else:
+ self.large = self.default
+
+ async def stop(self, sdk: "ContinueSDK"):
+ """Stop each LLM (if it's not the default, which is shared)"""
+ await self.default.stop()
+ if self.small is not self.default:
+ await self.small.stop()
+ if self.medium is not self.default:
+ await self.medium.stop()
+ if self.large is not self.default:
+ await self.large.stop()
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 4b76a121..bf22d696 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
-from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
-from ..libs.llm.openai import OpenAI
-from ..libs.llm.anthropic import AnthropicLLM
-from ..libs.llm.ggml import GGML
+from ..libs.llm import LLM
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage
from ..plugins.steps.core.core import *
-from ..libs.llm.proxy_server import ProxyServer
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
+from .models import Models
from ..libs.util.logging import logger
@@ -27,121 +24,6 @@ class Autopilot:
pass
-ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
-MODEL_PROVIDER_TO_ENV_VAR = {
- "openai": "OPENAI_API_KEY",
- "hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY",
-}
-
-
-class Models:
- provider_keys: Dict[ModelProvider, str] = {}
- model_providers: List[ModelProvider]
- system_message: str
-
- """
- Better to have sdk.llm.stream_chat(messages, model="claude-2").
- Then you also don't care that it' async.
- And it's easier to add more models.
- And intermediate shared code is easier to add.
- And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
- PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
- Easy to reason about, can place anywhere.
- And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
- This can all happen inside of Models?
-
- class Prompt:
- def __init__(self, ...info):
- '''take whatever info is needed to describe the prompt'''
-
- def to_string(self, model: str) -> str:
- '''depending on the model, return the single prompt string'''
- """
-
- def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
- self.sdk = sdk
- self.model_providers = model_providers
- self.system_message = sdk.config.system_message
-
- @classmethod
- async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
- if sdk.config.default_model == "claude-2":
- with_providers.append("anthropic")
-
- models = Models(sdk, with_providers)
- for provider in with_providers:
- if provider in MODEL_PROVIDER_TO_ENV_VAR:
- env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
- models.provider_keys[provider] = await sdk.get_user_secret(
- env_var, f'Please add your {env_var} to the .env file')
-
- return models
-
- def __load_openai_model(self, model: str) -> OpenAI:
- api_key = self.provider_keys["openai"]
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)
-
- def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
- api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
-
- def __load_anthropic_model(self, model: str) -> AnthropicLLM:
- api_key = self.provider_keys["anthropic"]
- return AnthropicLLM(api_key, model, self.system_message)
-
- @cached_property
- def claude2(self):
- return self.__load_anthropic_model("claude-2")
-
- @cached_property
- def starcoder(self):
- return self.__load_hf_inference_api_model("bigcode/starcoder")
-
- @cached_property
- def gpt35(self):
- return self.__load_openai_model("gpt-3.5-turbo")
-
- @cached_property
- def gpt350613(self):
- return self.__load_openai_model("gpt-3.5-turbo-0613")
-
- @cached_property
- def gpt3516k(self):
- return self.__load_openai_model("gpt-3.5-turbo-16k")
-
- @cached_property
- def gpt4(self):
- return self.__load_openai_model("gpt-4")
-
- @cached_property
- def ggml(self):
- return GGML(system_message=self.system_message)
-
- def __model_from_name(self, model_name: str):
- if model_name == "starcoder":
- return self.starcoder
- elif model_name == "gpt-3.5-turbo":
- return self.gpt35
- elif model_name == "gpt-3.5-turbo-16k":
- return self.gpt3516k
- elif model_name == "gpt-4":
- return self.gpt4
- elif model_name == "claude-2":
- return self.claude2
- elif model_name == "ggml":
- return self.ggml
- else:
- raise Exception(f"Unknown model {model_name}")
-
- @property
- def default(self):
- default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt4
-
-
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
ide: AbstractIdeProtocolServer
@@ -171,7 +53,7 @@ class ContinueSDK(AbstractContinueSDK):
formatted_err = '\n'.join(traceback.format_exception(e))
msg_step = MessageStep(
name="Invalid Continue Config File", message=formatted_err)
- msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```"
+ msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```\n\nIt's possible this error was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py.txt)."
sdk.history.add_node(HistoryNode(
step=msg_step,
observation=None,
@@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK):
active=False
))
+ sdk.models = sdk.config.models
+ await sdk.models.start(sdk)
+
# When the config is loaded, setup posthog logger
posthog_logger.setup(
sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
- sdk.models = await Models.create(sdk)
return sdk
@property
@@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK):
def write_log(self, message: str):
self.history.timeline[self.history.current_index].logs.append(message)
+ async def start_model(self, llm: LLM):
+ kwargs = {}
+ if llm.requires_api_key:
+ kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key)
+ if llm.requires_unique_id:
+ kwargs["unique_id"] = self.ide.unique_id
+ if llm.requires_write_log:
+ kwargs["write_log"] = self.write_log
+ await llm.start(**kwargs)
+
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
return path
@@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
+ # TODO support error prompt dynamically set on env_var
return await self.ide.getUserSecret(env_var)
_last_valid_config: ContinueConfig = None
diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py
index f09b813a..dba4874f 100644
--- a/continuedev/src/continuedev/libs/chroma/query.py
+++ b/continuedev/src/continuedev/libs/chroma/query.py
@@ -59,7 +59,7 @@ class ChromaIndexManager:
except:
logger.warning(
f"ERROR (probably found special token): {doc.text}")
- continue
+ continue # lol
filename = doc.extra_info["filename"]
chunks[filename] = len(text_chunks)
for i, text in enumerate(text_chunks):
diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py
index 23ed950f..d5326a06 100644
--- a/continuedev/src/continuedev/libs/chroma/update.py
+++ b/continuedev/src/continuedev/libs/chroma/update.py
@@ -23,7 +23,7 @@ def filter_ignored_files(files: List[str], root_dir: str):
"""Further filter files before indexing."""
for file in files:
if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'):
- continue
+ continue # nice
yield root_dir + "/" + file
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index 1a66c847..cf8b0324 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -9,9 +9,12 @@ import subprocess
from continuedev.core.main import Step
from continuedev.core.sdk import ContinueSDK
+from continuedev.core.models import Models
from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
from continuedev.plugins.context_providers.google import GoogleContextProvider
+from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from continuedev.plugins.policies.default import DefaultPolicy
from continuedev.plugins.steps.open_config import OpenConfigStep
from continuedev.plugins.steps.clear_history import ClearHistoryStep
@@ -35,9 +38,9 @@ class CommitMessageStep(Step):
diff = subprocess.check_output(
["git", "diff"], cwd=dir).decode("utf-8")
- # Ask gpt-3.5-16k to write a commit message,
+ # Ask the LLM to write a commit message,
# and set it as the description of this step
- self.description = await sdk.models.gpt3516k.complete(
+ self.description = await sdk.models.default.complete(
f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:")
@@ -47,9 +50,10 @@ config = ContinueConfig(
# See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry
allow_anonymous_telemetry=True,
- # GPT-4 is recommended for best results
- # See options here: https://continue.dev/docs/customization#change-the-default-llm
- default_model="gpt-4",
+ models=Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ medium=MaybeProxyOpenAI(model="gpt-3.5-turbo")
+ ),
# Set a system message with information that the LLM should always keep in mind
# E.g. "Please give concise answers. Always respond in Spanish."
@@ -114,5 +118,9 @@ config = ContinueConfig(
# GoogleContextProvider(
# serper_api_key="<your serper.dev api key>"
# )
- ]
+ ],
+
+ # Policies hold the main logic that decides which Step to take next
+ # You can use them to design agents, or deeply customize Continue
+ policy=DefaultPolicy()
)
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 2766db4b..50577993 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,14 +1,30 @@
-from abc import ABC
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from abc import ABC, abstractproperty
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
from ...core.main import ChatMessage
-from ...models.main import AbstractModel
-from pydantic import BaseModel
+from ...models.main import ContinueBaseModel
-class LLM(ABC):
+class LLM(ContinueBaseModel, ABC):
+ requires_api_key: Optional[str] = None
+ requires_unique_id: bool = False
+ requires_write_log: bool = False
+
system_message: Union[str, None] = None
+ @abstractproperty
+ def name(self):
+ """Return the name of the LLM."""
+ raise NotImplementedError
+
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ """Start the connection to the LLM."""
+ raise NotImplementedError
+
+ async def stop(self):
+ """Stop the connection to the LLM."""
+ raise NotImplementedError
+
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
@@ -24,3 +40,8 @@ class LLM(ABC):
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
raise NotImplementedError
+
+ @abstractproperty
+ def context_length(self) -> int:
+ """Return the context length of the LLM in tokens, as counted by count_tokens."""
+ raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 625d4e57..ec1b7e40 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,32 +1,39 @@
from functools import cached_property
import time
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
class AnthropicLLM(LLM):
- api_key: str
- default_model: str
- async_client: AsyncAnthropic
+ model: str = "claude-2"
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
- self.api_key = api_key
- self.default_model = default_model
+ requires_api_key: str = "ANTHROPIC_API_KEY"
+ _async_client: AsyncAnthropic = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __init__(self, model: str, system_message: str = None):
+ self.model = model
self.system_message = system_message
- self.async_client = AsyncAnthropic(api_key=api_key)
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self._async_client = AsyncAnthropic(api_key=api_key)
+
+ async def stop(self):
+ pass
@cached_property
def name(self):
- return self.default_model
+ return self.model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
args = args.copy()
@@ -40,7 +47,13 @@ class AnthropicLLM(LLM):
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
+
+ @property
+ def context_length(self):
+ if self.model == "claude-2":
+ return 100000
+ raise Exception(f"Unknown Anthropic model {self.model}")
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -60,7 +73,7 @@ class AnthropicLLM(LLM):
args["stream"] = True
args = self._transform_args(args)
- async for chunk in await self.async_client.completions.create(
+ async for chunk in await self._async_client.completions.create(
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**args
):
@@ -73,8 +86,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
- async for chunk in await self.async_client.completions.create(
+ args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
+ async for chunk in await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
):
@@ -88,8 +101,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
- resp = (await self.async_client.completions.create(
+ args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
+ resp = (await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
)).completion
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 4889a556..7742e8c3 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,6 +1,7 @@
from functools import cached_property
import json
from typing import Any, Coroutine, Dict, Generator, List, Union
+from pydantic import ConfigDict
import aiohttp
from ...core.main import ChatMessage
@@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000"
class GGML(LLM):
+ # this is model-specific
+ max_context_length: int = 2048
- def __init__(self, system_message: str = None):
- self.system_message = system_message
+ _client_session: aiohttp.ClientSession = None
- @cached_property
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession()
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
def name(self):
return "ggml"
@property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
@@ -33,54 +48,51 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
-
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": messages,
- **args
- }) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
- "messages": messages,
- **args
- }) as resp:
- # This is streaming application/json instaed of text/event-stream
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
- continue
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
- except:
- raise Exception(str(line[0]))
+ async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
- **args
- }) as resp:
- try:
- return await resp.text()
- except:
- raise Exception(await resp.text())
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 36f03270..49f593d8 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from ...core.main import ChatMessage
from ..llm import LLM
import requests
@@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
- api_key: str
model: str
- def __init__(self, api_key: str, model: str, system_message: str = None):
- self.api_key = api_key
+ requires_api_key: str = "HUGGING_FACE_TOKEN"
+ api_key: str = None
+
+ def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self.api_key = api_key
+
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
new file mode 100644
index 00000000..edf58fd7
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -0,0 +1,53 @@
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable
+
+from ...core.main import ChatMessage
+from . import LLM
+from .proxy_server import ProxyServer
+from .openai import OpenAI
+
+
+class MaybeProxyOpenAI(LLM):
+ model: str
+
+ requires_api_key: Optional[str] = "OPENAI_API_KEY"
+ requires_write_log: bool = True
+ requires_unique_id: bool = True
+ system_message: Union[str, None] = None
+
+ llm: Optional[LLM] = None
+
+ @property
+ def name(self):
+ return self.llm.name
+
+ @property
+ def context_length(self):
+ return self.llm.context_length
+
+ async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]):
+ if api_key is None or api_key.strip() == "":
+ self.llm = ProxyServer(model=self.model)
+ else:
+ self.llm = OpenAI(model=self.model)
+
+ await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id)
+
+ async def stop(self):
+ await self.llm.stop()
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ return await self.llm.complete(prompt, with_history=with_history, **kwargs)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_complete(
+ prompt, with_history=with_history, **kwargs)
+ async for item in resp:
+ yield item
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_chat(messages=messages, **kwargs)
+ async for item in resp:
+ yield item
+
+ def count_tokens(self, text: str):
+ return self.llm.count_tokens(text)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 654c7326..fce6e8ab 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,48 +1,83 @@
from functools import cached_property
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
-from ...core.main import ChatMessage
+from pydantic import BaseModel
import openai
+
+from ...core.main import ChatMessage
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
-from ...core.config import OpenAIServerInfo
+
+
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
+
+
+CHAT_MODELS = {
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
+}
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
+
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
class OpenAI(LLM):
- api_key: str
- default_model: str
+ model: str
+ openai_server_info: Optional[OpenAIServerInfo] = None
- def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None):
- self.api_key = api_key
- self.default_model = default_model
- self.system_message = system_message
- self.openai_server_info = openai_server_info
+ requires_api_key = "OPENAI_API_KEY"
+ requires_write_log = True
+
+ system_message: Optional[str] = None
+ azure_info: Optional[AzureInfo] = None
+ write_log: Optional[Callable[[str], None]] = None
+ api_key: str = None
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):
self.write_log = write_log
+ self.api_key = api_key
+ openai.api_key = self.api_key
- openai.api_key = api_key
+ if self.openai_server_info is not None:
+ openai.api_type = self.openai_server_info.api_type
+ if self.openai_server_info.api_base is not None:
+ openai.api_base = self.openai_server_info.api_base
+ if self.openai_server_info.api_version is not None:
+ openai.api_version = self.openai_server_info.api_version
- # Using an Azure OpenAI deployment
- if openai_server_info is not None:
- openai.api_type = openai_server_info.api_type
- if openai_server_info.api_base is not None:
- openai.api_base = openai_server_info.api_base
- if openai_server_info.api_version is not None:
- openai.api_version = openai_server_info.api_version
+ async def stop(self):
+ pass
- @cached_property
+ @property
def name(self):
- return self.default_model
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- args = {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.model}
if self.openai_server_info is not None:
args["engine"] = self.openai_server_info.engine
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -51,7 +86,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -62,7 +97,7 @@ class OpenAI(LLM):
yield chunk.choices[0].delta.content
completion += chunk.choices[0].delta.content
else:
- continue
+ continue # :)
self.write_log(f"Completion: \n\n{completion}")
else:
@@ -78,12 +113,13 @@ class OpenAI(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
- args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"
+ # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
+ args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -100,7 +136,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
messages=messages,
@@ -109,7 +145,7 @@ class OpenAI(LLM):
self.write_log(f"Completion: \n\n{resp}")
else:
prompt = prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"])
+ args["model"], self.context_length, prompt, args["max_tokens"])
self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
prompt=prompt,
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index f9e3fa01..1a48f213 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,7 +1,6 @@
-
import json
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
@@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
class ProxyServer(LLM):
- unique_id: str
- name: str
- default_model: Literal["gpt-3.5-turbo", "gpt-4"]
- write_log: Callable[[str], None]
+ model: str
+ system_message: Optional[str]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
- self.unique_id = unique_id
- self.default_model = default_model
- self.system_message = system_message
- self.name = default_model
+ unique_id: str = None
+ write_log: Callable[[str], None] = None
+ _client_session: aiohttp.ClientSession
+
+ requires_unique_id = True
+ requires_write_log = True
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context))
self.write_log = write_log
+ self.unique_id = unique_id
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
def get_headers(self):
# headers with unique id
@@ -45,75 +69,72 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- if resp.status != 200:
- raise Exception(await resp.text())
-
- response_text = await resp.text()
- self.write_log(f"Completion: \n\n{response_text}")
- return response_text
+ async with self._client_session.post(f"{SERVER_URL}/complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ if resp.status != 200:
+ raise Exception(await resp.text())
+
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_chat", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- # This is streaming application/json instaed of text/event-stream
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- json_chunk = "{}" if json_chunk == "" else json_chunk
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- loaded_chunk = json.loads(chunk)
- yield loaded_chunk
- if "content" in loaded_chunk:
- completion += loaded_chunk["content"]
- except Exception as e:
- posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
- "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
- else:
- break
-
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ json_chunk = "{}" if json_chunk == "" else json_chunk
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
+ except Exception as e:
+ posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
+ "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
+ else:
+ break
+
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_any():
- if line:
- try:
- decoded_line = line.decode("utf-8")
- yield decoded_line
- completion += decoded_line
- except:
- raise Exception(str(line))
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
+ except:
+ raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")
diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py
deleted file mode 100644
index 76240d4e..00000000
--- a/continuedev/src/continuedev/libs/llm/utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from transformers import GPT2TokenizerFast
-
-gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
-def count_tokens(text: str) -> int:
- return len(gpt2_tokenizer.encode(text))
-
-prices = {
- # All prices are per 1k tokens
- "fine-tune-train": {
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
- },
- "completion": {
- "davinci": 0.02,
- "curie": 0.002,
- "babbage": 0.0005,
- "ada": 0.0004,
- },
- "fine-tune-completion": {
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
- },
- "embedding": {
- "ada": 0.0004
- }
-}
-
-def get_price(text: str, model: str="davinci", task: str="completion") -> float:
- return count_tokens(text) * prices[task][model] / 1000 \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py
index ff0a135f..3e82bab3 100644
--- a/continuedev/src/continuedev/libs/util/calculate_diff.py
+++ b/continuedev/src/continuedev/libs/util/calculate_diff.py
@@ -92,7 +92,7 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit
tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index]
replacement = updated[j1:j2]
if tag == "equal":
- continue
+ continue # ;)
elif tag == "delete":
edits.append(FileEdit.from_deletion(
filepath, Range.from_indices(original, i1, i2)))
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index c58ae499..6add7b1a 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -2,43 +2,47 @@ import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
from .templating import render_templated_string
+from ...libs.llm import LLM
import tiktoken
+# TODO move many of these into specific LLM.properties() function that
+# contains max tokens, if its a chat model or not, default args (not all models
+# want to be run at 0.5 temp). also lets custom models made for long contexts
+# exist here (likg LLongMA)
aliases = {
"ggml": "gpt-3.5-turbo",
"claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
-MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192,
- "ggml": 2048,
- "claude-2": 100000
-}
-CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
-}
DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0}
-def encoding_for_model(model: str):
- return tiktoken.encoding_for_model(aliases.get(model, model))
+def encoding_for_model(model_name: str):
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
-def count_tokens(model: str, text: Union[str, None]):
+def count_tokens(model_name: str, text: Union[str, None]):
if text is None:
return 0
- encoding = encoding_for_model(model)
+ encoding = encoding_for_model(model_name)
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(
- model, DEFAULT_MAX_TOKENS) - tokens_for_completion
- encoding = encoding_for_model(model)
+def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE
+
+
+def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int):
+ max_tokens = context_length - tokens_for_completion
+ encoding = encoding_for_model(model_name)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
return prompt
@@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in
return encoding.decode(tokens[-max_tokens:])
-def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int:
- # Doing simpler, safer version of what is here:
- # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- # every message follows <|start|>{role/name}\n{content}<|end|>\n
- TOKENS_PER_MESSAGE = 4
- return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE
-
-
-def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
- sum(count_chat_message_tokens(model, message)
+ sum(count_chat_message_tokens(model_name, message)
for message in chat_history)
# 1. Replace beyond last 5 messages with summary
i = 0
- while total_tokens > max_tokens and i < len(chat_history) - 5:
+ while total_tokens > context_length and i < len(chat_history) - 5:
message = chat_history[0]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 2. Remove entire messages until the last 5
- while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:
+ while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 3. Truncate message in the last 5, except last 1
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1:
+ while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:
message = chat_history[i]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 4. Remove entire messages in the last 5, except last 1
- while total_tokens > max_tokens and len(chat_history) > 1:
+ while total_tokens > context_length and len(chat_history) > 1:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 5. Truncate last message
- if total_tokens > max_tokens and len(chat_history) > 0:
+ if total_tokens > context_length and len(chat_history) > 0:
message = chat_history[0]
message.content = prune_raw_prompt_from_top(
- model, message.content, tokens_for_completion)
- total_tokens = max_tokens
+ model_name, context_length, message.content, tokens_for_completion)
+ total_tokens = context_length
return chat_history
@@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
TOKEN_BUFFER_FOR_SAFETY = 100
-def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
@@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_
function_tokens = 0
if functions is not None:
for function in functions:
- function_tokens += count_tokens(model, json.dumps(function))
+ function_tokens += count_tokens(model_name, json.dumps(function))
msgs_copy = prune_chat_history(
- model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+ model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
history = [msg.to_dict(with_functions=functions is not None)
for msg in msgs_copy]
diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py
index f1fb8d0b..285c1e47 100644
--- a/continuedev/src/continuedev/libs/util/strings.py
+++ b/continuedev/src/continuedev/libs/util/strings.py
@@ -12,7 +12,7 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]:
for i in range(1, len(lines)):
# Empty lines are wildcards
if lines[i].strip() == "":
- continue
+ continue # hey that's us!
# Iterate through the leading whitespace characters of the current line
for j in range(0, len(lcp)):
# If it doesn't have the same whitespace as lcp, then update lcp
diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py
index 51869fdd..2166bc37 100644
--- a/continuedev/src/continuedev/models/generate_json_schema.py
+++ b/continuedev/src/continuedev/models/generate_json_schema.py
@@ -39,7 +39,7 @@ def main():
json = schema_json_of(model, indent=2, title=title)
except Exception as e:
print(f"Failed to generate json schema for {title}: {e}")
- continue
+ continue # pun intended
with open(f"{SCHEMA_DIR}/{title}.json", "w") as f:
f.write(json)
diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
new file mode 100644
index 00000000..42d1f754
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py
@@ -0,0 +1,79 @@
+import os
+from typing import List, Optional
+import uuid
+from pydantic import BaseModel
+
+from ...core.main import ContextItemId
+from ...core.context import ContextProvider
+from ...core.main import ContextItem, ContextItemDescription, ContextItemId
+from ...libs.chroma.query import ChromaIndexManager
+from .util import remove_meilisearch_disallowed_chars
+
+
+class EmbeddingResult(BaseModel):
+ filename: str
+ content: str
+
+
+class EmbeddingsProvider(ContextProvider):
+ title = "embed"
+
+ workspace_directory: str
+
+ EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings"
+
+ index_manager: Optional[ChromaIndexManager] = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ @property
+ def index(self):
+ if self.index_manager is None:
+ self.index_manager = ChromaIndexManager(self.workspace_directory)
+ return self.index_manager
+
+ @property
+ def BASE_CONTEXT_ITEM(self):
+ return ContextItem(
+ content="",
+ description=ContextItemDescription(
+ name="Embedding Search",
+ description="Enter a query to embedding search codebase",
+ id=ContextItemId(
+ provider_title=self.title,
+ item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID
+ )
+ )
+ )
+
+ async def _get_query_results(self, query: str) -> str:
+ results = self.index.query_codebase_index(query)
+
+ ret = []
+ for node in results.source_nodes:
+ resource_name = list(node.node.relationships.values())[0]
+ filepath = resource_name[:resource_name.index("::")]
+ ret.append(EmbeddingResult(
+ filename=filepath, content=node.node.text))
+
+ return ret
+
+ async def provide_context_items(self) -> List[ContextItem]:
+ self.index.create_codebase_index() # TODO Synchronous here is not ideal
+
+ return [self.BASE_CONTEXT_ITEM]
+
+ async def add_context_item(self, id: ContextItemId, query: str):
+ if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID:
+ raise Exception("Invalid item id")
+
+ results = await self._get_query_results(query)
+
+ for i in range(len(results)):
+ result = results[i]
+ ctx_item = self.BASE_CONTEXT_ITEM.copy()
+ ctx_item.description.name = os.path.basename(result.filename)
+ ctx_item.content = f"{result.filename}\n```\n{result.content}\n```"
+ ctx_item.description.id.item_id = uuid.uuid4().hex
+ self.selected_items.append(ctx_item)
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 634774df..31aa5423 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -3,6 +3,7 @@ import re
from typing import List
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...core.context import ContextProvider
+from .util import remove_meilisearch_disallowed_chars
from fnmatch import fnmatch
@@ -79,7 +80,7 @@ class FileContextProvider(ContextProvider):
description=file,
id=ContextItemId(
provider_title=self.title,
- item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file)
+ item_id=remove_meilisearch_disallowed_chars(file)
)
)
))
diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py
index fc76fe67..4b0a59ec 100644
--- a/continuedev/src/continuedev/plugins/context_providers/google.py
+++ b/continuedev/src/continuedev/plugins/context_providers/google.py
@@ -2,6 +2,7 @@ import json
from typing import List
import aiohttp
+from .util import remove_meilisearch_disallowed_chars
from ...core.main import ContextItem, ContextItemDescription, ContextItemId
from ...core.context import ContextProvider
@@ -60,5 +61,6 @@ class GoogleContextProvider(ContextProvider):
ctx_item = self.BASE_CONTEXT_ITEM.copy()
ctx_item.content = content
- ctx_item.description.id.item_id = query
+ ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(
+ query)
return ctx_item
diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py
new file mode 100644
index 00000000..da2e6b17
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/context_providers/util.py
@@ -0,0 +1,5 @@
+import re
+
+
+def remove_meilisearch_disallowed_chars(id: str) -> str:
+ return re.sub(r'[^0-9a-zA-Z_-]', '', id)
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/plugins/policies/default.py
index d90177b5..523c2cf4 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/plugins/policies/default.py
@@ -1,15 +1,15 @@
from textwrap import dedent
from typing import Union
-from ..plugins.steps.chat import SimpleChatStep
-from ..plugins.steps.welcome import WelcomeStep
-from .config import ContinueConfig
-from ..plugins.steps.steps_on_startup import StepsOnStartupStep
-from .main import Step, History, Policy
-from .observation import UserInputObservation
-from ..plugins.steps.core.core import MessageStep
-from ..plugins.steps.custom_command import CustomCommandStep
-from ..plugins.steps.main import EditHighlightedCodeStep
+from ..steps.chat import SimpleChatStep
+from ..steps.welcome import WelcomeStep
+from ...core.config import ContinueConfig
+from ..steps.steps_on_startup import StepsOnStartupStep
+from ...core.main import Step, History, Policy
+from ...core.observation import UserInputObservation
+from ..steps.core.core import MessageStep
+from ..steps.custom_command import CustomCommandStep
+from ..steps.main import EditHighlightedCodeStep
def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
@@ -45,7 +45,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:
class DefaultPolicy(Policy):
- ran_code_last: bool = False
+
+ default_step: Step = SimpleChatStep()
def next(self, config: ContinueConfig, history: History) -> Step:
# At the very start, run initial Steps spcecified in the config
@@ -56,7 +57,6 @@ class DefaultPolicy(Policy):
- Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
- Use `/help` to ask questions about how to use Continue""")) >>
WelcomeStep() >>
- # CreateCodebaseIndexChroma() >>
StepsOnStartupStep())
observation = history.get_current().observation
@@ -75,6 +75,6 @@ class DefaultPolicy(Policy):
if user_input.startswith("/edit"):
return EditHighlightedCodeStep(user_input=user_input[5:])
- return SimpleChatStep()
+ return self.default_step.copy()
return None
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index 433e309e..872f8d62 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -27,7 +27,7 @@ class SetupPipelineStep(Step):
async def run(self, sdk: ContinueSDK):
sdk.context.set("api_description", self.api_description)
- source_name = (await sdk.models.gpt35.complete(
+ source_name = (await sdk.models.medium.complete(
f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()
filename = f'{source_name}.py'
@@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{file_content}
```
@@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):
This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:"""))
- api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\
+ api_documentation_url = await sdk.models.medium.complete(dedent(f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
```python
{file_content}
@@ -151,7 +151,7 @@ class RunQueryStep(Step):
output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}
```
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index 6ef5ffd6..c66cd629 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):
"{self.user_input}"
Here is a complete set of pytest unit tests:""")
- tests = await sdk.models.gpt35.complete(prompt)
+ tests = await sdk.models.medium.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md
index 12073835..3f2f804c 100644
--- a/continuedev/src/continuedev/plugins/steps/README.md
+++ b/continuedev/src/continuedev/plugins/steps/README.md
@@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
Here's an example:
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index f72a8ec0..455d5a13 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep
from ...core.main import FunctionCall, Models
from ...core.main import ChatMessage, Step, step_to_json_schema
from ...core.sdk import ContinueSDK
+from ...libs.llm.openai import OpenAI
import openai
import os
from dotenv import load_dotenv
@@ -41,7 +42,7 @@ class SimpleChatStep(Step):
self.description += chunk["content"]
await sdk.update_ui()
- self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
f"Write a short title for the following chat message: {self.description}"))
self.chat_context.append(ChatMessage(
@@ -166,7 +167,10 @@ class ChatWithFunctions(Step):
msg_content = ""
msg_step = None
- async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
+ gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(gpt350613)
+
+ async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
if sdk.current_step_was_deleted():
return
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index dbe8363e..658cc7f3 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = await sdk.models.gpt35.complete(prompt)
+ answer = await sdk.models.medium.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index c80cecc3..4476c7ae 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -11,11 +11,12 @@ from pydantic import validator
from ....libs.llm.ggml import GGML
from ....models.main import Range
+from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
-from ....core.observation import Observation, TextObservation, UserInputObservation
-from ....core.main import ChatMessage, ContinueCustomException, Step
-from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
+from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
+from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
+from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
@@ -97,7 +98,7 @@ class ShellCommandsStep(Step):
return f"Error when running shell commands:\n```\n{self._err_text}\n```"
cmds_str = "\n".join(self.cmds)
- return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
+ return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd
@@ -105,7 +106,7 @@ class ShellCommandsStep(Step):
for cmd in self.cmds:
output = await sdk.ide.runCommand(cmd)
if self.handle_error and output is not None and output_contains_error(output):
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
While running the command `{cmd}`, the following error occurred:
```ascii
@@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step):
else:
changes = '\n'.join(difflib.ndiff(
self._previous_contents.splitlines(), self._new_contents.splitlines()))
- description = await models.gpt3516k.complete(dedent(f"""\
+ description = await models.medium.complete(dedent(f"""\
Diff summary: "{self.user_input}"
```diff
@@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step):
```
Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))
- name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
+ name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
self.name = remove_quotes_and_escapes(name)
return f"{remove_quotes_and_escapes(description)}"
@@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.default
- max_tokens = int(MAX_TOKENS_FOR_MODEL.get(
- model_to_use.name, DEFAULT_MAX_TOKENS) / 2)
+ max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step):
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
- if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
- model_to_use = sdk.models.gpt3516k
+ if total_tokens > model_to_use.context_length:
+ model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(model_to_use)
# Remove tokens from the end first, and then the start to clear space
# This part finds the start and end lines
@@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step):
cur_start_line = 0
cur_end_line = len(full_file_contents_lst) - 1
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_end_line > min_end_line:
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_end_line])
cur_end_line -= 1
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_start_line < max_start_line:
cur_start_line += 1
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_start_line])
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
# Now use the found start/end lines to get the prefix and suffix strings
@@ -525,7 +526,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user
# Accumulate lines
if "content" not in chunk:
- continue
+ continue # ayo
chunk = chunk["content"]
chunk_lines = chunk.split("\n")
chunk_lines[0] = unfinished_line + chunk_lines[0]
@@ -546,12 +547,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user
break
# Lines that should be ignored, like the <> tags
elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0):
- continue
+ continue # noice
# Check if we are currently just copying the prefix
elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]:
# This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line
lines_of_prefix_copied += 1
- continue
+ continue # also nice
# Because really short lines might be expected to be repeated, this is only a !heuristic!
# Stop when it starts copying the file_suffix
elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()):
diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py
index a76d491b..c38f54dc 100644
--- a/continuedev/src/continuedev/plugins/steps/draft/migration.py
+++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py
@@ -13,7 +13,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run([
"cd libs",
"poetry run alembic revision --autogenerate -m " + description,
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index 6997a547..ec670999 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -56,7 +56,7 @@ class HelpStep(Step):
summary="Help"
))
messages = await sdk.get_chat_context()
- generator = sdk.models.gpt4.stream_chat(messages)
+ generator = sdk.models.default.stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index b54d394a..3d8d96fb 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -23,6 +23,6 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.gpt35.complete(
+ gpt_parsed = await sdk.models.default.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index f28d9660..d2d6f4dd 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.gpt35.complete(prompt)
+ completion = await sdk.models.medium.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step):
_prompt_and_completion: str = ""
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
+ return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
range_in_files = await sdk.get_code_context(only_editing=True)
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index 8b2e7c2e..da6acdbf 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -27,7 +27,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""")
- resp = (await sdk.models.gpt35.complete(prompt)).lower()
+ resp = (await sdk.models.medium.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 07b50473..456dba84 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -20,10 +20,10 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]:
for root, dirs, files in os.walk(dirpath):
dirname = os.path.basename(root)
if dirname.startswith(".") or dirname in IGNORE_DIRS:
- continue
+ continue # continue!
for file in files:
if file in IGNORE_FILES:
- continue
+ continue # pun intended
with open(os.path.join(root, file), "r") as f:
# Find the index of all occurences of the pattern in the file. Use re.
file_content = f.read()
@@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.gpt35.complete(dedent(f"""\
+ pattern = await sdk.models.medium.complete(dedent(f"""\
This is the user request:
{self.user_request}
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 98a5aea0..cf18c56b 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -176,7 +176,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
message = json.loads(message)
if "messageType" not in message or "data" not in message:
- continue
+ continue # :o
message_type = message["messageType"]
data = message["data"]
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e4c07029..6124f3bd 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -139,7 +139,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
msg_string = await self.websocket.receive_text()
message = json.loads(msg_string)
if "messageType" not in message or "data" not in message:
- continue
+ continue # <-- hey that's the name of this repo!
message_type = message["messageType"]
data = message["data"]
logger.debug(f"Received message while initializing {message_type}")
@@ -311,7 +311,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onFileEdits(self, edits: List[FileEditWithFullContents]):
if autopilot := self.__get_autopilot():
- autopilot.handle_manual_edits(edits)
+ pass
def onDeleteAtIndex(self, index: int):
if autopilot := self.__get_autopilot():
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index cf46028f..b5580fe8 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,4 +1,5 @@
import os
+import traceback
from fastapi import WebSocket
from typing import Any, Dict, List, Union
from uuid import uuid4
@@ -6,12 +7,10 @@ import json
from fastapi.websockets import WebSocketState
-from ..plugins.steps.core.core import DisplayErrorStep
+from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
-from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
-from ..core.policy import DefaultPolicy
-from ..core.main import FullState
+from ..core.main import FullState, HistoryNode
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
@@ -31,19 +30,6 @@ class Session:
self.ws = None
-class DemoAutopilot(Autopilot):
- first_seen: bool = False
- cumulative_edit_string = ""
-
- def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
- return
- for edit in edits:
- self.cumulative_edit_string += edit.fileEdit.replacement
- self._manual_edits_buffer.append(edit)
- # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
- # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
-
-
class SessionManager:
sessions: Dict[str, Session] = {}
# Mapping of session_id to IDE, where the IDE is still alive
@@ -65,27 +51,47 @@ class SessionManager:
async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
logger.debug(f"New session: {session_id}")
+ # Load the persisted state (not being used right now)
full_state = None
if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
with open(getSessionFilePath(session_id), "r") as f:
full_state = FullState(**json.load(f))
- autopilot = await DemoAutopilot.create(
- policy=DefaultPolicy(), ide=ide, full_state=full_state)
+ # Register the session and ide (do this first so that the autopilot can access the session)
+ autopilot = Autopilot(ide=ide)
session_id = session_id or str(uuid4())
ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
self.sessions[session_id] = session
self.registered_ides[session_id] = ide
+ # Set up the autopilot to update the GUI
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
"state": state.dict()
})
autopilot.on_update(on_update)
- create_async_task(autopilot.run_policy(
- ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)))
+
+ # Start the autopilot (must be after session is added to sessions) and the policy
+ try:
+ await autopilot.start()
+ except Exception as e:
+ # Have to manually add to history because autopilot isn't started
+ formatted_err = '\n'.join(traceback.format_exception(e))
+ msg_step = MessageStep(
+ name="Error loading context manager", message=formatted_err)
+ msg_step.description = f"```\n{formatted_err}\n```"
+ autopilot.history.add_node(HistoryNode(
+ step=msg_step,
+ observation=None,
+ depth=0,
+ active=False
+ ))
+ logger.warning(f"Error loading context manager: {e}")
+
+ create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step(
+ DisplayErrorStep(e=e)))
return session
async def remove_session(self, session_id: str):