diff options
-rw-r--r-- | CONTRIBUTING.md | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/models.py | 33 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/constants/default_config.py.txt | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 27 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py | 38 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 12 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 27 |
12 files changed, 108 insertions, 57 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a958777f..bf39f22c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -107,7 +107,6 @@ When state is updated on the server, we currently send the entirety of the objec - `history`, a record of previously run Steps. Displayed in order in the sidebar. - `active`, whether the autopilot is currently running a step. Displayed as a loader while step is running. - `user_input_queue`, the queue of user inputs that have not yet been processed due to waiting for previous Steps to complete. Displayed below the `active` loader until popped from the queue. -- `default_model`, the default model used for completions. Displayed as a toggleable button on the bottom of the GUI. - `selected_context_items`, the ranges of code and other items (like GitHub Issues, files, etc...) that have been selected to include as context. Displayed just above the main text input. - `slash_commands`, the list of available slash commands. Displayed in the main text input dropdown. - `adding_highlighted_code`, whether highlighting of new code for context is locked. Displayed as a button adjacent to `highlighted_ranges`. diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 42a58423..beb40c75 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -94,7 +94,6 @@ class Autopilot(ContinueBaseModel): history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, - default_model=self.continue_sdk.config.default_model, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ "code"].adding_highlighted_code, diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index df9b98ef..2553850f 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -258,7 +258,6 @@ class FullState(ContinueBaseModel): history: History active: bool user_input_queue: List[str] - default_model: str slash_commands: List[SlashCommandDescription] adding_highlighted_code: bool selected_context_items: List[ContextItem] diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index ec89d503..e4cb8ed6 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -1,5 +1,5 @@ from typing import Optional, Any -from pydantic import BaseModel +from pydantic import BaseModel, validator from ..libs.llm import LLM @@ -12,7 +12,7 @@ class Models(BaseModel): # TODO namespace these away to not confuse readers, # or split Models into ModelsConfig, which gets turned into Models - sdk: Any = None + sdk: "ContinueSDK" = None system_message: Any = None """ @@ -34,43 +34,42 @@ class Models(BaseModel): '''depending on the model, return the single prompt string''' """ - async def _start(self, llm: LLM): + async def _start_llm(self, llm: LLM): kwargs = {} - if llm.required_api_key: - kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key) - if llm.required_unique_id: + if llm.requires_api_key: + kwargs["api_key"] = await self.sdk.get_api_key(llm.requires_api_key) + if llm.requires_unique_id: kwargs["unique_id"] = self.sdk.ide.unique_id - if llm.required_write_log: + if llm.requires_write_log: kwargs["write_log"] = self.sdk.write_log await llm.start(**kwargs) async def start(self, sdk: "ContinueSDK"): + """Start each of the LLMs, or fall back to default""" self.sdk = sdk self.system_message = self.sdk.config.system_message - await self._start(self.default) + await self._start_llm(self.default) if self.small: - await self._start(self.small) + await self._start_llm(self.small) else: self.small = self.default if self.medium: - await self._start(self.medium) + await self._start_llm(self.medium) else: self.medium = self.default if self.large: - await self._start(self.large) + await self._start_llm(self.large) else: self.large = self.default async def stop(self, sdk: "ContinueSDK"): + """Stop each LLM (if it's not the default, which is shared)""" await self.default.stop() - if self.small: + if self.small is not self.default: await self.small.stop() - - if self.medium: + if self.medium is not self.default: await self.medium.stop() - - if self.large: + if self.large is not self.default: await self.large.stop() - diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index f80a9ff0..5708747f 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -12,7 +12,7 @@ from continuedev.src.continuedev.core.sdk import ContinueSDK from continuedev.src.continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.src.continuedev.plugins.context_providers.github import GitHubIssuesContextProvider from continuedev.src.continuedev.plugins.context_providers.google import GoogleContextProvider - +from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI class CommitMessageStep(Step): """ @@ -41,9 +41,9 @@ config = ContinueConfig( # See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry allow_anonymous_telemetry=True, - # GPT-4 is recommended for best results - # See options here: https://continue.dev/docs/customization#change-the-default-llm - default_model="gpt-4", + models=Models( + default=MaybeProxyOpenAI("gpt4") + ) # Set a system message with information that the LLM should always keep in mind # E.g. "Please give concise answers. Always respond in Spanish." diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 0f6b1505..21afc338 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,7 +9,10 @@ from pydantic import BaseModel class LLM(BaseModel, ABC): - required_api_key: Optional[str] = None + requires_api_key: Optional[str] = None + requires_unique_id: bool = False + requires_write_log: bool = False + system_message: Union[str, None] = None async def start(self, *, api_key: Optional[str] = None): diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 846a2450..067a903b 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -9,27 +9,28 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_ class AnthropicLLM(LLM): - required_api_key: str = "ANTHROPIC_API_KEY" - default_model: str - async_client: AsyncAnthropic + model: str - def __init__(self, default_model: str, system_message: str = None): - self.default_model = default_model + requires_api_key: str = "ANTHROPIC_API_KEY" + _async_client: AsyncAnthropic = None + + def __init__(self, model: str, system_message: str = None): + self.model = model self.system_message = system_message - async def start(self, *, api_key): - self.async_client = AsyncAnthropic(api_key=api_key) + async def start(self, *, api_key: str): + self._async_client = AsyncAnthropic(api_key=api_key) async def stop(self): pass @cached_property def name(self): - return self.default_model + return self.model @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.default_model} + return {**DEFAULT_ARGS, "model": self.model} def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: args = args.copy() @@ -43,7 +44,7 @@ class AnthropicLLM(LLM): return args def count_tokens(self, text: str): - return count_tokens(self.default_model, text) + return count_tokens(self.model, text) def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -63,7 +64,7 @@ class AnthropicLLM(LLM): args["stream"] = True args = self._transform_args(args) - async for chunk in await self.async_client.completions.create( + async for chunk in await self._async_client.completions.create( prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", **args ): @@ -77,7 +78,7 @@ class AnthropicLLM(LLM): messages = compile_chat_messages( args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) - async for chunk in await self.async_client.completions.create( + async for chunk in await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args ): @@ -92,7 +93,7 @@ class AnthropicLLM(LLM): messages = compile_chat_messages( args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) - resp = (await self.async_client.completions.create( + resp = (await self._async_client.completions.create( prompt=self.__messages_to_prompt(messages), **args )).completion diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 401709c9..4bcf7e54 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -16,6 +16,12 @@ class GGML(LLM): def __init__(self, system_message: str = None): self.system_message = system_message + async def start(self, **kwargs): + pass + + async def stop(self): + pass + @property def name(self): return "ggml" diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 06d37596..4ad32e0e 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -8,14 +8,16 @@ DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): - required_api_key: str = "HUGGING_FACE_TOKEN" model: str + requires_api_key: str = "HUGGING_FACE_TOKEN" + api_key: str = None + def __init__(self, model: str, system_message: str = None): self.model = model self.system_message = system_message # TODO: Nothing being done with this - async def start(self, *, api_key): + async def start(self, *, api_key: str): self.api_key = api_key def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py new file mode 100644 index 00000000..d2898b5c --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -0,0 +1,38 @@ +from typing import Any, Coroutine, Dict, Generator, List, Union, Optional + +from ...core.main import ChatMessage +from . import LLM +from .proxy_server import ProxyServer +from .openai import OpenAI + + +class MaybeProxyOpenAI(LLM): + model: str + + requires_api_key: Optional[str] = "OPENAI_API_KEY" + requires_write_log: bool = True + system_message: Union[str, None] = None + + llm: Optional[LLM] = None + + async def start(self, *, api_key: Optional[str] = None, **kwargs): + if api_key is None or api_key.strip() == "": + self.llm = ProxyServer( + unique_id="", model=self.model, write_log=kwargs["write_log"]) + else: + self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"]) + + async def stop(self): + await self.llm.stop() + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + return await self.llm.complete(prompt, with_history=with_history, **kwargs) + + def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + return self.llm.stream_complete(prompt, with_history=with_history, **kwargs) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + return self.llm.stream_chat(messages=messages, **kwargs) + + def count_tokens(self, text: str): + return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 5ac4d211..0c2c360b 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -17,12 +17,14 @@ class AzureInfo(BaseModel): class OpenAI(LLM): model: str + + requires_api_key = "OPENAI_API_KEY" + requires_write_log = True + system_message: Optional[str] = None azure_info: Optional[AzureInfo] = None write_log: Optional[Callable[[str], None]] = None - - required_api_key = "OPENAI_API_KEY" - required_write_log = True + api_key: str = None async def start(self, *, api_key): self.api_key = api_key @@ -31,8 +33,8 @@ class OpenAI(LLM): # Using an Azure OpenAI deployment if self.azure_info is not None: openai.api_type = "azure" - openai.api_base = azure_info.endpoint - openai.api_version = azure_info.api_version + openai.api_base = self.azure_info.endpoint + openai.api_version = self.azure_info.api_version async def stop(self): pass diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 2c0e1dc4..e8f1cb46 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,4 +1,3 @@ - import json import traceback from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional @@ -18,20 +17,24 @@ SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" class ProxyServer(LLM): - unique_id: str model: str system_message: Optional[str] - write_log: Callable[[str], None] - required_unique_id = True - required_write_log = True + unique_id: str = None + write_log: Callable[[str], None] = None + _client_session: aiohttp.ClientSession + + requires_unique_id = True + requires_write_log = True - async def start(self): - # TODO put ClientSession here - pass + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl_context=ssl_context)) + self.write_log = kwargs["write_log"] + self.unique_id = kwargs["unique_id"] async def stop(self): - pass + await self._client_session.close() @property def name(self): @@ -54,7 +57,7 @@ class ProxyServer(LLM): messages = compile_chat_messages( args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/complete", json={ "messages": messages, **args @@ -72,7 +75,7 @@ class ProxyServer(LLM): args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ "messages": messages, **args @@ -107,7 +110,7 @@ class ProxyServer(LLM): self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: + async with self._client_session as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ "messages": messages, **args |