diff options
Diffstat (limited to 'continuedev')
43 files changed, 675 insertions, 496 deletions
diff --git a/continuedev/README.md b/continuedev/README.md index d3ead8ec..6a11ae43 100644 --- a/continuedev/README.md +++ b/continuedev/README.md @@ -67,9 +67,9 @@ cd continue/extension/scripts && python3 install_from_source.py # Understanding the codebase -- [Continue Server README](./continuedev/README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/) -- [VS Code Extension README](./extension/README.md): learn about the capabilities of our extension—the first implementation of Continue's IDE Protocol—which makes it possible to use use Continue in VS Code and GitHub Codespaces -- [Continue GUI README](./extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any suppported IDE -- [Schema README](./schema): learn about the JSON Schema types generated from Pydantic models, which we use across the `continuedev/` and `extension/` directories -- [Continue Docs README](./docs): learn how our [docs](https://continue.dev/docs) are written and built -- [How to debug the VS Code Extension README](./extension/src/README.md): learn how to set up the VS Code extension, so you can debug it +- [Continue Server README](./README.md): learn about the core of Continue, which can be downloaded as a [PyPI package](https://pypi.org/project/continuedev/) +- [VS Code Extension README](../extension/README.md): learn about the capabilities of our extension—the first implementation of Continue's IDE Protocol—which makes it possible to use use Continue in VS Code and GitHub Codespaces +- [Continue GUI README](../extension/react-app/): learn about the React app that lets users interact with the server and is placed adjacent to the text editor in any suppported IDE +- [Schema README](../schema/README.md): learn about the JSON Schema types generated from Pydantic models, which we use across the `continuedev/` and `extension/` directories +- [Continue Docs README](../docs/README.md): learn how our [docs](https://continue.dev/docs) are written and built +- [How to debug the VS Code Extension README](../extension/src/README.md): learn how to set up the VS Code extension, so you can debug it diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 94d7be10..e048f877 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC): pass @abstractmethod - async def get_user_secret(self, env_var: str, prompt: str) -> str: + async def get_user_secret(self, env_var: str) -> str: pass config: ContinueConfig diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 57e39d5c..de95a259 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation from .context import ContextManager +from ..plugins.policies.default import DefaultPolicy from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..server.ide_protocol import AbstractIdeProtocolServer @@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str: class Autopilot(ContinueBaseModel): - policy: Policy ide: AbstractIdeProtocolServer + + policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None @@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @classmethod - async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": - autopilot = cls(ide=ide, policy=policy) - autopilot.continue_sdk = await ContinueSDK.create(autopilot) + async def start(self): + self.continue_sdk = await ContinueSDK.create(self) + if override_policy := self.continue_sdk.config.policy_override: + self.policy = override_policy # Load documents into the search index - autopilot.context_manager = await ContextManager.create( - autopilot.continue_sdk.config.context_providers + [ - HighlightedCodeContextProvider(ide=ide), - FileContextProvider(workspace_dir=ide.workspace_directory) + self.context_manager = await ContextManager.create( + self.continue_sdk.config.context_providers + [ + HighlightedCodeContextProvider(ide=self.ide), + FileContextProvider(workspace_dir=self.ide.workspace_directory) ]) - await autopilot.context_manager.load_index(ide.workspace_directory) - return autopilot + await self.context_manager.load_index(self.ide.workspace_directory) class Config: arbitrary_types_allowed = True @@ -95,7 +96,6 @@ class Autopilot(ContinueBaseModel): history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, - default_model=self.continue_sdk.config.default_model, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ "code"].adding_highlighted_code, diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 9fbda824..84b6b10b 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,9 +2,13 @@ import json import os from .main import Step from .context import ContextProvider +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from .models import Models from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict, Type, Union -import yaml +from typing import List, Literal, Optional, Dict, Type + +from .main import Policy, Step +from .context import ContextProvider class SlashCommand(BaseModel): @@ -25,13 +29,6 @@ class OnTracebackSteps(BaseModel): params: Optional[Dict] = {} -class OpenAIServerInfo(BaseModel): - api_base: Optional[str] = None - engine: Optional[str] = None - api_version: Optional[str] = None - api_type: Literal["azure", "openai"] = "openai" - - class ContinueConfig(BaseModel): """ A pydantic class for the continue config file. @@ -39,8 +36,9 @@ class ContinueConfig(BaseModel): steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True - default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4", "claude-2", "ggml"] = 'gpt-4' + models: Models = Models( + default=MaybeProxyOpenAI(model="gpt-4"), + ) temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", @@ -50,7 +48,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None - openai_server_info: Optional[OpenAIServerInfo] = None + policy_override: Optional[Policy] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 86522ce1..e968c35c 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -178,12 +178,6 @@ class ContextManager: except Exception as e: logger.debug(f"Error loading meilisearch index: {e}") - # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: - # """ - # Compiles the chat prompt into a single string. - # """ - # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) - async def select_context_item(self, id: str, query: str): """ Selects the ContextItem with the given id. diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index df9b98ef..2553850f 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -258,7 +258,6 @@ class FullState(ContinueBaseModel): history: History active: bool user_input_queue: List[str] - default_model: str slash_commands: List[SlashCommandDescription] adding_highlighted_code: bool selected_context_items: List[ContextItem] diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py new file mode 100644 index 00000000..900762b6 --- /dev/null +++ b/continuedev/src/continuedev/core/models.py @@ -0,0 +1,65 @@ +from typing import Optional, Any +from pydantic import BaseModel, validator +from ..libs.llm import LLM + + +class Models(BaseModel): + """Main class that holds the current model configuration""" + default: LLM + small: Optional[LLM] = None + medium: Optional[LLM] = None + large: Optional[LLM] = None + + # TODO namespace these away to not confuse readers, + # or split Models into ModelsConfig, which gets turned into Models + sdk: "ContinueSDK" = None + system_message: Any = None + + """ + Better to have sdk.llm.stream_chat(messages, model="claude-2"). + Then you also don't care that it' async. + And it's easier to add more models. + And intermediate shared code is easier to add. + And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" + PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. + Easy to reason about, can place anywhere. + And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. + This can all happen inside of Models? + + class Prompt: + def __init__(self, ...info): + '''take whatever info is needed to describe the prompt''' + + def to_string(self, model: str) -> str: + '''depending on the model, return the single prompt string''' + """ + + async def start(self, sdk: "ContinueSDK"): + """Start each of the LLMs, or fall back to default""" + self.sdk = sdk + self.system_message = self.sdk.config.system_message + await sdk.start_model(self.default) + if self.small: + await sdk.start_model(self.small) + else: + self.small = self.default + + if self.medium: + await sdk.start_model(self.medium) + else: + self.medium = self.default + + if self.large: + await sdk.start_model(self.large) + else: + self.large = self.default + + async def stop(self, sdk: "ContinueSDK"): + """Stop each LLM (if it's not the default, which is shared)""" + await self.default.stop() + if self.small is not self.default: + await self.small.stop() + if self.medium is not self.default: + await self.medium.stop() + if self.large is not self.default: + await self.large.stop() diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 4b76a121..bf22d696 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory from ..models.filesystem import RangeInFile -from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI -from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import AnthropicLLM -from ..libs.llm.ggml import GGML +from ..libs.llm import LLM from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage from ..plugins.steps.core.core import * -from ..libs.llm.proxy_server import ProxyServer from ..libs.util.telemetry import posthog_logger from ..libs.util.paths import getConfigFilePath +from .models import Models from ..libs.util.logging import logger @@ -27,121 +24,6 @@ class Autopilot: pass -ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] -MODEL_PROVIDER_TO_ENV_VAR = { - "openai": "OPENAI_API_KEY", - "hf_inference_api": "HUGGING_FACE_TOKEN", - "anthropic": "ANTHROPIC_API_KEY", -} - - -class Models: - provider_keys: Dict[ModelProvider, str] = {} - model_providers: List[ModelProvider] - system_message: str - - """ - Better to have sdk.llm.stream_chat(messages, model="claude-2"). - Then you also don't care that it' async. - And it's easier to add more models. - And intermediate shared code is easier to add. - And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" - PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. - Easy to reason about, can place anywhere. - And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. - This can all happen inside of Models? - - class Prompt: - def __init__(self, ...info): - '''take whatever info is needed to describe the prompt''' - - def to_string(self, model: str) -> str: - '''depending on the model, return the single prompt string''' - """ - - def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): - self.sdk = sdk - self.model_providers = model_providers - self.system_message = sdk.config.system_message - - @classmethod - async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": - if sdk.config.default_model == "claude-2": - with_providers.append("anthropic") - - models = Models(sdk, with_providers) - for provider in with_providers: - if provider in MODEL_PROVIDER_TO_ENV_VAR: - env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] - models.provider_keys[provider] = await sdk.get_user_secret( - env_var, f'Please add your {env_var} to the .env file') - - return models - - def __load_openai_model(self, model: str) -> OpenAI: - api_key = self.provider_keys["openai"] - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) - return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log) - - def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: - api_key = self.provider_keys["hf_inference_api"] - return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) - - def __load_anthropic_model(self, model: str) -> AnthropicLLM: - api_key = self.provider_keys["anthropic"] - return AnthropicLLM(api_key, model, self.system_message) - - @cached_property - def claude2(self): - return self.__load_anthropic_model("claude-2") - - @cached_property - def starcoder(self): - return self.__load_hf_inference_api_model("bigcode/starcoder") - - @cached_property - def gpt35(self): - return self.__load_openai_model("gpt-3.5-turbo") - - @cached_property - def gpt350613(self): - return self.__load_openai_model("gpt-3.5-turbo-0613") - - @cached_property - def gpt3516k(self): - return self.__load_openai_model("gpt-3.5-turbo-16k") - - @cached_property - def gpt4(self): - return self.__load_openai_model("gpt-4") - - @cached_property - def ggml(self): - return GGML(system_message=self.system_message) - - def __model_from_name(self, model_name: str): - if model_name == "starcoder": - return self.starcoder - elif model_name == "gpt-3.5-turbo": - return self.gpt35 - elif model_name == "gpt-3.5-turbo-16k": - return self.gpt3516k - elif model_name == "gpt-4": - return self.gpt4 - elif model_name == "claude-2": - return self.claude2 - elif model_name == "ggml": - return self.ggml - else: - raise Exception(f"Unknown model {model_name}") - - @property - def default(self): - default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt4 - - class ContinueSDK(AbstractContinueSDK): """The SDK provided as parameters to a step""" ide: AbstractIdeProtocolServer @@ -171,7 +53,7 @@ class ContinueSDK(AbstractContinueSDK): formatted_err = '\n'.join(traceback.format_exception(e)) msg_step = MessageStep( name="Invalid Continue Config File", message=formatted_err) - msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```" + msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```\n\nIt's possible this error was caused by an update to the Continue config format. If you'd like to see the new recommended default `config.py`, check [here](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/constants/default_config.py.txt)." sdk.history.add_node(HistoryNode( step=msg_step, observation=None, @@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK): active=False )) + sdk.models = sdk.config.models + await sdk.models.start(sdk) + # When the config is loaded, setup posthog logger posthog_logger.setup( sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) - sdk.models = await Models.create(sdk) return sdk @property @@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK): def write_log(self, message: str): self.history.timeline[self.history.current_index].logs.append(message) + async def start_model(self, llm: LLM): + kwargs = {} + if llm.requires_api_key: + kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key) + if llm.requires_unique_id: + kwargs["unique_id"] = self.ide.unique_id + if llm.requires_write_log: + kwargs["write_log"] = self.write_log + await llm.start(**kwargs) + async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): return path @@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) - async def get_user_secret(self, env_var: str, prompt: str) -> str: + async def get_user_secret(self, env_var: str) -> str: + # TODO support error prompt dynamically set on env_var return await self.ide.getUserSecret(env_var) _last_valid_config: ContinueConfig = None diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index f09b813a..dba4874f 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -59,7 +59,7 @@ class ChromaIndexManager: except: logger.warning( f"ERROR (probably found special token): {doc.text}") - continue + continue # lol filename = doc.extra_info["filename"] chunks[filename] = len(text_chunks) for i, text in enumerate(text_chunks): diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index 23ed950f..d5326a06 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -23,7 +23,7 @@ def filter_ignored_files(files: List[str], root_dir: str): """Further filter files before indexing.""" for file in files: if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): - continue + continue # nice yield root_dir + "/" + file diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 1a66c847..cf8b0324 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -9,9 +9,12 @@ import subprocess from continuedev.core.main import Step from continuedev.core.sdk import ContinueSDK +from continuedev.core.models import Models from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider from continuedev.plugins.context_providers.google import GoogleContextProvider +from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from continuedev.plugins.policies.default import DefaultPolicy from continuedev.plugins.steps.open_config import OpenConfigStep from continuedev.plugins.steps.clear_history import ClearHistoryStep @@ -35,9 +38,9 @@ class CommitMessageStep(Step): diff = subprocess.check_output( ["git", "diff"], cwd=dir).decode("utf-8") - # Ask gpt-3.5-16k to write a commit message, + # Ask the LLM to write a commit message, # and set it as the description of this step - self.description = await sdk.models.gpt3516k.complete( + self.description = await sdk.models.default.complete( f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") @@ -47,9 +50,10 @@ config = ContinueConfig( # See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry allow_anonymous_telemetry=True, - # GPT-4 is recommended for best results - # See options here: https://continue.dev/docs/customization#change-the-default-llm - default_model="gpt-4", + models=Models( + default=MaybeProxyOpenAI(model="gpt-4"), + medium=MaybeProxyOpenAI(model="gpt-3.5-turbo") + ), # Set a system message with information that the LLM should always keep in mind # E.g. "Please give concise answers. Always respond in Spanish." @@ -114,5 +118,9 @@ config = ContinueConfig( # GoogleContextProvider( # serper_api_key="<your serper.dev api key>" # ) - ] + ], + + # Policies hold the main logic that decides which Step to take next + # You can use them to design agents, or deeply customize Continue + policy=DefaultPolicy() ) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2766db4b..50577993 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,14 +1,30 @@ -from abc import ABC -from typing import Any, Coroutine, Dict, Generator, List, Union +from abc import ABC, abstractproperty +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional from ...core.main import ChatMessage -from ...models.main import AbstractModel -from pydantic import BaseModel +from ...models.main import ContinueBaseModel -class LLM(ABC): +class LLM(ContinueBaseModel, ABC): + requires_api_key: Optional[str] = None + requires_unique_id: bool = False + requires_write_log: bool = False + system_message: Union[str, None] = None + @abstractproperty + def name(self): + """Return the name of the LLM.""" + raise NotImplementedError + + async def start(self, *, api_key: Optional[str] = None, **kwargs): + """Start the connection to the LLM.""" + raise NotImplementedError + + async def stop(self): + """Stop the connection to the LLM.""" + raise NotImplementedError + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError @@ -24,3 +40,8 @@ class LLM(ABC): def count_tokens(self, text: str): """Return the number of tokens in the given text.""" raise NotImplementedError + + @abstractproperty + def context_length(self) -> int: + """Return the context length of the LLM in tokens, as counted by count_tokens.""" + raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 625d4e57..ec1b7e40 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,32 +1,39 @@ from functools import cached_property import time -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens class AnthropicLLM(LLM): - api_key: str - default_model: str - async_client: AsyncAnthropic + model: str = "claude-2" - def __init__(self, api_key: str, default_model: str, system_message: str = None): - self.api_key = api_key - self.default_model = default_model + requires_api_key: str = "ANTHROPIC_API_KEY" + _async_client: AsyncAnthropic = None + + class Config: + arbitrary_types_allowed = True + + def __init__(self, model: str, system_message: str = None): + self.model = model self.system_message = system_message - self.async_client = AsyncAnthropic(api_key=api_key) + async def start(self, *, api_key: Optional[str] = None, **kwargs): + self._async_client = AsyncAnthropic(api_key=api_key) + + async def stop(self): + pass @cached_property def name(self): - return self.default_model + return self.model @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.model} def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: args = args.copy() @@ -40,7 +47,13 @@ class AnthropicLLM(LLM): return args def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) + + @property + def context_length(self): + if self.model == "claude-2": + return 100000 + raise Exception(f"Unknown Anthropic model {self.model}") def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -60,7 +73,7 @@ class AnthropicLLM(LLM): args["stream"] = True args = self._transform_args(args) - async for chunk in await self.async_client.completions.create( + async for chunk in await self._async_client.completions.create( prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", **args ): @@ -73,8 +86,8 @@ class AnthropicLLM(LLM): args = self._transform_args(args) messages = compile_chat_messages( - args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) - async for chunk in await self.async_client.completions.create( + args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + async for chunk in await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args ): @@ -88,8 +101,8 @@ class AnthropicLLM(LLM): args = self._transform_args(args) messages = compile_chat_messages( - args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) - resp = (await self.async_client.completions.create( + args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + resp = (await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args )).completion diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..7742e8c3 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,6 +1,7 @@ from functools import cached_property import json from typing import Any, Coroutine, Dict, Generator, List, Union +from pydantic import ConfigDict import aiohttp from ...core.main import ChatMessage @@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000" class GGML(LLM): + # this is model-specific + max_context_length: int = 2048 - def __init__(self, system_message: str = None): - self.system_message = system_message + _client_session: aiohttp.ClientSession = None - @cached_property + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession() + + async def stop(self): + await self._client_session.close() + + @property def name(self): return "ggml" @property + def context_length(self): + return self.max_context_length + + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -33,54 +48,51 @@ class GGML(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - - async with aiohttp.ClientSession() as session: - async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": messages, - **args - }) as resp: - async for line in resp.content.iter_any(): - if line: - try: - yield line.decode("utf-8") - except: - raise Exception(str(line)) + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + + async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": messages, + **args + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True - async with aiohttp.ClientSession() as session: - async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ - "messages": messages, - **args - }) as resp: - # This is streaming application/json instaed of text/event-stream - async for line in resp.content.iter_chunks(): - if line[1]: - try: - json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): - continue - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - yield json.loads(chunk[6:])["choices"][0]["delta"] - except: - raise Exception(str(line[0])) + async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={ + "messages": messages, + **args + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + raise Exception(str(line[0])) async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - async with aiohttp.ClientSession() as session: - async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), - **args - }) as resp: - try: - return await resp.text() - except: - raise Exception(await resp.text()) + async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + **args + }) as resp: + try: + return await resp.text() + except: + raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 36f03270..49f593d8 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from ...core.main import ChatMessage from ..llm import LLM import requests @@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): - api_key: str model: str - def __init__(self, api_key: str, model: str, system_message: str = None): - self.api_key = api_key + requires_api_key: str = "HUGGING_FACE_TOKEN" + api_key: str = None + + def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this + async def start(self, *, api_key: Optional[str] = None, **kwargs): + self.api_key = api_key + def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py new file mode 100644 index 00000000..edf58fd7 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -0,0 +1,53 @@ +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable + +from ...core.main import ChatMessage +from . import LLM +from .proxy_server import ProxyServer +from .openai import OpenAI + + +class MaybeProxyOpenAI(LLM): + model: str + + requires_api_key: Optional[str] = "OPENAI_API_KEY" + requires_write_log: bool = True + requires_unique_id: bool = True + system_message: Union[str, None] = None + + llm: Optional[LLM] = None + + @property + def name(self): + return self.llm.name + + @property + def context_length(self): + return self.llm.context_length + + async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): + if api_key is None or api_key.strip() == "": + self.llm = ProxyServer(model=self.model) + else: + self.llm = OpenAI(model=self.model) + + await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id) + + async def stop(self): + await self.llm.stop() + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + return await self.llm.complete(prompt, with_history=with_history, **kwargs) + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + resp = self.llm.stream_complete( + prompt, with_history=with_history, **kwargs) + async for item in resp: + yield item + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + resp = self.llm.stream_chat(messages=messages, **kwargs) + async for item in resp: + yield item + + def count_tokens(self, text: str): + return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 654c7326..fce6e8ab 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,48 +1,83 @@ from functools import cached_property import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional -from ...core.main import ChatMessage +from pydantic import BaseModel import openai + +from ...core.main import ChatMessage +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import OpenAIServerInfo + + +class OpenAIServerInfo(BaseModel): + api_base: Optional[str] = None + engine: Optional[str] = None + api_version: Optional[str] = None + api_type: Literal["azure", "openai"] = "openai" + + +CHAT_MODELS = { + "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" +} +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} + + +class AzureInfo(BaseModel): + endpoint: str + engine: str + api_version: str class OpenAI(LLM): - api_key: str - default_model: str + model: str + openai_server_info: Optional[OpenAIServerInfo] = None - def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None): - self.api_key = api_key - self.default_model = default_model - self.system_message = system_message - self.openai_server_info = openai_server_info + requires_api_key = "OPENAI_API_KEY" + requires_write_log = True + + system_message: Optional[str] = None + azure_info: Optional[AzureInfo] = None + write_log: Optional[Callable[[str], None]] = None + api_key: str = None + + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): self.write_log = write_log + self.api_key = api_key + openai.api_key = self.api_key - openai.api_key = api_key + if self.openai_server_info is not None: + openai.api_type = self.openai_server_info.api_type + if self.openai_server_info.api_base is not None: + openai.api_base = self.openai_server_info.api_base + if self.openai_server_info.api_version is not None: + openai.api_version = self.openai_server_info.api_version - # Using an Azure OpenAI deployment - if openai_server_info is not None: - openai.api_type = openai_server_info.api_type - if openai_server_info.api_base is not None: - openai.api_base = openai_server_info.api_base - if openai_server_info.api_version is not None: - openai.api_version = openai_server_info.api_version + async def stop(self): + pass - @cached_property + @property def name(self): - return self.default_model + return self.model + + @property + def context_length(self): + return MAX_TOKENS_FOR_MODEL[self.model] @property def default_args(self): - args = {**DEFAULT_ARGS, "model": self.default_model} + args = {**DEFAULT_ARGS, "model": self.model} if self.openai_server_info is not None: args["engine"] = self.openai_server_info.engine return args def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() @@ -51,7 +86,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -62,7 +97,7 @@ class OpenAI(LLM): yield chunk.choices[0].delta.content completion += chunk.choices[0].delta.content else: - continue + continue # :) self.write_log(f"Completion: \n\n{completion}") else: @@ -78,12 +113,13 @@ class OpenAI(LLM): args = self.default_args.copy() args.update(kwargs) args["stream"] = True - args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" + # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? + args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" if not args["model"].endswith("0613") and "functions" in args: del args["functions"] messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -100,7 +136,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") resp = (await openai.ChatCompletion.acreate( messages=messages, @@ -109,7 +145,7 @@ class OpenAI(LLM): self.write_log(f"Completion: \n\n{resp}") else: prompt = prune_raw_prompt_from_top( - args["model"], prompt, args["max_tokens"]) + args["model"], self.context_length, prompt, args["max_tokens"]) self.write_log(f"Prompt:\n\n{prompt}") resp = (await openai.Completion.acreate( prompt=prompt, diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index f9e3fa01..1a48f213 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,7 +1,6 @@ - import json import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional import aiohttp from ...core.main import ChatMessage from ..llm import LLM @@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path) # SERVER_URL = "http://127.0.0.1:8080" SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} + class ProxyServer(LLM): - unique_id: str - name: str - default_model: Literal["gpt-3.5-turbo", "gpt-4"] - write_log: Callable[[str], None] + model: str + system_message: Optional[str] - def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): - self.unique_id = unique_id - self.default_model = default_model - self.system_message = system_message - self.name = default_model + unique_id: str = None + write_log: Callable[[str], None] = None + _client_session: aiohttp.ClientSession + + requires_unique_id = True + requires_write_log = True + + class Config: + arbitrary_types_allowed = True + + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl_context=ssl_context)) self.write_log = write_log + self.unique_id = unique_id + + async def stop(self): + await self._client_session.close() + + @property + def name(self): + return self.model + + @property + def context_length(self): + return MAX_TOKENS_FOR_MODEL[self.model] @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.model} def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) def get_headers(self): # headers with unique id @@ -45,75 +69,72 @@ class ProxyServer(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: - async with session.post(f"{SERVER_URL}/complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - if resp.status != 200: - raise Exception(await resp.text()) - - response_text = await resp.text() - self.write_log(f"Completion: \n\n{response_text}") - return response_text + async with self._client_session.post(f"{SERVER_URL}/complete", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + response_text = await resp.text() + self.write_log(f"Completion: \n\n{response_text}") + return response_text async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: - async with session.post(f"{SERVER_URL}/stream_chat", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - # This is streaming application/json instaed of text/event-stream - completion = "" - if resp.status != 200: - raise Exception(await resp.text()) - async for line in resp.content.iter_chunks(): - if line[1]: - try: - json_chunk = line[0].decode("utf-8") - json_chunk = "{}" if json_chunk == "" else json_chunk - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - loaded_chunk = json.loads(chunk) - yield loaded_chunk - if "content" in loaded_chunk: - completion += loaded_chunk["content"] - except Exception as e: - posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { - "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) - else: - break - - self.write_log(f"Completion: \n\n{completion}") + async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + # This is streaming application/json instaed of text/event-stream + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + json_chunk = "{}" if json_chunk == "" else json_chunk + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + loaded_chunk = json.loads(chunk) + yield loaded_chunk + if "content" in loaded_chunk: + completion += loaded_chunk["content"] + except Exception as e: + posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { + "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + else: + break + + self.write_log(f"Completion: \n\n{completion}") async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: - async with session.post(f"{SERVER_URL}/stream_complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - completion = "" - if resp.status != 200: - raise Exception(await resp.text()) - async for line in resp.content.iter_any(): - if line: - try: - decoded_line = line.decode("utf-8") - yield decoded_line - completion += decoded_line - except: - raise Exception(str(line)) - self.write_log(f"Completion: \n\n{completion}") + async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) + async for line in resp.content.iter_any(): + if line: + try: + decoded_line = line.decode("utf-8") + yield decoded_line + completion += decoded_line + except: + raise Exception(str(line)) + self.write_log(f"Completion: \n\n{completion}") diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py deleted file mode 100644 index 76240d4e..00000000 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import GPT2TokenizerFast - -gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") -def count_tokens(text: str) -> int: - return len(gpt2_tokenizer.encode(text)) - -prices = { - # All prices are per 1k tokens - "fine-tune-train": { - "davinci": 0.03, - "curie": 0.03, - "babbage": 0.0006, - "ada": 0.0004, - }, - "completion": { - "davinci": 0.02, - "curie": 0.002, - "babbage": 0.0005, - "ada": 0.0004, - }, - "fine-tune-completion": { - "davinci": 0.12, - "curie": 0.012, - "babbage": 0.0024, - "ada": 0.0016, - }, - "embedding": { - "ada": 0.0004 - } -} - -def get_price(text: str, model: str="davinci", task: str="completion") -> float: - return count_tokens(text) * prices[task][model] / 1000
\ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py index ff0a135f..3e82bab3 100644 --- a/continuedev/src/continuedev/libs/util/calculate_diff.py +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -92,7 +92,7 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index] replacement = updated[j1:j2] if tag == "equal": - continue + continue # ;) elif tag == "delete": edits.append(FileEdit.from_deletion( filepath, Range.from_indices(original, i1, i2))) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index c58ae499..6add7b1a 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -2,43 +2,47 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage from .templating import render_templated_string +from ...libs.llm import LLM import tiktoken +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA) aliases = { "ggml": "gpt-3.5-turbo", "claude-2": "gpt-3.5-turbo", } DEFAULT_MAX_TOKENS = 2048 -MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-0613": 4096, - "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192, - "ggml": 2048, - "claude-2": 100000 -} -CHAT_MODELS = { - "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -} DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0} -def encoding_for_model(model: str): - return tiktoken.encoding_for_model(aliases.get(model, model)) +def encoding_for_model(model_name: str): + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except: + return tiktoken.encoding_for_model("gpt-3.5-turbo") -def count_tokens(model: str, text: Union[str, None]): +def count_tokens(model_name: str, text: Union[str, None]): if text is None: return 0 - encoding = encoding_for_model(model) + encoding = encoding_for_model(model_name) return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): - max_tokens = MAX_TOKENS_FOR_MODEL.get( - model, DEFAULT_MAX_TOKENS) - tokens_for_completion - encoding = encoding_for_model(model) +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): + max_tokens = context_length - tokens_for_completion + encoding = encoding_for_model(model_name) tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: return prompt @@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in return encoding.decode(tokens[-max_tokens:]) -def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: - # Doing simpler, safer version of what is here: - # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - # every message follows <|start|>{role/name}\n{content}<|end|>\n - TOKENS_PER_MESSAGE = 4 - return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE - - -def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int): total_tokens = tokens_for_completion + \ - sum(count_chat_message_tokens(model, message) + sum(count_chat_message_tokens(model_name, message) for message in chat_history) # 1. Replace beyond last 5 messages with summary i = 0 - while total_tokens > max_tokens and i < len(chat_history) - 5: + while total_tokens > context_length and i < len(chat_history) - 5: message = chat_history[0] - total_tokens -= count_tokens(model, message.content) - total_tokens += count_tokens(model, message.summary) + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) message.content = message.summary i += 1 # 2. Remove entire messages until the last 5 - while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0: + while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0: message = chat_history.pop(0) - total_tokens -= count_tokens(model, message.content) + total_tokens -= count_tokens(model_name, message.content) # 3. Truncate message in the last 5, except last 1 i = 0 - while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1: + while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1: message = chat_history[i] - total_tokens -= count_tokens(model, message.content) - total_tokens += count_tokens(model, message.summary) + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) message.content = message.summary i += 1 # 4. Remove entire messages in the last 5, except last 1 - while total_tokens > max_tokens and len(chat_history) > 1: + while total_tokens > context_length and len(chat_history) > 1: message = chat_history.pop(0) - total_tokens -= count_tokens(model, message.content) + total_tokens -= count_tokens(model_name, message.content) # 5. Truncate last message - if total_tokens > max_tokens and len(chat_history) > 0: + if total_tokens > context_length and len(chat_history) > 0: message = chat_history[0] message.content = prune_raw_prompt_from_top( - model, message.content, tokens_for_completion) - total_tokens = max_tokens + model_name, context_length, message.content, tokens_for_completion) + total_tokens = context_length return chat_history @@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: """ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message """ @@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_ function_tokens = 0 if functions is not None: for function in functions: - function_tokens += count_tokens(model, json.dumps(function)) + function_tokens += count_tokens(model_name, json.dumps(function)) msgs_copy = prune_chat_history( - model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy] diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index f1fb8d0b..285c1e47 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -12,7 +12,7 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: for i in range(1, len(lines)): # Empty lines are wildcards if lines[i].strip() == "": - continue + continue # hey that's us! # Iterate through the leading whitespace characters of the current line for j in range(0, len(lcp)): # If it doesn't have the same whitespace as lcp, then update lcp diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 51869fdd..2166bc37 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -39,7 +39,7 @@ def main(): json = schema_json_of(model, indent=2, title=title) except Exception as e: print(f"Failed to generate json schema for {title}: {e}") - continue + continue # pun intended with open(f"{SCHEMA_DIR}/{title}.json", "w") as f: f.write(json) diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py new file mode 100644 index 00000000..42d1f754 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional +import uuid +from pydantic import BaseModel + +from ...core.main import ContextItemId +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.chroma.query import ChromaIndexManager +from .util import remove_meilisearch_disallowed_chars + + +class EmbeddingResult(BaseModel): + filename: str + content: str + + +class EmbeddingsProvider(ContextProvider): + title = "embed" + + workspace_directory: str + + EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" + + index_manager: Optional[ChromaIndexManager] = None + + class Config: + arbitrary_types_allowed = True + + @property + def index(self): + if self.index_manager is None: + self.index_manager = ChromaIndexManager(self.workspace_directory) + return self.index_manager + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Embedding Search", + description="Enter a query to embedding search codebase", + id=ContextItemId( + provider_title=self.title, + item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID + ) + ) + ) + + async def _get_query_results(self, query: str) -> str: + results = self.index.query_codebase_index(query) + + ret = [] + for node in results.source_nodes: + resource_name = list(node.node.relationships.values())[0] + filepath = resource_name[:resource_name.index("::")] + ret.append(EmbeddingResult( + filename=filepath, content=node.node.text)) + + return ret + + async def provide_context_items(self) -> List[ContextItem]: + self.index.create_codebase_index() # TODO Synchronous here is not ideal + + return [self.BASE_CONTEXT_ITEM] + + async def add_context_item(self, id: ContextItemId, query: str): + if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID: + raise Exception("Invalid item id") + + results = await self._get_query_results(query) + + for i in range(len(results)): + result = results[i] + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.description.name = os.path.basename(result.filename) + ctx_item.content = f"{result.filename}\n```\n{result.content}\n```" + ctx_item.description.id.item_id = uuid.uuid4().hex + self.selected_items.append(ctx_item) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 634774df..31aa5423 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -3,6 +3,7 @@ import re from typing import List from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from .util import remove_meilisearch_disallowed_chars from fnmatch import fnmatch @@ -79,7 +80,7 @@ class FileContextProvider(ContextProvider): description=file, id=ContextItemId( provider_title=self.title, - item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file) + item_id=remove_meilisearch_disallowed_chars(file) ) ) )) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index fc76fe67..4b0a59ec 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -2,6 +2,7 @@ import json from typing import List import aiohttp +from .util import remove_meilisearch_disallowed_chars from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider @@ -60,5 +61,6 @@ class GoogleContextProvider(ContextProvider): ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = content - ctx_item.description.id.item_id = query + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( + query) return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py new file mode 100644 index 00000000..da2e6b17 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/util.py @@ -0,0 +1,5 @@ +import re + + +def remove_meilisearch_disallowed_chars(id: str) -> str: + return re.sub(r'[^0-9a-zA-Z_-]', '', id) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/plugins/policies/default.py index d90177b5..523c2cf4 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,15 +1,15 @@ from textwrap import dedent from typing import Union -from ..plugins.steps.chat import SimpleChatStep -from ..plugins.steps.welcome import WelcomeStep -from .config import ContinueConfig -from ..plugins.steps.steps_on_startup import StepsOnStartupStep -from .main import Step, History, Policy -from .observation import UserInputObservation -from ..plugins.steps.core.core import MessageStep -from ..plugins.steps.custom_command import CustomCommandStep -from ..plugins.steps.main import EditHighlightedCodeStep +from ..steps.chat import SimpleChatStep +from ..steps.welcome import WelcomeStep +from ...core.config import ContinueConfig +from ..steps.steps_on_startup import StepsOnStartupStep +from ...core.main import Step, History, Policy +from ...core.observation import UserInputObservation +from ..steps.core.core import MessageStep +from ..steps.custom_command import CustomCommandStep +from ..steps.main import EditHighlightedCodeStep def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: @@ -45,7 +45,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: class DefaultPolicy(Policy): - ran_code_last: bool = False + + default_step: Step = SimpleChatStep() def next(self, config: ContinueConfig, history: History) -> Step: # At the very start, run initial Steps spcecified in the config @@ -56,7 +57,6 @@ class DefaultPolicy(Policy): - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - Use `/help` to ask questions about how to use Continue""")) >> WelcomeStep() >> - # CreateCodebaseIndexChroma() >> StepsOnStartupStep()) observation = history.get_current().observation @@ -75,6 +75,6 @@ class DefaultPolicy(Policy): if user_input.startswith("/edit"): return EditHighlightedCodeStep(user_input=user_input[5:]) - return SimpleChatStep() + return self.default_step.copy() return None diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 433e309e..872f8d62 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -27,7 +27,7 @@ class SetupPipelineStep(Step): async def run(self, sdk: ContinueSDK): sdk.context.set("api_description", self.api_description) - source_name = (await sdk.models.gpt35.complete( + source_name = (await sdk.models.medium.complete( f"Write a snake_case name for the data source described by {self.api_description}: ")).strip() filename = f'{source_name}.py' @@ -89,7 +89,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ ```python {file_content} ``` @@ -101,7 +101,7 @@ class ValidatePipelineStep(Step): This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) - api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\ + api_documentation_url = await sdk.models.medium.complete(dedent(f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: ```python {file_content} @@ -151,7 +151,7 @@ class RunQueryStep(Step): output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ ```python {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))} ``` diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index 6ef5ffd6..c66cd629 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -42,7 +42,7 @@ class WritePytestsRecipe(Step): "{self.user_input}" Here is a complete set of pytest unit tests:""") - tests = await sdk.models.gpt35.complete(prompt) + tests = await sdk.models.medium.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md index 12073835..3f2f804c 100644 --- a/continuedev/src/continuedev/plugins/steps/README.md +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th - Return a static string - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. Here's an example: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index f72a8ec0..455d5a13 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep from ...core.main import FunctionCall, Models from ...core.main import ChatMessage, Step, step_to_json_schema from ...core.sdk import ContinueSDK +from ...libs.llm.openai import OpenAI import openai import os from dotenv import load_dotenv @@ -41,7 +42,7 @@ class SimpleChatStep(Step): self.description += chunk["content"] await sdk.update_ui() - self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( + self.name = remove_quotes_and_escapes(await sdk.models.medium.complete( f"Write a short title for the following chat message: {self.description}")) self.chat_context.append(ChatMessage( @@ -166,7 +167,10 @@ class ChatWithFunctions(Step): msg_content = "" msg_step = None - async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): + gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(gpt350613) + + async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): if sdk.current_step_was_deleted(): return diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index dbe8363e..658cc7f3 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""") - answer = await sdk.models.gpt35.complete(prompt) + answer = await sdk.models.medium.complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index c80cecc3..4476c7ae 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -11,11 +11,12 @@ from pydantic import validator from ....libs.llm.ggml import GGML from ....models.main import Range +from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents -from ....core.observation import Observation, TextObservation, UserInputObservation -from ....core.main import ChatMessage, ContinueCustomException, Step -from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS +from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation +from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep +from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes @@ -97,7 +98,7 @@ class ShellCommandsStep(Step): return f"Error when running shell commands:\n```\n{self._err_text}\n```" cmds_str = "\n".join(self.cmds) - return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") + return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd @@ -105,7 +106,7 @@ class ShellCommandsStep(Step): for cmd in self.cmds: output = await sdk.ide.runCommand(cmd) if self.handle_error and output is not None and output_contains_error(output): - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ While running the command `{cmd}`, the following error occurred: ```ascii @@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step): else: changes = '\n'.join(difflib.ndiff( self._previous_contents.splitlines(), self._new_contents.splitlines())) - description = await models.gpt3516k.complete(dedent(f"""\ + description = await models.medium.complete(dedent(f"""\ Diff summary: "{self.user_input}" ```diff @@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step): ``` Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) - name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") + name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") self.name = remove_quotes_and_escapes(name) return f"{remove_quotes_and_escapes(description)}" @@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step): # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. model_to_use = sdk.models.default - max_tokens = int(MAX_TOKENS_FOR_MODEL.get( - model_to_use.name, DEFAULT_MAX_TOKENS) / 2) + max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step): # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": - if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: - model_to_use = sdk.models.gpt3516k + if total_tokens > model_to_use.context_length: + model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(model_to_use) # Remove tokens from the end first, and then the start to clear space # This part finds the start and end lines @@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step): cur_start_line = 0 cur_end_line = len(full_file_contents_lst) - 1 - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens > model_to_use.context_length: while cur_end_line > min_end_line: total_tokens -= model_to_use.count_tokens( full_file_contents_lst[cur_end_line]) cur_end_line -= 1 - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens < model_to_use.context_length: break - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens > model_to_use.context_length: while cur_start_line < max_start_line: cur_start_line += 1 total_tokens -= model_to_use.count_tokens( full_file_contents_lst[cur_start_line]) - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens < model_to_use.context_length: break # Now use the found start/end lines to get the prefix and suffix strings @@ -525,7 +526,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user # Accumulate lines if "content" not in chunk: - continue + continue # ayo chunk = chunk["content"] chunk_lines = chunk.split("\n") chunk_lines[0] = unfinished_line + chunk_lines[0] @@ -546,12 +547,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user break # Lines that should be ignored, like the <> tags elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): - continue + continue # noice # Check if we are currently just copying the prefix elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line lines_of_prefix_copied += 1 - continue + continue # also nice # Because really short lines might be expected to be repeated, this is only a !heuristic! # Stop when it starts copying the file_suffix elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py index a76d491b..c38f54dc 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/migration.py +++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step): recent_edits = await sdk.ide.get_recent_edits(self.edited_file) recent_edits_string = "\n\n".join( map(lambda x: x.to_string(), recent_edits)) - description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") + description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") await sdk.run([ "cd libs", "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 6997a547..ec670999 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -56,7 +56,7 @@ class HelpStep(Step): summary="Help" )) messages = await sdk.get_chat_context() - generator = sdk.models.gpt4.stream_chat(messages) + generator = sdk.models.default.stream_chat(messages) async for chunk in generator: if "content" in chunk: self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index b54d394a..3d8d96fb 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -23,6 +23,6 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = await sdk.models.gpt35.complete( + gpt_parsed = await sdk.models.default.complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index f28d9660..d2d6f4dd 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.gpt35.complete(prompt) + completion = await sdk.models.medium.complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt @@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") + return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: range_in_files = await sdk.get_code_context(only_editing=True) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index 8b2e7c2e..da6acdbf 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -27,7 +27,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") - resp = (await sdk.models.gpt35.complete(prompt)).lower() + resp = (await sdk.models.medium.complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 07b50473..456dba84 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -20,10 +20,10 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]: for root, dirs, files in os.walk(dirpath): dirname = os.path.basename(root) if dirname.startswith(".") or dirname in IGNORE_DIRS: - continue + continue # continue! for file in files: if file in IGNORE_FILES: - continue + continue # pun intended with open(os.path.join(root, file), "r") as f: # Find the index of all occurences of the pattern in the file. Use re. file_content = f.read() @@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.gpt35.complete(dedent(f"""\ + pattern = await sdk.models.medium.complete(dedent(f"""\ This is the user request: {self.user_request} diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 98a5aea0..cf18c56b 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -176,7 +176,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we message = json.loads(message) if "messageType" not in message or "data" not in message: - continue + continue # :o message_type = message["messageType"] data = message["data"] diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e4c07029..6124f3bd 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -139,7 +139,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): msg_string = await self.websocket.receive_text() message = json.loads(msg_string) if "messageType" not in message or "data" not in message: - continue + continue # <-- hey that's the name of this repo! message_type = message["messageType"] data = message["data"] logger.debug(f"Received message while initializing {message_type}") @@ -311,7 +311,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onFileEdits(self, edits: List[FileEditWithFullContents]): if autopilot := self.__get_autopilot(): - autopilot.handle_manual_edits(edits) + pass def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index cf46028f..b5580fe8 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,4 +1,5 @@ import os +import traceback from fastapi import WebSocket from typing import Any, Dict, List, Union from uuid import uuid4 @@ -6,12 +7,10 @@ import json from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER -from ..core.policy import DefaultPolicy -from ..core.main import FullState +from ..core.main import FullState, HistoryNode from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task @@ -31,19 +30,6 @@ class Session: self.ws = None -class DemoAutopilot(Autopilot): - first_seen: bool = False - cumulative_edit_string = "" - - def handle_manual_edits(self, edits: List[FileEditWithFullContents]): - return - for edit in edits: - self.cumulative_edit_string += edit.fileEdit.replacement - self._manual_edits_buffer.append(edit) - # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. - # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) - - class SessionManager: sessions: Dict[str, Session] = {} # Mapping of session_id to IDE, where the IDE is still alive @@ -65,27 +51,47 @@ class SessionManager: async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: logger.debug(f"New session: {session_id}") + # Load the persisted state (not being used right now) full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: full_state = FullState(**json.load(f)) - autopilot = await DemoAutopilot.create( - policy=DefaultPolicy(), ide=ide, full_state=full_state) + # Register the session and ide (do this first so that the autopilot can access the session) + autopilot = Autopilot(ide=ide) session_id = session_id or str(uuid4()) ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) self.sessions[session_id] = session self.registered_ides[session_id] = ide + # Set up the autopilot to update the GUI async def on_update(state: FullState): await session_manager.send_ws_data(session_id, "state_update", { "state": state.dict() }) autopilot.on_update(on_update) - create_async_task(autopilot.run_policy( - ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) + + # Start the autopilot (must be after session is added to sessions) and the policy + try: + await autopilot.start() + except Exception as e: + # Have to manually add to history because autopilot isn't started + formatted_err = '\n'.join(traceback.format_exception(e)) + msg_step = MessageStep( + name="Error loading context manager", message=formatted_err) + msg_step.description = f"```\n{formatted_err}\n```" + autopilot.history.add_node(HistoryNode( + step=msg_step, + observation=None, + depth=0, + active=False + )) + logger.warning(f"Error loading context manager: {e}") + + create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step( + DisplayErrorStep(e=e))) return session async def remove_session(self, session_id: str): |