From 902ac7f67c2bb7b9e7c17ed15582414db42b5c52 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jul 2023 00:41:05 -0300 Subject: make Config receive LLM objects --- continuedev/src/continuedev/core/config.py | 5 +- continuedev/src/continuedev/core/sdk.py | 66 +++------------------- continuedev/src/continuedev/libs/llm/__init__.py | 8 +++ continuedev/src/continuedev/libs/llm/anthropic.py | 10 ++-- continuedev/src/continuedev/libs/llm/ggml.py | 1 + .../src/continuedev/libs/llm/hf_inference_api.py | 7 ++- continuedev/src/continuedev/libs/llm/openai.py | 13 +++-- .../src/continuedev/libs/llm/proxy_server.py | 6 ++ 8 files changed, 43 insertions(+), 73 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index cb9c8977..23f4fe65 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,6 +2,7 @@ import json import os from .main import Step from .context import ContextProvider +from ..libs.llm.openai import OpenAI from pydantic import BaseModel, validator from typing import List, Literal, Optional, Dict, Type, Union import yaml @@ -38,8 +39,7 @@ class ContinueConfig(BaseModel): steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True - default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", - "gpt-4", "claude-2", "ggml"] = 'gpt-4' + llm: LLM = OpenAI(default_model="gpt-4") temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", @@ -49,7 +49,6 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = [] on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None - azure_openai_info: Optional[AzureInfo] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 992bc1cf..92a72b23 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -63,18 +63,8 @@ class Models: self.system_message = sdk.config.system_message @classmethod - async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": - if sdk.config.default_model == "claude-2": - with_providers.append("anthropic") - - models = Models(sdk, with_providers) - for provider in with_providers: - if provider in MODEL_PROVIDER_TO_ENV_VAR: - env_var = MODEL_PROVIDER_TO_ENV_VAR[provider] - models.provider_keys[provider] = await sdk.get_user_secret( - env_var, f'Please add your {env_var} to the .env file') - - return models + async def create(cls, sdk: "ContinueSDK") -> "Models": + return self.default def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] @@ -90,54 +80,9 @@ class Models: api_key = self.provider_keys["anthropic"] return AnthropicLLM(api_key, model, self.system_message) - @cached_property - def claude2(self): - return self.__load_anthropic_model("claude-2") - - @cached_property - def starcoder(self): - return self.__load_hf_inference_api_model("bigcode/starcoder") - - @cached_property - def gpt35(self): - return self.__load_openai_model("gpt-3.5-turbo") - - @cached_property - def gpt350613(self): - return self.__load_openai_model("gpt-3.5-turbo-0613") - - @cached_property - def gpt3516k(self): - return self.__load_openai_model("gpt-3.5-turbo-16k") - - @cached_property - def gpt4(self): - return self.__load_openai_model("gpt-4") - - @cached_property - def ggml(self): - return GGML(system_message=self.system_message) - - def __model_from_name(self, model_name: str): - if model_name == "starcoder": - return self.starcoder - elif model_name == "gpt-3.5-turbo": - return self.gpt35 - elif model_name == "gpt-3.5-turbo-16k": - return self.gpt3516k - elif model_name == "gpt-4": - return self.gpt4 - elif model_name == "claude-2": - return self.claude2 - elif model_name == "ggml": - return self.ggml - else: - raise Exception(f"Unknown model {model_name}") - @property def default(self): - default_model = self.sdk.config.default_model - return self.__model_from_name(default_model) if default_model is not None else self.gpt4 + return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4") class ContinueSDK(AbstractContinueSDK): @@ -174,6 +119,7 @@ class ContinueSDK(AbstractContinueSDK): )) sdk.models = await Models.create(sdk) + await sdk.models.start() return sdk @property @@ -252,6 +198,10 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) + async def get_api_key(self, env_var: str) -> str: + # TODO support error prompt dynamically set on env_var + return await self.ide.getUserSecret(env_var) + async def get_user_secret(self, env_var: str, prompt: str) -> str: return await self.ide.getUserSecret(env_var) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2766db4b..5641d8a9 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,6 +9,14 @@ from pydantic import BaseModel class LLM(ABC): system_message: Union[str, None] = None + async def start(self): + """Start the connection to the LLM.""" + raise NotImplementedError + + async def stop(self): + """Stop the connection to the LLM.""" + raise NotImplementedError + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 625d4e57..0067ce3a 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,16 +9,18 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): - api_key: str default_model: str async_client: AsyncAnthropic - def __init__(self, api_key: str, default_model: str, system_message: str = None): - self.api_key = api_key + def __init__(self, default_model: str, system_message: str = None): self.default_model = default_model self.system_message = system_message - self.async_client = AsyncAnthropic(api_key=api_key) + async def start(self): + self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY")) + + async def stop(self): + pass @cached_property def name(self): diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..ef771a2e 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -35,6 +35,7 @@ class GGML(LLM): messages = compile_chat_messages( self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + # TODO move to single self.session variable (proxy setting etc) async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ "messages": messages, diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 36f03270..39b54f0f 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): - api_key: str model: str - def __init__(self, api_key: str, model: str, system_message: str = None): - self.api_key = api_key + def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this + async def start(self): + self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN")) + def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a0773c1d..1a48fa86 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,24 +10,27 @@ from ...core.config import AzureInfo class OpenAI(LLM): - api_key: str default_model: str - def __init__(self, api_key: str, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): - self.api_key = api_key + def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): self.default_model = default_model self.system_message = system_message self.azure_info = azure_info self.write_log = write_log - openai.api_key = api_key + async def start(self): + self.api_key = await sdk.get_api_key("OPENAI_API_KEY") + openai.api_key = self.api_key # Using an Azure OpenAI deployment - if azure_info is not None: + if self.azure_info is not None: openai.api_type = "azure" openai.api_base = azure_info.endpoint openai.api_version = azure_info.api_version + async def stop(self): + pass + @cached_property def name(self): return self.default_model diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index f9e3fa01..c0e2a403 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -30,6 +30,12 @@ class ProxyServer(LLM): self.name = default_model self.write_log = write_log + async def start(self): + pass + + async def stop(self): + pass + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.default_model} -- cgit v1.2.3-70-g09d2 From cde2cc05a75f1ae98d0ef95f8495e52ee3c6f163 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jul 2023 22:35:49 -0300 Subject: use large/medium/small role separation for llm config --- continuedev/src/continuedev/core/config.py | 4 +- continuedev/src/continuedev/core/sdk.py | 75 +++++++++++++++--------------- 2 files changed, 40 insertions(+), 39 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 23f4fe65..6957ae5e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -39,7 +39,9 @@ class ContinueConfig(BaseModel): steps_on_startup: List[Step] = [] disallowed_steps: Optional[List[str]] = [] allow_anonymous_telemetry: Optional[bool] = True - llm: LLM = OpenAI(default_model="gpt-4") + models: Models = Models( + default=MaybeProxyOpenAI(model="gpt-4"), + ) temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 92a72b23..183518ac 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -8,15 +8,11 @@ from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory from ..models.filesystem import RangeInFile -from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI -from ..libs.llm.openai import OpenAI -from ..libs.llm.anthropic import AnthropicLLM -from ..libs.llm.ggml import GGML +from ..libs.llm import LLM from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage from ..plugins.steps.core.core import * -from ..libs.llm.proxy_server import ProxyServer from ..libs.util.telemetry import posthog_logger from ..libs.util.paths import getConfigFilePath @@ -25,18 +21,12 @@ class Autopilot: pass -ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"] -MODEL_PROVIDER_TO_ENV_VAR = { - "openai": "OPENAI_API_KEY", - "hf_inference_api": "HUGGING_FACE_TOKEN", - "anthropic": "ANTHROPIC_API_KEY", -} - - class Models: - provider_keys: Dict[ModelProvider, str] = {} - model_providers: List[ModelProvider] - system_message: str + """Main class that holds the current model configuration""" + default: LLM + small: Optional[LLM] = None + medium: Optional[LLM] = None + large: Optional[LLM] = None """ Better to have sdk.llm.stream_chat(messages, model="claude-2"). @@ -57,32 +47,41 @@ class Models: '''depending on the model, return the single prompt string''' """ - def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): - self.sdk = sdk - self.model_providers = model_providers + def __init__(self, *, default, small=None, medium=None, large=None, custom=None): + self.default = default + self.small = small + self.medium = medium + self.large = large self.system_message = sdk.config.system_message - @classmethod - async def create(cls, sdk: "ContinueSDK") -> "Models": - return self.default + async def start(sdk: "ContinueSDK"): + self.sdk = sdk + await self.default.start() + if self.small: + await self.small.start() + else: + self.small = self.default - def __load_openai_model(self, model: str) -> OpenAI: - api_key = self.provider_keys["openai"] - if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log) - return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log) + if self.medium: + await self.medium.start() + else: + self.medium = self.default - def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: - api_key = self.provider_keys["hf_inference_api"] - return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) + if self.large: + await self.large.start() + else: + self.large = self.default - def __load_anthropic_model(self, model: str) -> AnthropicLLM: - api_key = self.provider_keys["anthropic"] - return AnthropicLLM(api_key, model, self.system_message) + async def stop(sdk: "ContinueSDK"): + await self.default.stop() + if self.small: + await self.small.stop() - @property - def default(self): - return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4") + if self.medium: + await self.medium.stop() + + if self.large: + await self.large.stop() class ContinueSDK(AbstractContinueSDK): @@ -118,8 +117,8 @@ class ContinueSDK(AbstractContinueSDK): active=False )) - sdk.models = await Models.create(sdk) - await sdk.models.start() + sdk.models = sdk.config.models + await sdk.models.start(sdk) return sdk @property -- cgit v1.2.3-70-g09d2 From 2b651d2504638ea9db97ba612f702356e38a805e Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jul 2023 23:50:47 -0300 Subject: make Models fetch LLM secret field declaratively --- continuedev/src/continuedev/core/sdk.py | 14 ++++++++++---- continuedev/src/continuedev/libs/llm/__init__.py | 3 ++- continuedev/src/continuedev/libs/llm/anthropic.py | 5 +++-- continuedev/src/continuedev/libs/llm/hf_inference_api.py | 5 +++-- continuedev/src/continuedev/libs/llm/openai.py | 5 +++-- continuedev/src/continuedev/libs/llm/utils.py | 1 + continuedev/src/continuedev/libs/util/count_tokens.py | 4 ++++ 7 files changed, 26 insertions(+), 11 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 183518ac..784f8ed1 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -54,21 +54,27 @@ class Models: self.large = large self.system_message = sdk.config.system_message + async def _start(llm: LLM): + kwargs = {} + if llm.required_api_key: + kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key) + await llm.start(**kwargs) + async def start(sdk: "ContinueSDK"): self.sdk = sdk - await self.default.start() + await self._start(self.default) if self.small: - await self.small.start() + await self._start(self.small) else: self.small = self.default if self.medium: - await self.medium.start() + await self._start(self.medium) else: self.medium = self.default if self.large: - await self.large.start() + await self._start(self.large) else: self.large = self.default diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 5641d8a9..6ae3dd46 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -7,9 +7,10 @@ from pydantic import BaseModel class LLM(ABC): + required_api_key: Optional[str] = None system_message: Union[str, None] = None - async def start(self): + async def start(self, *, api_key: Optional[str] = None): """Start the connection to the LLM.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 0067ce3a..846a2450 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,6 +9,7 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): + required_api_key: str = "ANTHROPIC_API_KEY" default_model: str async_client: AsyncAnthropic @@ -16,8 +17,8 @@ class AnthropicLLM(LLM): self.default_model = default_model self.system_message = system_message - async def start(self): - self.async_client = AsyncAnthropic(api_key=await self.sdk.get_api_key("ANTHROPIC_API_KEY")) + async def start(self, *, api_key): + self.async_client = AsyncAnthropic(api_key=api_key) async def stop(self): pass diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 39b54f0f..06d37596 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,15 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): + required_api_key: str = "HUGGING_FACE_TOKEN" model: str def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this - async def start(self): - self.api_key = await self.sdk.get_api_key("HUGGING_FACE_TOKEN")) + async def start(self, *, api_key): + self.api_key = api_key def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 1a48fa86..c8de90a8 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -10,6 +10,7 @@ from ...core.config import AzureInfo class OpenAI(LLM): + required_api_key = "OPENAI_API_KEY" default_model: str def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): @@ -18,8 +19,8 @@ class OpenAI(LLM): self.azure_info = azure_info self.write_log = write_log - async def start(self): - self.api_key = await sdk.get_api_key("OPENAI_API_KEY") + async def start(self, *, api_key): + self.api_key = api_key openai.api_key = self.api_key # Using an Azure OpenAI deployment diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py index 76240d4e..4ea45b7b 100644 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ b/continuedev/src/continuedev/libs/llm/utils.py @@ -5,6 +5,7 @@ gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") def count_tokens(text: str) -> int: return len(gpt2_tokenizer.encode(text)) +# TODO move this to LLM class itself (especially as prices may change in the future) prices = { # All prices are per 1k tokens "fine-tune-train": { diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index c58ae499..f6c7cb00 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -4,6 +4,10 @@ from ...core.main import ChatMessage from .templating import render_templated_string import tiktoken +# TODO move many of these into specific LLM.properties() function that +# contains max tokens, if its a chat model or not, default args (not all models +# want to be run at 0.5 temp). also lets custom models made for long contexts +# exist here (likg LLongMA) aliases = { "ggml": "gpt-3.5-turbo", "claude-2": "gpt-3.5-turbo", -- cgit v1.2.3-70-g09d2 From 714867f9a0d99548eef30c870b32384454b873ed Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 29 Jul 2023 15:10:59 -0300 Subject: turn Models and LLM into pydantic-compatible classes required as they're part of the config class --- continuedev/src/continuedev/core/config.py | 9 +--- continuedev/src/continuedev/core/models.py | 67 +++++++++++++++++++++++ continuedev/src/continuedev/core/sdk.py | 68 ------------------------ continuedev/src/continuedev/libs/llm/__init__.py | 6 ++- continuedev/src/continuedev/libs/llm/ggml.py | 2 +- continuedev/src/continuedev/libs/llm/openai.py | 9 +++- 6 files changed, 81 insertions(+), 80 deletions(-) create mode 100644 continuedev/src/continuedev/core/models.py (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6957ae5e..af37264d 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -2,7 +2,8 @@ import json import os from .main import Step from .context import ContextProvider -from ..libs.llm.openai import OpenAI +from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from .models import Models from pydantic import BaseModel, validator from typing import List, Literal, Optional, Dict, Type, Union import yaml @@ -26,12 +27,6 @@ class OnTracebackSteps(BaseModel): params: Optional[Dict] = {} -class AzureInfo(BaseModel): - endpoint: str - engine: str - api_version: str - - class ContinueConfig(BaseModel): """ A pydantic class for the continue config file. diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py new file mode 100644 index 00000000..c939d504 --- /dev/null +++ b/continuedev/src/continuedev/core/models.py @@ -0,0 +1,67 @@ +from typing import Optional +from pydantic import BaseModel +from ..libs.llm import LLM + + +class Models(BaseModel): + """Main class that holds the current model configuration""" + default: LLM + small: Optional[LLM] = None + medium: Optional[LLM] = None + large: Optional[LLM] = None + + """ + Better to have sdk.llm.stream_chat(messages, model="claude-2"). + Then you also don't care that it' async. + And it's easier to add more models. + And intermediate shared code is easier to add. + And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" + PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. + Easy to reason about, can place anywhere. + And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. + This can all happen inside of Models? + + class Prompt: + def __init__(self, ...info): + '''take whatever info is needed to describe the prompt''' + + def to_string(self, model: str) -> str: + '''depending on the model, return the single prompt string''' + """ + + async def _start(llm: LLM): + kwargs = {} + if llm.required_api_key: + kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key) + await llm.start(**kwargs) + + async def start(sdk: "ContinueSDK"): + self.sdk = sdk + self.system_message = self.sdk.config.system_message + await self._start(self.default) + if self.small: + await self._start(self.small) + else: + self.small = self.default + + if self.medium: + await self._start(self.medium) + else: + self.medium = self.default + + if self.large: + await self._start(self.large) + else: + self.large = self.default + + async def stop(sdk: "ContinueSDK"): + await self.default.stop() + if self.small: + await self.small.stop() + + if self.medium: + await self.medium.stop() + + if self.large: + await self.large.stop() + diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 784f8ed1..b0f7d40a 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -21,74 +21,6 @@ class Autopilot: pass -class Models: - """Main class that holds the current model configuration""" - default: LLM - small: Optional[LLM] = None - medium: Optional[LLM] = None - large: Optional[LLM] = None - - """ - Better to have sdk.llm.stream_chat(messages, model="claude-2"). - Then you also don't care that it' async. - And it's easier to add more models. - And intermediate shared code is easier to add. - And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo" - PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model. - Easy to reason about, can place anywhere. - And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model. - This can all happen inside of Models? - - class Prompt: - def __init__(self, ...info): - '''take whatever info is needed to describe the prompt''' - - def to_string(self, model: str) -> str: - '''depending on the model, return the single prompt string''' - """ - - def __init__(self, *, default, small=None, medium=None, large=None, custom=None): - self.default = default - self.small = small - self.medium = medium - self.large = large - self.system_message = sdk.config.system_message - - async def _start(llm: LLM): - kwargs = {} - if llm.required_api_key: - kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key) - await llm.start(**kwargs) - - async def start(sdk: "ContinueSDK"): - self.sdk = sdk - await self._start(self.default) - if self.small: - await self._start(self.small) - else: - self.small = self.default - - if self.medium: - await self._start(self.medium) - else: - self.medium = self.default - - if self.large: - await self._start(self.large) - else: - self.large = self.default - - async def stop(sdk: "ContinueSDK"): - await self.default.stop() - if self.small: - await self.small.stop() - - if self.medium: - await self.medium.stop() - - if self.large: - await self.large.stop() - class ContinueSDK(AbstractContinueSDK): """The SDK provided as parameters to a step""" diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 6ae3dd46..0f6b1505 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,12 +1,14 @@ +import functools from abc import ABC -from typing import Any, Coroutine, Dict, Generator, List, Union +from pydantic import BaseModel, ConfigDict +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional from ...core.main import ChatMessage from ...models.main import AbstractModel from pydantic import BaseModel -class LLM(ABC): +class LLM(BaseModel, ABC): required_api_key: Optional[str] = None system_message: Union[str, None] = None diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index ef771a2e..52e44bfe 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,6 +1,7 @@ from functools import cached_property import json from typing import Any, Coroutine, Dict, Generator, List, Union +from pydantic import ConfigDict import aiohttp from ...core.main import ChatMessage @@ -15,7 +16,6 @@ class GGML(LLM): def __init__(self, system_message: str = None): self.system_message = system_message - @cached_property def name(self): return "ggml" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c8de90a8..ef8830a6 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -2,11 +2,17 @@ from functools import cached_property import json from typing import Any, Callable, Coroutine, Dict, Generator, List, Union +from pydantic import BaseModel from ...core.main import ChatMessage import openai from ..llm import LLM from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top -from ...core.config import AzureInfo + + +class AzureInfo(BaseModel): + endpoint: str + engine: str + api_version: str class OpenAI(LLM): @@ -32,7 +38,6 @@ class OpenAI(LLM): async def stop(self): pass - @cached_property def name(self): return self.default_model -- cgit v1.2.3-70-g09d2 From 89dc1a0d4f8c80ae80532923db4dd3f0469e255e Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 29 Jul 2023 15:16:25 -0300 Subject: fix missing attributes in Models --- continuedev/src/continuedev/core/models.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index c939d504..8b1b1f00 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Any from pydantic import BaseModel from ..libs.llm import LLM @@ -10,6 +10,11 @@ class Models(BaseModel): medium: Optional[LLM] = None large: Optional[LLM] = None + # TODO namespace these away to not confuse readers, + # or split Models into ModelsConfig, which gets turned into Models + sdk: Any = None + system_message: Any = None + """ Better to have sdk.llm.stream_chat(messages, model="claude-2"). Then you also don't care that it' async. @@ -29,13 +34,13 @@ class Models(BaseModel): '''depending on the model, return the single prompt string''' """ - async def _start(llm: LLM): + async def _start(self, llm: LLM): kwargs = {} if llm.required_api_key: - kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key) + kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key) await llm.start(**kwargs) - async def start(sdk: "ContinueSDK"): + async def start(self, sdk: "ContinueSDK"): self.sdk = sdk self.system_message = self.sdk.config.system_message await self._start(self.default) @@ -54,7 +59,7 @@ class Models(BaseModel): else: self.large = self.default - async def stop(sdk: "ContinueSDK"): + async def stop(self, sdk: "ContinueSDK"): await self.default.stop() if self.small: await self.small.stop() -- cgit v1.2.3-70-g09d2 From c57182b8533a2c86d465bbf21e3a357bda13bb41 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 29 Jul 2023 15:49:10 -0300 Subject: make the MaybeProxy correctly instantiate the real LLM --- continuedev/src/continuedev/core/models.py | 4 ++++ continuedev/src/continuedev/libs/llm/ggml.py | 1 + continuedev/src/continuedev/libs/llm/openai.py | 25 +++++++++++----------- .../src/continuedev/libs/llm/proxy_server.py | 25 +++++++++++----------- 4 files changed, 31 insertions(+), 24 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index 8b1b1f00..ec89d503 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -38,6 +38,10 @@ class Models(BaseModel): kwargs = {} if llm.required_api_key: kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key) + if llm.required_unique_id: + kwargs["unique_id"] = self.sdk.ide.unique_id + if llm.required_write_log: + kwargs["write_log"] = self.sdk.write_log await llm.start(**kwargs) async def start(self, sdk: "ContinueSDK"): diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 52e44bfe..401709c9 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -16,6 +16,7 @@ class GGML(LLM): def __init__(self, system_message: str = None): self.system_message = system_message + @property def name(self): return "ggml" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index ef8830a6..5ac4d211 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,6 +1,6 @@ from functools import cached_property import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Union, Optional from pydantic import BaseModel from ...core.main import ChatMessage @@ -16,14 +16,13 @@ class AzureInfo(BaseModel): class OpenAI(LLM): + model: str + system_message: Optional[str] = None + azure_info: Optional[AzureInfo] = None + write_log: Optional[Callable[[str], None]] = None + required_api_key = "OPENAI_API_KEY" - default_model: str - - def __init__(self, default_model: str, system_message: str = None, azure_info: AzureInfo = None, write_log: Callable[[str], None] = None): - self.default_model = default_model - self.system_message = system_message - self.azure_info = azure_info - self.write_log = write_log + required_write_log = True async def start(self, *, api_key): self.api_key = api_key @@ -38,18 +37,19 @@ class OpenAI(LLM): async def stop(self): pass + @property def name(self): - return self.default_model + return self.model @property def default_args(self): - args = {**DEFAULT_ARGS, "model": self.default_model} + args = {**DEFAULT_ARGS, "model": self.model} if self.azure_info is not None: args["engine"] = self.azure_info.engine return args def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = self.default_args.copy() @@ -85,7 +85,8 @@ class OpenAI(LLM): args = self.default_args.copy() args.update(kwargs) args["stream"] = True - args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" + # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it? + args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613" if not args["model"].endswith("0613") and "functions" in args: del args["functions"] diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index c0e2a403..2c0e1dc4 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,7 +1,7 @@ import json import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional import aiohttp from ...core.main import ChatMessage from ..llm import LLM @@ -19,29 +19,30 @@ SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" class ProxyServer(LLM): unique_id: str - name: str - default_model: Literal["gpt-3.5-turbo", "gpt-4"] + model: str + system_message: Optional[str] write_log: Callable[[str], None] - def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None): - self.unique_id = unique_id - self.default_model = default_model - self.system_message = system_message - self.name = default_model - self.write_log = write_log + required_unique_id = True + required_write_log = True async def start(self): + # TODO put ClientSession here pass async def stop(self): pass + @property + def name(self): + return self.model + @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.model} def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) def get_headers(self): # headers with unique id @@ -103,7 +104,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: -- cgit v1.2.3-70-g09d2 From 11d7f0a9d178b7ae8f913a2ad5e70d623ce4b11e Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 15:53:53 -0700 Subject: refactor: :construction: refactor so server runs until requesting model --- CONTRIBUTING.md | 1 - continuedev/src/continuedev/core/autopilot.py | 1 - continuedev/src/continuedev/core/main.py | 1 - continuedev/src/continuedev/core/models.py | 33 +++++++++---------- .../libs/constants/default_config.py.txt | 8 ++--- continuedev/src/continuedev/libs/llm/__init__.py | 5 ++- continuedev/src/continuedev/libs/llm/anthropic.py | 27 +++++++-------- continuedev/src/continuedev/libs/llm/ggml.py | 6 ++++ .../src/continuedev/libs/llm/hf_inference_api.py | 6 ++-- .../src/continuedev/libs/llm/maybe_proxy_openai.py | 38 ++++++++++++++++++++++ continuedev/src/continuedev/libs/llm/openai.py | 12 ++++--- .../src/continuedev/libs/llm/proxy_server.py | 27 ++++++++------- 12 files changed, 108 insertions(+), 57 deletions(-) create mode 100644 continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py (limited to 'continuedev/src') diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a958777f..bf39f22c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -107,7 +107,6 @@ When state is updated on the server, we currently send the entirety of the objec - `history`, a record of previously run Steps. Displayed in order in the sidebar. - `active`, whether the autopilot is currently running a step. Displayed as a loader while step is running. - `user_input_queue`, the queue of user inputs that have not yet been processed due to waiting for previous Steps to complete. Displayed below the `active` loader until popped from the queue. -- `default_model`, the default model used for completions. Displayed as a toggleable button on the bottom of the GUI. - `selected_context_items`, the ranges of code and other items (like GitHub Issues, files, etc...) that have been selected to include as context. Displayed just above the main text input. - `slash_commands`, the list of available slash commands. Displayed in the main text input dropdown. - `adding_highlighted_code`, whether highlighting of new code for context is locked. Displayed as a button adjacent to `highlighted_ranges`. diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 42a58423..beb40c75 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -94,7 +94,6 @@ class Autopilot(ContinueBaseModel): history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, - default_model=self.continue_sdk.config.default_model, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ "code"].adding_highlighted_code, diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index df9b98ef..2553850f 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -258,7 +258,6 @@ class FullState(ContinueBaseModel): history: History active: bool user_input_queue: List[str] - default_model: str slash_commands: List[SlashCommandDescription] adding_highlighted_code: bool selected_context_items: List[ContextItem] diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index ec89d503..e4cb8ed6 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,5 +1,5 @@ from typing import Optional, Any -from pydantic import BaseModel +from pydantic import BaseModel, validator from ..libs.llm import LLM @@ -12,7 +12,7 @@ class Models(BaseModel): # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models - sdk: Any = None + sdk: "ContinueSDK" = None system_message: Any = None """ @@ -34,43 +34,42 @@ class Models(BaseModel): '''depending on the model, return the single prompt string''' """ - async def _start(self, llm: LLM): + async def _start_llm(self, llm: LLM): kwargs = {} - if llm.required_api_key: - kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key) - if llm.required_unique_id: + if llm.requires_api_key: + kwargs["api_key"] = await self.sdk.get_api_key(llm.requires_api_key) + if llm.requires_unique_id: kwargs["unique_id"] = self.sdk.ide.unique_id - if llm.required_write_log: + if llm.requires_write_log: kwargs["write_log"] = self.sdk.write_log await llm.start(**kwargs) async def start(self, sdk: "ContinueSDK"): + """Start each of the LLMs, or fall back to default""" self.sdk = sdk self.system_message = self.sdk.config.system_message - await self._start(self.default) + await self._start_llm(self.default) if self.small: - await self._start(self.small) + await self._start_llm(self.small) else: self.small = self.default if self.medium: - await self._start(self.medium) + await self._start_llm(self.medium) else: self.medium = self.default if self.large: - await self._start(self.large) + await self._start_llm(self.large) else: self.large = self.default async def stop(self, sdk: "ContinueSDK"): + """Stop each LLM (if it's not the default, which is shared)""" await self.default.stop() - if self.small: + if self.small is not self.default: await self.small.stop() - - if self.medium: + if self.medium is not self.default: await self.medium.stop() - - if self.large: + if self.large is not self.default: await self.large.stop() - diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index f80a9ff0..5708747f 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -12,7 +12,7 @@ from continuedev.src.continuedev.core.sdk import ContinueSDK from continuedev.src.continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.src.continuedev.plugins.context_providers.github import GitHubIssuesContextProvider from continuedev.src.continuedev.plugins.context_providers.google import GoogleContextProvider - +from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI class CommitMessageStep(Step): """ @@ -41,9 +41,9 @@ config = ContinueConfig( # See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry allow_anonymous_telemetry=True, - # GPT-4 is recommended for best results - # See options here: https://continue.dev/docs/customization#change-the-default-llm - default_model="gpt-4", + models=Models( + default=MaybeProxyOpenAI("gpt4") + ) # Set a system message with information that the LLM should always keep in mind # E.g. "Please give concise answers. Always respond in Spanish." diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 0f6b1505..21afc338 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,7 +9,10 @@ from pydantic import BaseModel class LLM(BaseModel, ABC): - required_api_key: Optional[str] = None + requires_api_key: Optional[str] = None + requires_unique_id: bool = False + requires_write_log: bool = False + system_message: Union[str, None] = None async def start(self, *, api_key: Optional[str] = None): diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 846a2450..067a903b 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,27 +9,28 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): - required_api_key: str = "ANTHROPIC_API_KEY" - default_model: str - async_client: AsyncAnthropic + model: str - def __init__(self, default_model: str, system_message: str = None): - self.default_model = default_model + requires_api_key: str = "ANTHROPIC_API_KEY" + _async_client: AsyncAnthropic = None + + def __init__(self, model: str, system_message: str = None): + self.model = model self.system_message = system_message - async def start(self, *, api_key): - self.async_client = AsyncAnthropic(api_key=api_key) + async def start(self, *, api_key: str): + self._async_client = AsyncAnthropic(api_key=api_key) async def stop(self): pass @cached_property def name(self): - return self.default_model + return self.model @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.model} def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: args = args.copy() @@ -43,7 +44,7 @@ class AnthropicLLM(LLM): return args def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -63,7 +64,7 @@ class AnthropicLLM(LLM): args["stream"] = True args = self._transform_args(args) - async for chunk in await self.async_client.completions.create( + async for chunk in await self._async_client.completions.create( prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", **args ): @@ -77,7 +78,7 @@ class AnthropicLLM(LLM): messages = compile_chat_messages( args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) - async for chunk in await self.async_client.completions.create( + async for chunk in await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args ): @@ -92,7 +93,7 @@ class AnthropicLLM(LLM): messages = compile_chat_messages( args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) - resp = (await self.async_client.completions.create( + resp = (await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args )).completion diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 401709c9..4bcf7e54 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -16,6 +16,12 @@ class GGML(LLM): def __init__(self, system_message: str = None): self.system_message = system_message + async def start(self, **kwargs): + pass + + async def stop(self): + pass + @property def name(self): return "ggml" diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 06d37596..4ad32e0e 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,16 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): - required_api_key: str = "HUGGING_FACE_TOKEN" model: str + requires_api_key: str = "HUGGING_FACE_TOKEN" + api_key: str = None + def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this - async def start(self, *, api_key): + async def start(self, *, api_key: str): self.api_key = api_key def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py new file mode 100644 index 00000000..d2898b5c --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -0,0 +1,38 @@ +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional + +from ...core.main import ChatMessage +from . import LLM +from .proxy_server import ProxyServer +from .openai import OpenAI + + +class MaybeProxyOpenAI(LLM): + model: str + + requires_api_key: Optional[str] = "OPENAI_API_KEY" + requires_write_log: bool = True + system_message: Union[str, None] = None + + llm: Optional[LLM] = None + + async def start(self, *, api_key: Optional[str] = None, **kwargs): + if api_key is None or api_key.strip() == "": + self.llm = ProxyServer( + unique_id="", model=self.model, write_log=kwargs["write_log"]) + else: + self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"]) + + async def stop(self): + await self.llm.stop() + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + return await self.llm.complete(prompt, with_history=with_history, **kwargs) + + def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + return self.llm.stream_complete(prompt, with_history=with_history, **kwargs) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + return self.llm.stream_chat(messages=messages, **kwargs) + + def count_tokens(self, text: str): + return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 5ac4d211..0c2c360b 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -17,12 +17,14 @@ class AzureInfo(BaseModel): class OpenAI(LLM): model: str + + requires_api_key = "OPENAI_API_KEY" + requires_write_log = True + system_message: Optional[str] = None azure_info: Optional[AzureInfo] = None write_log: Optional[Callable[[str], None]] = None - - required_api_key = "OPENAI_API_KEY" - required_write_log = True + api_key: str = None async def start(self, *, api_key): self.api_key = api_key @@ -31,8 +33,8 @@ class OpenAI(LLM): # Using an Azure OpenAI deployment if self.azure_info is not None: openai.api_type = "azure" - openai.api_base = azure_info.endpoint - openai.api_version = azure_info.api_version + openai.api_base = self.azure_info.endpoint + openai.api_version = self.azure_info.api_version async def stop(self): pass diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 2c0e1dc4..e8f1cb46 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,4 +1,3 @@ - import json import traceback from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional @@ -18,20 +17,24 @@ SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" class ProxyServer(LLM): - unique_id: str model: str system_message: Optional[str] - write_log: Callable[[str], None] - required_unique_id = True - required_write_log = True + unique_id: str = None + write_log: Callable[[str], None] = None + _client_session: aiohttp.ClientSession + + requires_unique_id = True + requires_write_log = True - async def start(self): - # TODO put ClientSession here - pass + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl_context=ssl_context)) + self.write_log = kwargs["write_log"] + self.unique_id = kwargs["unique_id"] async def stop(self): - pass + await self._client_session.close() @property def name(self): @@ -54,7 +57,7 @@ class ProxyServer(LLM): messages = compile_chat_messages( args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/complete", json={ "messages": messages, **args @@ -72,7 +75,7 @@ class ProxyServer(LLM): args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ "messages": messages, **args @@ -107,7 +110,7 @@ class ProxyServer(LLM): self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ "messages": messages, **args -- cgit v1.2.3-70-g09d2 From e37996002d848fc71c82375199dc9a704f2c9b05 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 16:22:16 -0700 Subject: refactor: :construction: replace all sdk.models.gpt35/etc. with sdk.models.medium/etc. --- continuedev/src/continuedev/core/models.py | 18 ++++-------------- continuedev/src/continuedev/core/sdk.py | 12 +++++++++++- .../continuedev/libs/constants/default_config.py.txt | 5 ++++- continuedev/src/continuedev/libs/llm/__init__.py | 9 +++++++-- continuedev/src/continuedev/libs/llm/anthropic.py | 4 ++-- .../src/continuedev/libs/llm/hf_inference_api.py | 4 ++-- .../src/continuedev/libs/llm/maybe_proxy_openai.py | 17 ++++++++++++++--- continuedev/src/continuedev/libs/llm/openai.py | 2 +- continuedev/src/continuedev/libs/llm/proxy_server.py | 2 +- .../plugins/recipes/CreatePipelineRecipe/steps.py | 8 ++++---- .../plugins/recipes/WritePytestsRecipe/main.py | 2 +- continuedev/src/continuedev/plugins/steps/chat.py | 8 ++++++-- continuedev/src/continuedev/plugins/steps/chroma.py | 2 +- continuedev/src/continuedev/plugins/steps/core/core.py | 7 ++++--- .../src/continuedev/plugins/steps/draft/migration.py | 2 +- continuedev/src/continuedev/plugins/steps/help.py | 2 +- .../continuedev/plugins/steps/input/nl_multiselect.py | 2 +- continuedev/src/continuedev/plugins/steps/main.py | 2 +- continuedev/src/continuedev/plugins/steps/react.py | 2 +- .../src/continuedev/plugins/steps/search_directory.py | 2 +- 20 files changed, 68 insertions(+), 44 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index e4cb8ed6..900762b6 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -34,33 +34,23 @@ class Models(BaseModel): '''depending on the model, return the single prompt string''' """ - async def _start_llm(self, llm: LLM): - kwargs = {} - if llm.requires_api_key: - kwargs["api_key"] = await self.sdk.get_api_key(llm.requires_api_key) - if llm.requires_unique_id: - kwargs["unique_id"] = self.sdk.ide.unique_id - if llm.requires_write_log: - kwargs["write_log"] = self.sdk.write_log - await llm.start(**kwargs) - async def start(self, sdk: "ContinueSDK"): """Start each of the LLMs, or fall back to default""" self.sdk = sdk self.system_message = self.sdk.config.system_message - await self._start_llm(self.default) + await sdk.start_model(self.default) if self.small: - await self._start_llm(self.small) + await sdk.start_model(self.small) else: self.small = self.default if self.medium: - await self._start_llm(self.medium) + await sdk.start_model(self.medium) else: self.medium = self.default if self.large: - await self._start_llm(self.large) + await sdk.start_model(self.large) else: self.large = self.default diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index b0f7d40a..7febb932 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -15,13 +15,13 @@ from .main import Context, ContinueCustomException, History, HistoryNode, Step, from ..plugins.steps.core.core import * from ..libs.util.telemetry import posthog_logger from ..libs.util.paths import getConfigFilePath +from .models import Models class Autopilot: pass - class ContinueSDK(AbstractContinueSDK): """The SDK provided as parameters to a step""" ide: AbstractIdeProtocolServer @@ -66,6 +66,16 @@ class ContinueSDK(AbstractContinueSDK): def write_log(self, message: str): self.history.timeline[self.history.current_index].logs.append(message) + async def start_model(self, llm: LLM): + kwargs = {} + if llm.requires_api_key: + kwargs["api_key"] = await self.get_api_key(llm.requires_api_key) + if llm.requires_unique_id: + kwargs["unique_id"] = self.ide.unique_id + if llm.requires_write_log: + kwargs["write_log"] = self.write_log + await llm.start(**kwargs) + async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): return path diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 5708747f..7cd2226a 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -31,7 +31,10 @@ class CommitMessageStep(Step): # Ask gpt-3.5-16k to write a commit message, # and set it as the description of this step - self.description = await sdk.models.gpt3516k.complete( + gpt3516k = OpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(gpt3516k) + + self.description = await gpt3516k.complete( f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 21afc338..58572634 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,5 +1,5 @@ import functools -from abc import ABC +from abc import ABC, abstractproperty from pydantic import BaseModel, ConfigDict from typing import Any, Coroutine, Dict, Generator, List, Union, Optional @@ -15,7 +15,12 @@ class LLM(BaseModel, ABC): system_message: Union[str, None] = None - async def start(self, *, api_key: Optional[str] = None): + @abstractproperty + def name(self): + """Return the name of the LLM.""" + raise NotImplementedError + + async def start(self, *, api_key: Optional[str] = None, **kwargs): """Start the connection to the LLM.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 067a903b..c9c8e9db 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,7 +1,7 @@ from functools import cached_property import time -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM @@ -18,7 +18,7 @@ class AnthropicLLM(LLM): self.model = model self.system_message = system_message - async def start(self, *, api_key: str): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self._async_client = AsyncAnthropic(api_key=api_key) async def stop(self): diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 4ad32e0e..49f593d8 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from ...core.main import ChatMessage from ..llm import LLM import requests @@ -17,7 +17,7 @@ class HuggingFaceInferenceAPI(LLM): self.model = model self.system_message = system_message # TODO: Nothing being done with this - async def start(self, *, api_key: str): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self.api_key = api_key def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index d2898b5c..121ae99e 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -15,6 +15,10 @@ class MaybeProxyOpenAI(LLM): llm: Optional[LLM] = None + @property + def name(self): + return self.llm.name + async def start(self, *, api_key: Optional[str] = None, **kwargs): if api_key is None or api_key.strip() == "": self.llm = ProxyServer( @@ -22,17 +26,24 @@ class MaybeProxyOpenAI(LLM): else: self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"]) + await self.llm.start(api_key=api_key, **kwargs) + async def stop(self): await self.llm.stop() async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: return await self.llm.complete(prompt, with_history=with_history, **kwargs) - def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - return self.llm.stream_complete(prompt, with_history=with_history, **kwargs) + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + resp = self.llm.stream_complete( + prompt, with_history=with_history, **kwargs) + async for item in resp: + yield item async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - return self.llm.stream_chat(messages=messages, **kwargs) + resp = self.llm.stream_chat(messages=messages, **kwargs) + async for item in resp: + yield item def count_tokens(self, text: str): return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 0c2c360b..de02a614 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -26,7 +26,7 @@ class OpenAI(LLM): write_log: Optional[Callable[[str], None]] = None api_key: str = None - async def start(self, *, api_key): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self.api_key = api_key openai.api_key = self.api_key diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index e8f1cb46..1c942523 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -27,7 +27,7 @@ class ProxyServer(LLM): requires_unique_id = True requires_write_log = True - async def start(self, **kwargs): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context)) self.write_log = kwargs["write_log"] diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 433e309e..872f8d62 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -27,7 +27,7 @@ class SetupPipelineStep(Step): async def run(self, sdk: ContinueSDK): sdk.context.set("api_description", self.api_description) - source_name = (await sdk.models.gpt35.complete( + source_name = (await sdk.models.medium.complete( f"Write a snake_case name for the data source described by {self.api_description}: ")).strip() filename = f'{source_name}.py' @@ -89,7 +89,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ ```python {file_content} ``` @@ -101,7 +101,7 @@ class ValidatePipelineStep(Step): This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) - api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\ + api_documentation_url = await sdk.models.medium.complete(dedent(f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: ```python {file_content} @@ -151,7 +151,7 @@ class RunQueryStep(Step): output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ ```python {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))} ``` diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index 6ef5ffd6..c66cd629 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -42,7 +42,7 @@ class WritePytestsRecipe(Step): "{self.user_input}" Here is a complete set of pytest unit tests:""") - tests = await sdk.models.gpt35.complete(prompt) + tests = await sdk.models.medium.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 2c662459..0a0fbca2 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -9,6 +9,7 @@ from .core.core import MessageStep from ...core.main import FunctionCall, Models from ...core.main import ChatMessage, Step, step_to_json_schema from ...core.sdk import ContinueSDK +from ...libs.llm.openai import OpenAI import openai import os from dotenv import load_dotenv @@ -43,7 +44,7 @@ class SimpleChatStep(Step): completion += chunk["content"] await sdk.update_ui() finally: - self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( + self.name = remove_quotes_and_escapes(await sdk.models.medium.complete( f"Write a short title for the following chat message: {self.description}")) self.chat_context.append(ChatMessage( @@ -168,7 +169,10 @@ class ChatWithFunctions(Step): msg_content = "" msg_step = None - async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): + gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(gpt350613) + + async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): if sdk.current_step_was_deleted(): return diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index dbe8363e..658cc7f3 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""") - answer = await sdk.models.gpt35.complete(prompt) + answer = await sdk.models.medium.complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 5a81e5ee..b9f27fe5 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -7,7 +7,7 @@ from typing import Coroutine, List, Literal, Union from ....libs.llm.ggml import GGML from ....models.main import Range -from ....libs.llm.prompt_utils import MarkdownStyleEncoderDecoder +from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation @@ -84,7 +84,7 @@ class ShellCommandsStep(Step): for cmd in self.cmds: output = await sdk.ide.runCommand(cmd) if self.handle_error and output is not None and output_contains_error(output): - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ While running the command `{cmd}`, the following error occurred: ```ascii @@ -202,7 +202,8 @@ class DefaultModelEditCodeStep(Step): # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: - model_to_use = sdk.models.gpt3516k + model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(model_to_use) # Remove tokens from the end first, and then the start to clear space # This part finds the start and end lines diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py index a76d491b..c38f54dc 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/migration.py +++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step): recent_edits = await sdk.ide.get_recent_edits(self.edited_file) recent_edits_string = "\n\n".join( map(lambda x: x.to_string(), recent_edits)) - description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") + description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") await sdk.run([ "cd libs", "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index d3807706..4d75af30 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -49,7 +49,7 @@ class HelpStep(Step): summary="Help" )) messages = await sdk.get_chat_context() - generator = sdk.models.gpt4.stream_chat(messages) + generator = sdk.models.default.stream_chat(messages) async for chunk in generator: if "content" in chunk: self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index b54d394a..3d8d96fb 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -23,6 +23,6 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = await sdk.models.gpt35.complete( + gpt_parsed = await sdk.models.default.complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index a8752df2..26c1cabd 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -100,7 +100,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.gpt35.complete(prompt) + completion = await sdk.models.medium.complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index 8b2e7c2e..da6acdbf 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -27,7 +27,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") - resp = (await sdk.models.gpt35.complete(prompt)).lower() + resp = (await sdk.models.medium.complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 7d02d6fa..c13047d6 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.gpt35.complete(dedent(f"""\ + pattern = await sdk.models.medium.complete(dedent(f"""\ This is the user request: {self.user_request} -- cgit v1.2.3-70-g09d2 From 6ead69fb71ea01e8b0ab6964d17c5dd058244883 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 16:26:31 -0700 Subject: refactor: :zap: turn GGML's aiohttp.AsyncSession into an instance attribute --- .../src/continuedev/libs/constants/default_config.py.txt | 3 ++- continuedev/src/continuedev/libs/llm/ggml.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 7cd2226a..7c7f495e 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -45,7 +45,8 @@ config = ContinueConfig( allow_anonymous_telemetry=True, models=Models( - default=MaybeProxyOpenAI("gpt4") + default=MaybeProxyOpenAI("gpt4"), + medium=MaybeProxyOpenAI("gpt-3.5-turbo") ) # Set a system message with information that the LLM should always keep in mind diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4bcf7e54..990f35bc 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -13,11 +13,13 @@ SERVER_URL = "http://localhost:8000" class GGML(LLM): + _client_session: aiohttp.ClientSession + def __init__(self, system_message: str = None): self.system_message = system_message async def start(self, **kwargs): - pass + self._client_session = aiohttp.ClientSession() async def stop(self): pass @@ -43,7 +45,7 @@ class GGML(LLM): self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) # TODO move to single self.session variable (proxy setting etc) - async with aiohttp.ClientSession() as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ "messages": messages, **args @@ -61,7 +63,7 @@ class GGML(LLM): self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True - async with aiohttp.ClientSession() as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ "messages": messages, **args @@ -83,7 +85,7 @@ class GGML(LLM): async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - async with aiohttp.ClientSession() as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args -- cgit v1.2.3-70-g09d2 From 798e94f62b2c64762e2e6f79645e9334013ac7a8 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 16:30:06 -0700 Subject: refactor: :construction: rename get_api_key -> get_user_secret --- continuedev/src/continuedev/core/abstract_sdk.py | 2 +- continuedev/src/continuedev/core/sdk.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 94d7be10..e048f877 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC): pass @abstractmethod - async def get_user_secret(self, env_var: str, prompt: str) -> str: + async def get_user_secret(self, env_var: str) -> str: pass config: ContinueConfig diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 7febb932..1dd4b857 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -69,7 +69,7 @@ class ContinueSDK(AbstractContinueSDK): async def start_model(self, llm: LLM): kwargs = {} if llm.requires_api_key: - kwargs["api_key"] = await self.get_api_key(llm.requires_api_key) + kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key) if llm.requires_unique_id: kwargs["unique_id"] = self.ide.unique_id if llm.requires_write_log: @@ -145,13 +145,10 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) - async def get_api_key(self, env_var: str) -> str: + async def get_user_secret(self, env_var: str) -> str: # TODO support error prompt dynamically set on env_var return await self.ide.getUserSecret(env_var) - async def get_user_secret(self, env_var: str, prompt: str) -> str: - return await self.ide.getUserSecret(env_var) - _last_valid_config: ContinueConfig = None def _load_config_dot_py(self) -> ContinueConfig: -- cgit v1.2.3-70-g09d2 From 39076efbd74106ad59ad65e31d52b8d591c1d485 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 17:47:17 -0700 Subject: refactor: :recycle: clean up LLM-specific constants from util files --- continuedev/src/continuedev/core/context.py | 6 -- continuedev/src/continuedev/libs/llm/__init__.py | 5 ++ continuedev/src/continuedev/libs/llm/anthropic.py | 12 +++- continuedev/src/continuedev/libs/llm/ggml.py | 10 ++- .../src/continuedev/libs/llm/maybe_proxy_openai.py | 4 ++ continuedev/src/continuedev/libs/llm/openai.py | 24 +++++-- .../src/continuedev/libs/llm/proxy_server.py | 17 ++++- continuedev/src/continuedev/libs/llm/utils.py | 35 --------- .../src/continuedev/libs/util/count_tokens.py | 82 ++++++++++------------ .../src/continuedev/plugins/steps/core/core.py | 15 ++-- 10 files changed, 102 insertions(+), 108 deletions(-) delete mode 100644 continuedev/src/continuedev/libs/llm/utils.py (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index f81fa57a..8afbd610 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -169,12 +169,6 @@ class ContextManager: async with Client('http://localhost:7700') as search_client: await search_client.index(SEARCH_INDEX_NAME).add_documents(documents) - # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: - # """ - # Compiles the chat prompt into a single string. - # """ - # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) - async def select_context_item(self, id: str, query: str): """ Selects the ContextItem with the given id. diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 58572634..96e88383 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -43,3 +43,8 @@ class LLM(BaseModel, ABC): def count_tokens(self, text: str): """Return the number of tokens in the given text.""" raise NotImplementedError + + @abstractproperty + def context_length(self) -> int: + """Return the context length of the LLM in tokens, as counted by count_tokens.""" + raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index c9c8e9db..4444fd1b 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -5,7 +5,7 @@ from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM -from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens class AnthropicLLM(LLM): @@ -46,6 +46,12 @@ class AnthropicLLM(LLM): def count_tokens(self, text: str): return count_tokens(self.model, text) + @property + def context_length(self): + if self.model == "claude-2": + return 100000 + raise Exception(f"Unknown Anthropic model {self.model}") + def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -77,7 +83,7 @@ class AnthropicLLM(LLM): args = self._transform_args(args) messages = compile_chat_messages( - args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) async for chunk in await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args @@ -92,7 +98,7 @@ class AnthropicLLM(LLM): args = self._transform_args(args) messages = compile_chat_messages( - args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) resp = (await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 990f35bc..7fa51e34 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -28,6 +28,10 @@ class GGML(LLM): def name(self): return "ggml" + @property + def context_length(self): + return 2048 + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -42,7 +46,7 @@ class GGML(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) # TODO move to single self.session variable (proxy setting etc) async with self._client_session as session: @@ -60,7 +64,7 @@ class GGML(LLM): async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True async with self._client_session as session: @@ -87,7 +91,7 @@ class GGML(LLM): async with self._client_session as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args }) as resp: try: diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index 121ae99e..f5b3c18c 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -19,6 +19,10 @@ class MaybeProxyOpenAI(LLM): def name(self): return self.llm.name + @property + def context_length(self): + return self.llm.context_length + async def start(self, *, api_key: Optional[str] = None, **kwargs): if api_key is None or api_key.strip() == "": self.llm = ProxyServer( diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index de02a614..deb6df4c 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -6,7 +6,17 @@ from pydantic import BaseModel from ...core.main import ChatMessage import openai from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top + +CHAT_MODELS = { + "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" +} +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} class AzureInfo(BaseModel): @@ -43,6 +53,10 @@ class OpenAI(LLM): def name(self): return self.model + @property + def context_length(self): + return MAX_TOKENS_FOR_MODEL[self.model] + @property def default_args(self): args = {**DEFAULT_ARGS, "model": self.model} @@ -60,7 +74,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -93,7 +107,7 @@ class OpenAI(LLM): del args["functions"] messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") completion = "" async for chunk in await openai.ChatCompletion.acreate( @@ -110,7 +124,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") resp = (await openai.ChatCompletion.acreate( messages=messages, @@ -119,7 +133,7 @@ class OpenAI(LLM): self.write_log(f"Completion: \n\n{resp}") else: prompt = prune_raw_prompt_from_top( - args["model"], prompt, args["max_tokens"]) + args["model"], self.context_length, prompt, args["max_tokens"]) self.write_log(f"Prompt:\n\n{prompt}") resp = (await openai.Completion.acreate( prompt=prompt, diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 1c942523..56b123db 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -15,6 +15,13 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path) # SERVER_URL = "http://127.0.0.1:8080" SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-4": 8192, +} + class ProxyServer(LLM): model: str @@ -40,6 +47,10 @@ class ProxyServer(LLM): def name(self): return self.model + @property + def context_length(self): + return MAX_TOKENS_FOR_MODEL[self.model] + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.model} @@ -55,7 +66,7 @@ class ProxyServer(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with self._client_session as session: async with session.post(f"{SERVER_URL}/complete", json={ @@ -72,7 +83,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with self._client_session as session: @@ -107,7 +118,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") async with self._client_session as session: diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py deleted file mode 100644 index 4ea45b7b..00000000 --- a/continuedev/src/continuedev/libs/llm/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import GPT2TokenizerFast - -gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") -def count_tokens(text: str) -> int: - return len(gpt2_tokenizer.encode(text)) - -# TODO move this to LLM class itself (especially as prices may change in the future) -prices = { - # All prices are per 1k tokens - "fine-tune-train": { - "davinci": 0.03, - "curie": 0.03, - "babbage": 0.0006, - "ada": 0.0004, - }, - "completion": { - "davinci": 0.02, - "curie": 0.002, - "babbage": 0.0005, - "ada": 0.0004, - }, - "fine-tune-completion": { - "davinci": 0.12, - "curie": 0.012, - "babbage": 0.0024, - "ada": 0.0016, - }, - "embedding": { - "ada": 0.0004 - } -} - -def get_price(text: str, model: str="davinci", task: str="completion") -> float: - return count_tokens(text) * prices[task][model] / 1000 \ No newline at end of file diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index f6c7cb00..6add7b1a 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -2,6 +2,7 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage from .templating import render_templated_string +from ...libs.llm import LLM import tiktoken # TODO move many of these into specific LLM.properties() function that @@ -13,36 +14,35 @@ aliases = { "claude-2": "gpt-3.5-turbo", } DEFAULT_MAX_TOKENS = 2048 -MAX_TOKENS_FOR_MODEL = { - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-0613": 4096, - "gpt-3.5-turbo-16k": 16384, - "gpt-4": 8192, - "ggml": 2048, - "claude-2": 100000 -} -CHAT_MODELS = { - "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" -} DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0} -def encoding_for_model(model: str): - return tiktoken.encoding_for_model(aliases.get(model, model)) +def encoding_for_model(model_name: str): + try: + return tiktoken.encoding_for_model(aliases.get(model_name, model_name)) + except: + return tiktoken.encoding_for_model("gpt-3.5-turbo") -def count_tokens(model: str, text: Union[str, None]): +def count_tokens(model_name: str, text: Union[str, None]): if text is None: return 0 - encoding = encoding_for_model(model) + encoding = encoding_for_model(model_name) return len(encoding.encode(text, disallowed_special=())) -def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int): - max_tokens = MAX_TOKENS_FOR_MODEL.get( - model, DEFAULT_MAX_TOKENS) - tokens_for_completion - encoding = encoding_for_model(model) +def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int: + # Doing simpler, safer version of what is here: + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # every message follows <|start|>{role/name}\n{content}<|end|>\n + TOKENS_PER_MESSAGE = 4 + return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE + + +def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int): + max_tokens = context_length - tokens_for_completion + encoding = encoding_for_model(model_name) tokens = encoding.encode(prompt, disallowed_special=()) if len(tokens) <= max_tokens: return prompt @@ -50,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in return encoding.decode(tokens[-max_tokens:]) -def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int: - # Doing simpler, safer version of what is here: - # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - # every message follows <|start|>{role/name}\n{content}<|end|>\n - TOKENS_PER_MESSAGE = 4 - return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE - - -def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int): total_tokens = tokens_for_completion + \ - sum(count_chat_message_tokens(model, message) + sum(count_chat_message_tokens(model_name, message) for message in chat_history) # 1. Replace beyond last 5 messages with summary i = 0 - while total_tokens > max_tokens and i < len(chat_history) - 5: + while total_tokens > context_length and i < len(chat_history) - 5: message = chat_history[0] - total_tokens -= count_tokens(model, message.content) - total_tokens += count_tokens(model, message.summary) + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) message.content = message.summary i += 1 # 2. Remove entire messages until the last 5 - while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0: + while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0: message = chat_history.pop(0) - total_tokens -= count_tokens(model, message.content) + total_tokens -= count_tokens(model_name, message.content) # 3. Truncate message in the last 5, except last 1 i = 0 - while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1: + while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1: message = chat_history[i] - total_tokens -= count_tokens(model, message.content) - total_tokens += count_tokens(model, message.summary) + total_tokens -= count_tokens(model_name, message.content) + total_tokens += count_tokens(model_name, message.summary) message.content = message.summary i += 1 # 4. Remove entire messages in the last 5, except last 1 - while total_tokens > max_tokens and len(chat_history) > 1: + while total_tokens > context_length and len(chat_history) > 1: message = chat_history.pop(0) - total_tokens -= count_tokens(model, message.content) + total_tokens -= count_tokens(model_name, message.content) # 5. Truncate last message - if total_tokens > max_tokens and len(chat_history) > 0: + if total_tokens > context_length and len(chat_history) > 0: message = chat_history[0] message.content = prune_raw_prompt_from_top( - model, message.content, tokens_for_completion) - total_tokens = max_tokens + model_name, context_length, message.content, tokens_for_completion) + total_tokens = context_length return chat_history @@ -105,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: TOKEN_BUFFER_FOR_SAFETY = 100 -def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: +def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]: """ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message """ @@ -129,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_ function_tokens = 0 if functions is not None: for function in functions: - function_tokens += count_tokens(model, json.dumps(function)) + function_tokens += count_tokens(model_name, json.dumps(function)) msgs_copy = prune_chat_history( - model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) + model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY) history = [msg.to_dict(with_functions=functions is not None) for msg in msgs_copy] diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index b9f27fe5..4c5303fb 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -12,7 +12,7 @@ from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConte from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep -from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS +from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes import difflib @@ -182,8 +182,7 @@ class DefaultModelEditCodeStep(Step): # We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion. # Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need. model_to_use = sdk.models.default - max_tokens = int(MAX_TOKENS_FOR_MODEL.get( - model_to_use.name, DEFAULT_MAX_TOKENS) / 2) + max_tokens = int(model_to_use.context_length / 2) TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200 if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE: @@ -201,7 +200,7 @@ class DefaultModelEditCodeStep(Step): # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": - if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: + if total_tokens > model_to_use.context_length: model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(model_to_use) @@ -213,20 +212,20 @@ class DefaultModelEditCodeStep(Step): cur_start_line = 0 cur_end_line = len(full_file_contents_lst) - 1 - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens > model_to_use.context_length: while cur_end_line > min_end_line: total_tokens -= model_to_use.count_tokens( full_file_contents_lst[cur_end_line]) cur_end_line -= 1 - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens < model_to_use.context_length: break - if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens > model_to_use.context_length: while cur_start_line < max_start_line: cur_start_line += 1 total_tokens -= model_to_use.count_tokens( full_file_contents_lst[cur_start_line]) - if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]: + if total_tokens < model_to_use.context_length: break # Now use the found start/end lines to get the prefix and suffix strings -- cgit v1.2.3-70-g09d2 From 374058e07ca699b5a345b270067636f6785df3af Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 30 Jul 2023 23:23:50 -0300 Subject: fix GGML client session usage --- continuedev/src/continuedev/libs/llm/ggml.py | 78 +++++++++++++--------------- 1 file changed, 37 insertions(+), 41 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 7fa51e34..a760f7fb 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -13,7 +13,7 @@ SERVER_URL = "http://localhost:8000" class GGML(LLM): - _client_session: aiohttp.ClientSession + _client_session: aiohttp.ClientSession = None def __init__(self, system_message: str = None): self.system_message = system_message @@ -22,7 +22,7 @@ class GGML(LLM): self._client_session = aiohttp.ClientSession() async def stop(self): - pass + await self._client_session.close() @property def name(self): @@ -48,18 +48,16 @@ class GGML(LLM): messages = compile_chat_messages( self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) - # TODO move to single self.session variable (proxy setting etc) - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": messages, - **args - }) as resp: - async for line in resp.content.iter_any(): - if line: - try: - yield line.decode("utf-8") - except: - raise Exception(str(line)) + async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": messages, + **args + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} @@ -67,34 +65,32 @@ class GGML(LLM): self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/v1/chat/completions", json={ - "messages": messages, - **args - }) as resp: - # This is streaming application/json instaed of text/event-stream - async for line in resp.content.iter_chunks(): - if line[1]: - try: - json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): - continue - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - yield json.loads(chunk[6:])["choices"][0]["delta"] - except: - raise Exception(str(line[0])) + async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={ + "messages": messages, + **args + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield json.loads(chunk[6:])["choices"][0]["delta"] + except: + raise Exception(str(line[0])) async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), - **args - }) as resp: - try: - return await resp.text() - except: - raise Exception(await resp.text()) + async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={ + "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message), + **args + }) as resp: + try: + return await resp.text() + except: + raise Exception(await resp.text()) -- cgit v1.2.3-70-g09d2 From 0174a769f10f5ea8b1ec06787fc75eca8c45a1f1 Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 30 Jul 2023 23:27:24 -0300 Subject: fix ProxyServer client session usage --- .../src/continuedev/libs/llm/proxy_server.py | 109 ++++++++++----------- 1 file changed, 53 insertions(+), 56 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 56b123db..44734b1c 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -68,17 +68,16 @@ class ProxyServer(LLM): messages = compile_chat_messages( args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - if resp.status != 200: - raise Exception(await resp.text()) - - response_text = await resp.text() - self.write_log(f"Completion: \n\n{response_text}") - return response_text + async with self._client_session.post(f"{SERVER_URL}/complete", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + response_text = await resp.text() + self.write_log(f"Completion: \n\n{response_text}") + return response_text async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} @@ -86,34 +85,33 @@ class ProxyServer(LLM): args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/stream_chat", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - # This is streaming application/json instaed of text/event-stream - completion = "" - if resp.status != 200: - raise Exception(await resp.text()) - async for line in resp.content.iter_chunks(): - if line[1]: - try: - json_chunk = line[0].decode("utf-8") - json_chunk = "{}" if json_chunk == "" else json_chunk - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - loaded_chunk = json.loads(chunk) - yield loaded_chunk - if "content" in loaded_chunk: - completion += loaded_chunk["content"] - except Exception as e: - posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { - "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) - else: - break - - self.write_log(f"Completion: \n\n{completion}") + async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + # This is streaming application/json instaed of text/event-stream + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + json_chunk = "{}" if json_chunk == "" else json_chunk + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + loaded_chunk = json.loads(chunk) + yield loaded_chunk + if "content" in loaded_chunk: + completion += loaded_chunk["content"] + except Exception as e: + posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { + "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + else: + break + + self.write_log(f"Completion: \n\n{completion}") async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} @@ -121,20 +119,19 @@ class ProxyServer(LLM): self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/stream_complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - completion = "" - if resp.status != 200: - raise Exception(await resp.text()) - async for line in resp.content.iter_any(): - if line: - try: - decoded_line = line.decode("utf-8") - yield decoded_line - completion += decoded_line - except: - raise Exception(str(line)) - self.write_log(f"Completion: \n\n{completion}") + async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) + async for line in resp.content.iter_any(): + if line: + try: + decoded_line = line.decode("utf-8") + yield decoded_line + completion += decoded_line + except: + raise Exception(str(line)) + self.write_log(f"Completion: \n\n{completion}") -- cgit v1.2.3-70-g09d2 From d451b92810d73ccbdbec64b4421b0dd172f28f7f Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 30 Jul 2023 23:28:32 -0300 Subject: fix argument passing from MaybeServer to real LLM --- continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py | 11 +++++------ continuedev/src/continuedev/libs/llm/openai.py | 3 ++- continuedev/src/continuedev/libs/llm/proxy_server.py | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index f5b3c18c..a0f46fa9 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Coroutine, Dict, Generator, List, Union, Optional +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable from ...core.main import ChatMessage from . import LLM @@ -23,14 +23,13 @@ class MaybeProxyOpenAI(LLM): def context_length(self): return self.llm.context_length - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]): if api_key is None or api_key.strip() == "": - self.llm = ProxyServer( - unique_id="", model=self.model, write_log=kwargs["write_log"]) + self.llm = ProxyServer(model=self.model) else: - self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"]) + self.llm = OpenAI(model=self.model) - await self.llm.start(api_key=api_key, **kwargs) + await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id) async def stop(self): await self.llm.stop() diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index deb6df4c..16428d4e 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -36,7 +36,8 @@ class OpenAI(LLM): write_log: Optional[Callable[[str], None]] = None api_key: str = None - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): + self.write_log = write_log self.api_key = api_key openai.api_key = self.api_key diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 44734b1c..5ee8ad90 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -34,11 +34,11 @@ class ProxyServer(LLM): requires_unique_id = True requires_write_log = True - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context)) - self.write_log = kwargs["write_log"] - self.unique_id = kwargs["unique_id"] + self.write_log = write_log + self.unique_id = unique_id async def stop(self): await self._client_session.close() -- cgit v1.2.3-70-g09d2 From 685b1ce062a989490758c86f1cf2c187efafded8 Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 30 Jul 2023 23:32:39 -0300 Subject: let content_length be customized on GGML --- continuedev/src/continuedev/libs/llm/ggml.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index a760f7fb..378ec106 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -12,10 +12,13 @@ SERVER_URL = "http://localhost:8000" class GGML(LLM): + # this is model-specific + max_context_length: int _client_session: aiohttp.ClientSession = None - def __init__(self, system_message: str = None): + def __init__(self, max_context_length: int = 2048, system_message: str = None): + self.max_context_length = max_context_length self.system_message = system_message async def start(self, **kwargs): @@ -30,7 +33,7 @@ class GGML(LLM): @property def context_length(self): - return 2048 + return self.max_context_length @property def default_args(self): -- cgit v1.2.3-70-g09d2 From 1bc5777ed168e47e2ef2ab1b33eecf6cbd170a61 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 22:07:41 -0700 Subject: fix: :children_crossing: use default model in default config.py --- continuedev/src/continuedev/libs/constants/default_config.py.txt | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 7c7f495e..e40a2684 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -29,12 +29,9 @@ class CommitMessageStep(Step): diff = subprocess.check_output( ["git", "diff"], cwd=dir).decode("utf-8") - # Ask gpt-3.5-16k to write a commit message, + # Ask the LLM to write a commit message, # and set it as the description of this step - gpt3516k = OpenAI(model="gpt-3.5-turbo-0613") - await sdk.start_model(gpt3516k) - - self.description = await gpt3516k.complete( + self.description = await sdk.models.default.complete( f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") -- cgit v1.2.3-70-g09d2 From 96379a7bf5b576a2338142b10932d98cbc865d59 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 22:47:23 -0700 Subject: fix: :bug: post-merge fixes --- continuedev/src/continuedev/libs/constants/default_config.py.txt | 7 ++++--- continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py | 1 + continuedev/src/continuedev/plugins/steps/README.md | 2 +- continuedev/src/continuedev/plugins/steps/core/core.py | 6 +++--- continuedev/src/continuedev/plugins/steps/main.py | 2 +- docs/docs/walkthroughs/create-a-recipe.md | 2 +- 6 files changed, 11 insertions(+), 9 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index be978fd3..0eae86e6 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -9,6 +9,7 @@ import subprocess from continuedev.core.main import Step from continuedev.core.sdk import ContinueSDK +from continuedev.core.models import Models from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider from continuedev.plugins.context_providers.google import GoogleContextProvider @@ -48,9 +49,9 @@ config = ContinueConfig( allow_anonymous_telemetry=True, models=Models( - default=MaybeProxyOpenAI("gpt4"), - medium=MaybeProxyOpenAI("gpt-3.5-turbo") - ) + default=MaybeProxyOpenAI(model="gpt-4"), + medium=MaybeProxyOpenAI(model="gpt-3.5-turbo") + ), # Set a system message with information that the LLM should always keep in mind # E.g. "Please give concise answers. Always respond in Spanish." diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index a0f46fa9..edf58fd7 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -11,6 +11,7 @@ class MaybeProxyOpenAI(LLM): requires_api_key: Optional[str] = "OPENAI_API_KEY" requires_write_log: bool = True + requires_unique_id: bool = True system_message: Union[str, None] = None llm: Optional[LLM] = None diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md index 12073835..3f2f804c 100644 --- a/continuedev/src/continuedev/plugins/steps/README.md +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th - Return a static string - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. Here's an example: diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index fb9ea029..cf5a7510 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -98,7 +98,7 @@ class ShellCommandsStep(Step): return f"Error when running shell commands:\n```\n{self._err_text}\n```" cmds_str = "\n".join(self.cmds) - return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") + return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:") async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd @@ -186,7 +186,7 @@ class DefaultModelEditCodeStep(Step): else: changes = '\n'.join(difflib.ndiff( self._previous_contents.splitlines(), self._new_contents.splitlines())) - description = await models.gpt3516k.complete(dedent(f"""\ + description = await models.medium.complete(dedent(f"""\ Diff summary: "{self.user_input}" ```diff @@ -194,7 +194,7 @@ class DefaultModelEditCodeStep(Step): ``` Please give brief a description of the changes made above using markdown bullet points. Be concise:""")) - name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") + name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:") self.name = remove_quotes_and_escapes(name) return f"{remove_quotes_and_escapes(description)}" diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index 2c3d34fc..d2d6f4dd 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") + return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:") async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: range_in_files = await sdk.get_code_context(only_editing=True) diff --git a/docs/docs/walkthroughs/create-a-recipe.md b/docs/docs/walkthroughs/create-a-recipe.md index 5d80d083..2cb28f77 100644 --- a/docs/docs/walkthroughs/create-a-recipe.md +++ b/docs/docs/walkthroughs/create-a-recipe.md @@ -31,7 +31,7 @@ If you'd like to override the default description of your steps, which is just t - Return a static string - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. ## 2. Compose steps together into a complete recipe -- cgit v1.2.3-70-g09d2 From 72e83325a8eb5032c448a5e891c157987921ced2 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 30 Jul 2023 23:03:31 -0700 Subject: fix: :bug: test and fix small issues with GGML --- continuedev/src/continuedev/core/policy.py | 2 -- continuedev/src/continuedev/libs/llm/__init__.py | 7 ++----- continuedev/src/continuedev/libs/llm/anthropic.py | 3 +++ continuedev/src/continuedev/libs/llm/ggml.py | 5 ++--- continuedev/src/continuedev/libs/llm/proxy_server.py | 3 +++ 5 files changed, 10 insertions(+), 10 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index d90177b5..7c2a8ce0 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -45,8 +45,6 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: class DefaultPolicy(Policy): - ran_code_last: bool = False - def next(self, config: ContinueConfig, history: History) -> Step: # At the very start, run initial Steps spcecified in the config if history.get_current() is None: diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 96e88383..50577993 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,14 +1,11 @@ -import functools from abc import ABC, abstractproperty -from pydantic import BaseModel, ConfigDict from typing import Any, Coroutine, Dict, Generator, List, Union, Optional from ...core.main import ChatMessage -from ...models.main import AbstractModel -from pydantic import BaseModel +from ...models.main import ContinueBaseModel -class LLM(BaseModel, ABC): +class LLM(ContinueBaseModel, ABC): requires_api_key: Optional[str] = None requires_unique_id: bool = False requires_write_log: bool = False diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 4444fd1b..b01a84cd 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -14,6 +14,9 @@ class AnthropicLLM(LLM): requires_api_key: str = "ANTHROPIC_API_KEY" _async_client: AsyncAnthropic = None + class Config: + arbitrary_types_allowed = True + def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 378ec106..2b56a51c 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -17,9 +17,8 @@ class GGML(LLM): _client_session: aiohttp.ClientSession = None - def __init__(self, max_context_length: int = 2048, system_message: str = None): - self.max_context_length = max_context_length - self.system_message = system_message + class Config: + arbitrary_types_allowed = True async def start(self, **kwargs): self._client_session = aiohttp.ClientSession() diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 5ee8ad90..1a48f213 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -34,6 +34,9 @@ class ProxyServer(LLM): requires_unique_id = True requires_write_log = True + class Config: + arbitrary_types_allowed = True + async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs): self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context)) -- cgit v1.2.3-70-g09d2 From 490d30e11f7f7cee2b6b8ac2dd48c55dacffa36d Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Mon, 31 Jul 2023 01:23:17 -0700 Subject: docs: :memo: update documentation for LLMs in config.py --- continuedev/src/continuedev/libs/llm/anthropic.py | 2 +- continuedev/src/continuedev/libs/llm/ggml.py | 2 +- docs/docs/customization.md | 59 ++++++++++++++++++----- 3 files changed, 49 insertions(+), 14 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index b01a84cd..ec1b7e40 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,7 +9,7 @@ from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_token class AnthropicLLM(LLM): - model: str + model: str = "claude-2" requires_api_key: str = "ANTHROPIC_API_KEY" _async_client: AsyncAnthropic = None diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 2b56a51c..7742e8c3 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -13,7 +13,7 @@ SERVER_URL = "http://localhost:8000" class GGML(LLM): # this is model-specific - max_context_length: int + max_context_length: int = 2048 _client_session: aiohttp.ClientSession = None diff --git a/docs/docs/customization.md b/docs/docs/customization.md index fa4d110e..06183c4a 100644 --- a/docs/docs/customization.md +++ b/docs/docs/customization.md @@ -4,11 +4,25 @@ Continue can be deeply customized by editing the `ContinueConfig` object in `~/. ## Change the default LLM -Change the `default_model` field to any of "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "claude-2", or "ggml". +In `config.py`, you'll find the `models` property: + +```python +config = ContinueConfig( + ... + models=Models( + default=MaybeProxyOpenAI(model="gpt-4"), + medium=MaybeProxyOpenAI(model="gpt-3.5-turbo") + ) +) +``` + +The `default` model is the one used for most operations, including responding to your messages and editing code. The `medium` model is used for summarization tasks that require less quality. The values of these fields are both of the [`LLM`](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/llm/__init__.py) class, which implements methods for retreiving and streaming completions from an LLM. + +Below, we describe the `LLM` classes available in the Continue core library, and how they can be used. ### Adding an OpenAI API key -New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. +With the `MaybeProxyOpenAI` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code. Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps: @@ -18,34 +32,55 @@ Once you are using Continue regularly though, you will need to add an OpenAI API 4. Click Edit in settings.json under Continue: OpenAI_API_KEY" section 5. Paste your API key as the value for "continue.OPENAI_API_KEY" in settings.json -### claude-2 and gpt-X +The `MaybeProxyOpenAI` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead. + +These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k". + +### claude-2 -If you have access, simply set `default_model` to the model you would like to use, then you will be prompted for a personal API key after reloading VS Code. If using an OpenAI model, you can press enter to try with our API key for free. +Import the `Anthropic` LLM class and set it as the default model: + +```python +from continuedev.libs.llm.anthropic import Anthropic + +config = ContinueConfig( + ... + models=Models( + default=Anthropic(model="claude-2") + ) +) +``` + +Continue will automatically prompt you for your Anthropic API key, which must have access to Claude 2. You can request early access [here](https://www.anthropic.com/earlyaccess). ### Local models with ggml See our [5 minute quickstart](https://github.com/continuedev/ggml-server-example) to run any model locally with ggml. While these models don't yet perform as well, they are free, entirely private, and run offline. -Once the model is running on localhost:8000, set `default_model` in `~/.continue/config.py` to "ggml". +Once the model is running on localhost:8000, import the `GGML` LLM class from `continuedev.libs.llm.ggml` and set `default=GGML(max_context_length=2048)`. ### Self-hosting an open-source model -If you want to self-host on Colab, RunPod, Replicate, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`. +If you want to self-host on Colab, RunPod, Replicate, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 primary methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`. If by chance the provider has the exact same API interface as OpenAI, the `GGML` class will work for you out of the box, after changing the endpoint at the top of the file. ### Azure OpenAI Service -If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, set `default_model` to "gpt-4", and then set the `openai_server_info` property in the `ContinueConfig` like so: +If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, instantiate the model like so: ```python +from continuedev.libs.llm.openai import OpenAI, OpenAIServerInfo + config = ContinueConfig( ... - openai_server_info=OpenAIServerInfo( - api_base="https://my-azure-openai-instance.openai.azure.com/", - engine="my-azure-openai-deployment", - api_version="2023-03-15-preview", - api_type="azure" + models=Models( + default=OpenAI(model="gpt-3.5-turbo", server_info=OpenAIServerInfo( + api_base="https://my-azure-openai-instance.openai.azure.com/" + engine="my-azure-openai-deployment", + api_version="2023-03-15-preview", + api_type="azure" + )) ) ) ``` -- cgit v1.2.3-70-g09d2