diff options
Diffstat (limited to 'continuedev/src')
32 files changed, 283 insertions, 601 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 98730d38..fdb99d47 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -71,10 +71,6 @@ class AbstractContinueSDK(ABC): async def delete_directory(self, path: str): pass - @abstractmethod - async def get_user_secret(self, env_var: str) -> str: - pass - config: ContinueConfig @abstractmethod diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index bae82739..de0b8c53 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -507,7 +507,7 @@ class Autopilot(ContinueBaseModel): if self.session_info is None: async def create_title(): - title = await self.continue_sdk.models.medium.complete( + title = await self.continue_sdk.models.medium._complete( f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ', max_tokens=20, ) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index bfb89561..bb2c43dc 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -225,7 +225,7 @@ class ContextManager: await self.load_index(sdk.ide.workspace_directory) logger.debug("Loaded Meilisearch index") except asyncio.TimeoutError: - logger.warning("Meilisearch is not running. As of now, Continue does not attempt to download Meilisearch on Windows because the download process is more involved. If you'd like install Meilisearch (which allows you to reference context by typing '@' (e.g. files, GitHub issues, etc.)), follow the instructions here: https://www.meilisearch.com/docs/learn/getting_started/installation. Alternatively, you can track our progress on support for Meilisearch on Windows here: https://github.com/continuedev/continue/issues/408.") + logger.warning("Meilisearch is not running.") create_async_task(start_meilisearch(context_providers)) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 37992b67..9b1c2cd0 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -94,14 +94,7 @@ class ContinueSDK(AbstractContinueSDK): self.history.timeline[self.history.current_index].logs.append(message) async def start_model(self, llm: LLM): - kwargs = {} - if llm.requires_api_key: - kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key) - if llm.requires_unique_id: - kwargs["unique_id"] = self.ide.unique_id - if llm.requires_write_log: - kwargs["write_log"] = self.write_log - await llm.start(**kwargs) + await llm.start(unique_id=self.ide.unique_id, write_log=self.write_log) async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): @@ -211,10 +204,6 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) - async def get_user_secret(self, env_var: str) -> str: - # TODO support error prompt dynamically set on env_var - return await self.ide.getUserSecret(env_var) - _last_valid_config: ContinueConfig = None def _load_config_dot_py(self) -> ContinueConfig: diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 90ef7934..6a321a41 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,19 +1,32 @@ -from abc import ABC, abstractproperty -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from abc import ABC +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from ...models.main import ContinueBaseModel +from ..util.count_tokens import DEFAULT_ARGS, count_tokens class LLM(ContinueBaseModel, ABC): - requires_api_key: Optional[str] = None - requires_unique_id: bool = False - requires_write_log: bool = False - + title: Optional[str] = None system_message: Optional[str] = None + context_length: int = 2048 + "The maximum context length of the LLM in tokens, as counted by count_tokens." + + unique_id: Optional[str] = None + "The unique ID of the user." + + model: str + "The model name" + prompt_templates: dict = {} + write_log: Optional[Callable[[str], None]] = None + "A function that takes a string and writes it to the log." + + api_key: Optional[str] = None + "The API key for the LLM provider." + class Config: arbitrary_types_allowed = True extra = "allow" @@ -21,36 +34,39 @@ class LLM(ContinueBaseModel, ABC): def dict(self, **kwargs): original_dict = super().dict(**kwargs) original_dict.pop("write_log", None) - original_dict["name"] = self.name original_dict["class_name"] = self.__class__.__name__ return original_dict - @abstractproperty - def name(self): - """Return the name of the LLM.""" - raise NotImplementedError + def collect_args(self, **kwargs) -> Any: + """Collect the arguments for the LLM.""" + args = {**DEFAULT_ARGS.copy(), "model": self.model, "max_tokens": 1024} + args.update(kwargs) + return args - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): """Start the connection to the LLM.""" - raise NotImplementedError + self.write_log = write_log + self.unique_id = unique_id async def stop(self): """Stop the connection to the LLM.""" - raise NotImplementedError + pass - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError - def stream_complete( + def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the completion through generator.""" raise NotImplementedError - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the chat through generator.""" @@ -58,9 +74,4 @@ class LLM(ContinueBaseModel, ABC): def count_tokens(self, text: str): """Return the number of tokens in the given text.""" - raise NotImplementedError - - @abstractproperty - def context_length(self) -> int: - """Return the context length of the LLM in tokens, as counted by count_tokens.""" - raise NotImplementedError + return count_tokens(self.model, text) diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 16bc2fce..b5aff63a 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,47 +1,36 @@ -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Union from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages class AnthropicLLM(LLM): api_key: str + "Anthropic API key" + model: str = "claude-2" - requires_write_log = True _async_client: AsyncAnthropic = None class Config: arbitrary_types_allowed = True - write_log: Optional[Callable[[str], None]] = None - async def start( self, - *, - api_key: Optional[str] = None, - write_log: Callable[[str], None], **kwargs, ): - self.write_log = write_log + await super().start(**kwargs) self._async_client = AsyncAnthropic(api_key=self.api_key) - async def stop(self): - pass - - @property - def name(self): - return self.model + if self.model == "claude-2": + self.context_length = 100_000 - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model} + def collect_args(self, **kwargs) -> Any: + args = super().collect_args(**kwargs) - def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: - args = args.copy() if "max_tokens" in args: args["max_tokens_to_sample"] = args["max_tokens"] del args["max_tokens"] @@ -51,15 +40,6 @@ class AnthropicLLM(LLM): del args["presence_penalty"] return args - def count_tokens(self, text: str): - return count_tokens(self.model, text) - - @property - def context_length(self): - if self.model == "claude-2": - return 100000 - raise Exception(f"Unknown Anthropic model {self.model}") - def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -76,13 +56,11 @@ class AnthropicLLM(LLM): prompt += AI_PROMPT return prompt - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = self._transform_args(args) prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" self.write_log(f"Prompt: \n\n{prompt}") @@ -95,13 +73,11 @@ class AnthropicLLM(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = self._transform_args(args) messages = compile_chat_messages( args["model"], @@ -123,11 +99,10 @@ class AnthropicLLM(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} - args = self._transform_args(args) + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index db3aaed7..1668fb65 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,62 +1,29 @@ import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import compile_chat_messages, format_chat_messages class GGML(LLM): - # this is model-specific - max_context_length: int = 2048 server_url: str = "http://localhost:8000" verify_ssl: Optional[bool] = None - - requires_write_log = True - - write_log: Optional[Callable[[str], None]] = None + model: str = "ggml" class Config: arbitrary_types_allowed = True - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - pass - - @property - def name(self): - return "ggml" - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -84,12 +51,12 @@ class GGML(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -142,10 +109,10 @@ class GGML(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) self.write_log(f"Prompt: \n\n{prompt}") async with aiohttp.ClientSession( diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 4b78a247..3a586a43 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -5,17 +5,14 @@ import requests from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, count_tokens DEFAULT_MAX_TIME = 120.0 class HuggingFaceInferenceAPI(LLM): - model: str hf_token: str self_hosted_url: str = None - max_context_length: int = 2048 verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None @@ -24,6 +21,7 @@ class HuggingFaceInferenceAPI(LLM): arbitrary_types_allowed = True async def start(self, **kwargs): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) ) @@ -31,22 +29,7 @@ class HuggingFaceInferenceAPI(LLM): async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ): """Return the completion of the text with the given temperature.""" @@ -77,14 +60,14 @@ class HuggingFaceInferenceAPI(LLM): return data[0]["generated_text"] - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: - response = await self.complete(messages[-1].content, messages[:-1]) + response = await self._complete(messages[-1].content, messages[:-1]) yield {"content": response, "role": "assistant"} - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Any | List | Dict, None, None]: - response = await self.complete(prompt, with_history) + response = await self._complete(prompt, with_history) yield response diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index f04e700d..f106f83f 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -5,44 +5,22 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import code_llama_template_messages class HuggingFaceTGI(LLM): model: str = "huggingface-tgi" - max_context_length: int = 2048 server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages - requires_write_log = True - - write_log: Optional[Callable[[str], None]] = None - class Config: arbitrary_types_allowed = True - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - pass - - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def _transform_args(self, args): + def collect_args(self, **kwargs) -> Any: + args = super().collect_args(**kwargs) args = { **args, "max_new_tokens": args.get("max_tokens", 1024), @@ -50,19 +28,14 @@ class HuggingFaceTGI(LLM): args.pop("max_tokens", None) return args - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -93,12 +66,12 @@ class HuggingFaceTGI(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -107,7 +80,7 @@ class HuggingFaceTGI(LLM): system_message=self.system_message, ) - async for chunk in self.stream_complete( + async for chunk in self._stream_complete( None, self.template_messages(messages), **args ): yield { @@ -115,13 +88,13 @@ class HuggingFaceTGI(LLM): "content": chunk, } - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) completion = "" - async for chunk in self.stream_complete(prompt, with_history, **args): + async for chunk in self._stream_complete(prompt, with_history, **args): completion += chunk return completion diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index e6f38cd0..7940c4c9 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -6,12 +6,12 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages class LlamaCpp(LLM): - max_context_length: int = 2048 + model: str = "llamacpp" server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None @@ -20,9 +20,6 @@ class LlamaCpp(LLM): use_command: Optional[str] = None - requires_write_log = True - write_log: Optional[Callable[[str], None]] = None - class Config: arbitrary_types_allowed = True @@ -31,29 +28,8 @@ class LlamaCpp(LLM): d.pop("template_messages") return d - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - await self._client_session.close() - - @property - def name(self): - return "llamacpp" - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: - args = args.copy() + def collect_args(self, **kwargs) -> Any: + args = super().collect_args(**kwargs) if "max_tokens" in args: args["n_predict"] = args["max_tokens"] del args["max_tokens"] @@ -85,16 +61,14 @@ class LlamaCpp(LLM): await process.wait() - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -125,12 +99,12 @@ class LlamaCpp(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -177,10 +151,10 @@ class LlamaCpp(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) self.write_log(f"Prompt: \n\n{prompt}") diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index daffe41f..99b7c47f 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from . import LLM @@ -10,63 +10,42 @@ class MaybeProxyOpenAI(LLM): model: str api_key: Optional[str] = None - requires_write_log: bool = True - requires_unique_id: bool = True - system_message: Union[str, None] = None - llm: Optional[LLM] = None def update_llm_properties(self): if self.llm is not None: self.llm.system_message = self.system_message - @property - def name(self): - if self.llm is not None: - return self.llm.name - else: - return None - - @property - def context_length(self): - return self.llm.context_length - - async def start( - self, - *, - api_key: Optional[str] = None, - unique_id: str, - write_log: Callable[[str], None] - ): + async def start(self, **kwargs): if self.api_key is None or self.api_key.strip() == "": self.llm = ProxyServer(model=self.model) else: self.llm = OpenAI(api_key=self.api_key, model=self.model) - await self.llm.start(write_log=write_log, unique_id=unique_id) + await self.llm.start(**kwargs) async def stop(self): await self.llm.stop() - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: self.update_llm_properties() - return await self.llm.complete(prompt, with_history=with_history, **kwargs) + return await self.llm._complete(prompt, with_history=with_history, **kwargs) - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: self.update_llm_properties() - resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs) + resp = self.llm._stream_complete(prompt, with_history=with_history, **kwargs) async for item in resp: yield item - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: self.update_llm_properties() - resp = self.llm.stream_chat(messages=messages, **kwargs) + resp = self.llm._stream_chat(messages=messages, **kwargs) async for item in resp: yield item diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index 03300435..ef8ed47b 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -7,17 +7,15 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages class Ollama(LLM): model: str = "llama2" server_url: str = "http://localhost:11434" - max_context_length: int = 2048 _client_session: aiohttp.ClientSession = None - requires_write_log = True prompt_templates = { "edit": dedent( @@ -36,34 +34,19 @@ class Ollama(LLM): class Config: arbitrary_types_allowed = True - async def start(self, write_log, **kwargs): + async def start(self, **kwargs): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession() - self.write_log = write_log async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self) -> int: - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -102,12 +85,12 @@ class Ollama(LLM): yield urllib.parse.unquote(url_decode_buffer) url_decode_buffer = "" - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -143,11 +126,11 @@ class Ollama(LLM): completion += j["response"] self.write_log(f"Completion:\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: completion = "" - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) async with self._client_session.post( f"{self.server_url}/api/generate", json={ diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a61103b9..a017af22 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -12,26 +12,15 @@ from typing import ( import certifi import openai -from pydantic import BaseModel from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import ( - DEFAULT_ARGS, compile_chat_messages, - count_tokens, format_chat_messages, prune_raw_prompt_from_top, ) - -class OpenAIServerInfo(BaseModel): - api_base: Optional[str] = None - engine: Optional[str] = None - api_version: Optional[str] = None - api_type: Literal["azure", "openai"] = "openai" - - CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"} MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo": 4096, @@ -47,32 +36,43 @@ MAX_TOKENS_FOR_MODEL = { class OpenAI(LLM): api_key: str - model: str - openai_server_info: Optional[OpenAIServerInfo] = None + "OpenAI API key" + verify_ssl: Optional[bool] = None + "Whether to verify SSL certificates for requests." + ca_bundle_path: Optional[str] = None + "Path to CA bundle to use for requests." + proxy: Optional[str] = None + "Proxy URL to use for requests." + + api_base: Optional[str] = None + "OpenAI API base URL." - requires_write_log = True + api_type: Optional[Literal["azure", "openai"]] = None + "OpenAI API type." - write_log: Optional[Callable[[str], None]] = None + api_version: Optional[str] = None + "OpenAI API version. For use with Azure OpenAI Service." + + engine: Optional[str] = None + "OpenAI engine. For use with Azure OpenAI Service." async def start( - self, - *, - api_key: Optional[str] = None, - write_log: Callable[[str], None], - **kwargs, + self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None ): - self.write_log = write_log - openai.api_key = self.api_key + await super().start(write_log=write_log, unique_id=unique_id) - if self.openai_server_info is not None: - openai.api_type = self.openai_server_info.api_type - if self.openai_server_info.api_base is not None: - openai.api_base = self.openai_server_info.api_base - if self.openai_server_info.api_version is not None: - openai.api_version = self.openai_server_info.api_version + self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096) + + openai.api_key = self.api_key + if self.api_type is not None: + openai.api_type = self.api_type + if self.api_base is not None: + openai.api_base = self.api_base + if self.api_version is not None: + openai.api_version = self.api_version if self.verify_ssl is not None and self.verify_ssl is False: openai.verify_ssl_certs = False @@ -82,32 +82,16 @@ class OpenAI(LLM): openai.ca_bundle_path = self.ca_bundle_path or certifi.where() - async def stop(self): - pass - - @property - def name(self): - return self.model - - @property - def context_length(self): - return MAX_TOKENS_FOR_MODEL.get(self.model, 4096) - - @property - def default_args(self): - args = {**DEFAULT_ARGS, "model": self.model} - if self.openai_server_info is not None: - args["engine"] = self.openai_server_info.engine + def collect_args(self, **kwargs): + args = super().collect_args() + if self.engine is not None: + args["engine"] = self.engine return args - def count_tokens(self, text: str): - return count_tokens(self.model, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True if args["model"] in CHAT_MODELS: @@ -142,11 +126,10 @@ class OpenAI(LLM): self.write_log(f"Completion:\n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True if not args["model"].endswith("0613") and "functions" in args: @@ -174,10 +157,10 @@ class OpenAI(LLM): completion += chunk.choices[0].delta.content self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) if args["model"] in CHAT_MODELS: messages = compile_chat_messages( @@ -190,23 +173,18 @@ class OpenAI(LLM): system_message=self.system_message, ) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - resp = ( - ( - await openai.ChatCompletion.acreate( - messages=messages, - **args, - ) - ) - .choices[0] - .message.content + resp = await openai.ChatCompletion.acreate( + messages=messages, + **args, ) - self.write_log(f"Completion: \n\n{resp}") + completion = resp.choices[0].message.content + self.write_log(f"Completion: \n\n{completion}") else: prompt = prune_raw_prompt_from_top( args["model"], self.context_length, prompt, args["max_tokens"] ) self.write_log(f"Prompt:\n\n{prompt}") - resp = ( + completion = ( ( await openai.Completion.acreate( prompt=prompt, @@ -216,6 +194,6 @@ class OpenAI(LLM): .choices[0] .text ) - self.write_log(f"Completion:\n\n{resp}") + self.write_log(f"Completion:\n\n{completion}") - return resp + return completion diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index fa77a22a..3ac6371f 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,19 +1,14 @@ import json import ssl import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Union import aiohttp import certifi from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import compile_chat_messages, format_chat_messages from ..util.telemetry import posthog_logger ca_bundle_path = certifi.where() @@ -31,59 +26,32 @@ MAX_TOKENS_FOR_MODEL = { class ProxyServer(LLM): - model: str - system_message: Optional[str] - - unique_id: str = None - write_log: Callable[[str], None] = None _client_session: aiohttp.ClientSession - requires_unique_id = True - requires_write_log = True - class Config: arbitrary_types_allowed = True async def start( self, - *, - api_key: Optional[str] = None, - write_log: Callable[[str], None], - unique_id: str, **kwargs, ): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context) ) - self.write_log = write_log - self.unique_id = unique_id + self.context_length = MAX_TOKENS_FOR_MODEL[self.model] async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self): - return MAX_TOKENS_FOR_MODEL[self.model] - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model} - - def count_tokens(self, text: str): - return count_tokens(self.model, text) - def get_headers(self): # headers with unique id return {"unique_id": self.unique_id} - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], @@ -107,10 +75,10 @@ class ProxyServer(LLM): self.write_log(f"Completion: \n\n{response_text}") return response_text - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], messages, @@ -158,10 +126,10 @@ class ProxyServer(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( self.model, with_history, diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 0424d827..fb0d3f5c 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -4,39 +4,22 @@ from typing import List import replicate from ...core.main import ChatMessage -from ..util.count_tokens import DEFAULT_ARGS, count_tokens from . import LLM class ReplicateLLM(LLM): api_key: str + "Replicate API key" + model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" - max_context_length: int = 2048 _client: replicate.Client = None - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def start(self): + async def start(self, **kwargs): + await super().start(**kwargs) self._client = replicate.Client(api_token=self.api_key) - async def stop(self): - pass - - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ): def helper(): @@ -55,7 +38,7 @@ class ReplicateLLM(LLM): return completion - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ): for item in self._client.run( @@ -63,7 +46,7 @@ class ReplicateLLM(LLM): ): yield item - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs): + async def _stream_chat(self, messages: List[ChatMessage] = None, **kwargs): for item in self._client.run( self.model, input={"message": messages[-1].content, "prompt": messages[-1].content}, diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py index 380f7b48..59627629 100644 --- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py +++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py @@ -1,51 +1,23 @@ import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union import websockets from ...core.main import ChatMessage -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import compile_chat_messages, format_chat_messages from . import LLM class TextGenUI(LLM): # this is model-specific model: str = "text-gen-ui" - max_context_length: int = 2048 server_url: str = "http://localhost:5000" streaming_url: str = "http://localhost:5005" verify_ssl: Optional[bool] = None - requires_write_log = True - - write_log: Optional[Callable[[str], None]] = None - class Config: arbitrary_types_allowed = True - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - pass - - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - def _transform_args(self, args): args = { **args, @@ -54,18 +26,12 @@ class TextGenUI(LLM): args.pop("max_tokens", None) return args - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} - self.write_log(f"Prompt: \n\n{prompt}") completion = "" @@ -89,12 +55,12 @@ class TextGenUI(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -146,10 +112,10 @@ class TextGenUI(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - generator = self.stream_chat( + generator = self._stream_chat( [ChatMessage(role="user", content=prompt, summary=prompt)], **kwargs ) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index ddae91a9..d8c7334b 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -5,21 +5,23 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages class TogetherLLM(LLM): # this is model-specific api_key: str + "Together API key" + model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" - max_context_length: int = 2048 base_url: str = "https://api.together.xyz" verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None async def start(self, **kwargs): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) ) @@ -27,31 +29,14 @@ class TogetherLLM(LLM): async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream_tokens"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -72,12 +57,12 @@ class TogetherLLM(LLM): except: raise Exception(str(line)) - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -112,10 +97,10 @@ class TogetherLLM(LLM): "content": json_chunk["choices"][0]["text"], } - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index eed43054..45a4a599 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -74,9 +74,6 @@ def add_config_import(line: str): filtered_attrs = { - "requires_api_key", - "requires_unique_id", - "requires_write_log", "class_name", "name", "llm", diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 43a2b800..fe049268 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -30,7 +30,7 @@ class SetupPipelineStep(Step): sdk.context.set("api_description", self.api_description) source_name = ( - await sdk.models.medium.complete( + await sdk.models.medium._complete( f"Write a snake_case name for the data source described by {self.api_description}: " ) ).strip() @@ -115,7 +115,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.medium.complete( + suggestion = await sdk.models.medium._complete( dedent( f"""\ ```python @@ -131,7 +131,7 @@ class ValidatePipelineStep(Step): ) ) - api_documentation_url = await sdk.models.medium.complete( + api_documentation_url = await sdk.models.medium._complete( dedent( f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: @@ -216,7 +216,7 @@ class RunQueryStep(Step): ) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.medium.complete( + suggestion = await sdk.models.medium._complete( dedent( f"""\ ```python diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py index d6769148..44065d22 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py @@ -92,7 +92,7 @@ class LoadDataStep(Step): docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = await sdk.models.default.complete( + suggestion = await sdk.models.default._complete( dedent( f"""\ When trying to load data into BigQuery, the following error occurred: diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index e2712746..4727c994 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -45,7 +45,7 @@ class WritePytestsRecipe(Step): Here is a complete set of pytest unit tests:""" ) - tests = await sdk.models.medium.complete(prompt) + tests = await sdk.models.medium._complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 857183bc..d580f886 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -83,11 +83,17 @@ class SimpleChatStep(Step): messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.chat.stream_chat( + generator = sdk.models.chat._stream_chat( messages, temperature=sdk.config.temperature ) - posthog_logger.capture_event("model_use", {"model": sdk.models.default.name}) + posthog_logger.capture_event( + "model_use", + { + "model": sdk.models.default.model, + "provider": sdk.models.default.__class__.__name__, + }, + ) async for chunk in generator: if sdk.current_step_was_deleted(): @@ -112,7 +118,7 @@ class SimpleChatStep(Step): await sdk.update_ui() self.name = add_ellipsis( remove_quotes_and_escapes( - await sdk.models.medium.complete( + await sdk.models.medium._complete( f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', max_tokens=20, ) @@ -254,7 +260,7 @@ class ChatWithFunctions(Step): gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(gpt350613) - async for msg_chunk in gpt350613.stream_chat( + async for msg_chunk in gpt350613._stream_chat( await sdk.get_chat_context(), functions=functions ): if sdk.current_step_was_deleted(): diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 25633942..9ee2a48d 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -58,7 +58,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""" ) - answer = await sdk.models.medium.complete(prompt) + answer = await sdk.models.medium._complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 1529fe1b..9d40822b 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -208,6 +208,8 @@ class DefaultModelEditCodeStep(Step): _new_contents: str = "" _prompt_and_completion: str = "" + summary_prompt: str = "Please give brief a description of the changes made above using markdown bullet points. Be concise:" + async def describe(self, models: Models) -> Coroutine[str, None, None]: if self._previous_contents.strip() == self._new_contents.strip(): description = "No edits were made" @@ -227,7 +229,7 @@ class DefaultModelEditCodeStep(Step): {changes} ``` - Please give brief a description of the changes made above using markdown bullet points. Be concise:""" + {self.summary_prompt}""" ) ) name = await models.medium.complete( @@ -273,7 +275,7 @@ class DefaultModelEditCodeStep(Step): ) # If using 3.5 and overflows, upgrade to 3.5.16k - if model_to_use.name == "gpt-3.5-turbo": + if model_to_use.model == "gpt-3.5-turbo": if total_tokens > model_to_use.context_length: model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(model_to_use) @@ -661,11 +663,14 @@ Please output the code to be inserted at the cursor in order to fulfill the user else: messages = rendered - generator = model_to_use.stream_chat( + generator = model_to_use._stream_chat( messages, temperature=sdk.config.temperature, max_tokens=max_tokens ) - posthog_logger.capture_event("model_use", {"model": model_to_use.name}) + posthog_logger.capture_event( + "model_use", + {"model": model_to_use.model, "provider": model_to_use.__class__.__name__}, + ) try: async for chunk in generator: diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 148dddb8..c73d7eef 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -59,7 +59,7 @@ class HelpStep(Step): ChatMessage(role="user", content=prompt, summary="Help") ) messages = await sdk.get_chat_context() - generator = sdk.models.default.stream_chat(messages) + generator = sdk.models.default._stream_chat(messages) async for chunk in generator: if "content" in chunk: self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index 721f1306..001876d0 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -26,7 +26,7 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = await sdk.models.default.complete( + gpt_parsed = await sdk.models.default._complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:" ) return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index 2ceb82c5..7762666c 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -105,7 +105,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.medium.complete(prompt) + completion = await sdk.models.medium._complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt @@ -180,7 +180,7 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.medium.complete( + return await models.medium._complete( f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:" ) @@ -245,6 +245,8 @@ class EditHighlightedCodeStep(Step): hide = True description: str = "Change the contents of the currently highlighted code or open file. You should call this function if the user asks seems to be asking for a code change." + summary_prompt: Optional[str] = None + async def describe(self, models: Models) -> Coroutine[str, None, None]: return "Editing code" @@ -293,13 +295,15 @@ class EditHighlightedCodeStep(Step): self.description = "Please accept or reject the change before making another edit in this file." return - await sdk.run_step( - DefaultModelEditCodeStep( - user_input=self.user_input, - range_in_files=range_in_files, - model=self.model, - ) - ) + args = { + "user_input": self.user_input, + "range_in_files": range_in_files, + "model": self.model, + } + if self.summary_prompt: + args["summary_prompt"] = self.summary_prompt + + await sdk.run_step(DefaultModelEditCodeStep(**args)) class UserInputStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index a2612731..2ed2d3d7 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -29,7 +29,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""" ) - resp = (await sdk.models.medium.complete(prompt)).lower() + resp = (await sdk.models.medium._complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 04fb98b7..9317bfe1 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -46,7 +46,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.medium.complete( + pattern = await sdk.models.medium._complete( dedent( f"""\ This is the user request: diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 2c4f2e4d..49541b76 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -23,7 +23,6 @@ from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..plugins.steps.core.core import DisplayErrorStep from ..plugins.steps.setup_model import SetupModelStep -from .gui_protocol import AbstractGUIProtocolServer from .session_manager import Session, session_manager router = APIRouter(prefix="/gui", tags=["gui"]) @@ -54,7 +53,7 @@ T = TypeVar("T", bound=BaseModel) # You should probably abstract away the websocket stuff into a separate class -class GUIProtocolServer(AbstractGUIProtocolServer): +class GUIProtocolServer: websocket: WebSocket session: Session sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() @@ -118,8 +117,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.set_system_message(data["message"]) elif message_type == "set_temperature": self.set_temperature(float(data["temperature"])) - elif message_type == "set_model_for_role": - self.set_model_for_role(data["role"], data["model_class"], data["model"]) + elif message_type == "add_model_for_role": + self.add_model_for_role(data["role"], data["model_class"], data["model"]) + elif message_type == "set_model_for_role_from_index": + self.set_model_for_role_from_index(data["role"], data["index"]) elif message_type == "save_context_group": self.save_context_group( data["title"], [ContextItem(**item) for item in data["context_items"]] @@ -230,51 +231,50 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_error, ) - def set_model_for_role(self, role: str, model_class: str, model: Any): + def set_model_for_role_from_index(self, role: str, index: int): + async def async_stuff(): + models = self.session.autopilot.continue_sdk.config.models + + # Set models in SDK + temp = models.default + models.default = models.unused[index] + models.unused[index] = temp + await self.session.autopilot.continue_sdk.start_model(models.default) + + # Set models in config.py + JOINER = ", " + models_args = { + "unused": f"[{JOINER.join([display_llm_class(llm) for llm in models.unused])}]", + ("default" if role == "*" else role): display_llm_class(models.default), + } + + await self.session.autopilot.set_config_attr( + ["models"], + create_obj_node("Models", models_args), + ) + + for other_role in ALL_MODEL_ROLES: + if other_role != "default": + models.__setattr__(other_role, models.default) + + await self.session.autopilot.continue_sdk.update_ui() + + create_async_task(async_stuff(), self.on_error) + + def add_model_for_role(self, role: str, model_class: str, model: Any): models = self.session.autopilot.continue_sdk.config.models unused_models = models.unused if role == "*": async def async_stuff(): - # Clear all of the models and store them in unused_models - # NOTE: There will be duplicates for role in ALL_MODEL_ROLES: - prev_model = models.__getattribute__(role) models.__setattr__(role, None) - if prev_model is not None: - exists = False - for other in unused_models: - if ( - prev_model.__class__.__name__ - == other.__class__.__name__ - and ( - other.name is not None - and ( - not other.name.startswith("gpt") - or prev_model.name == other.name - ) - ) - ): - exists = True - break - if not exists: - unused_models.append(prev_model) - - # Replace default with either new one or existing from unused_models - for unused_model in unused_models: - if model_class == unused_model.__class__.__name__ and ( - "model" not in model - or model["model"] == unused_model.model - and model["model"].startswith("gpt") - ): - models.default = unused_model # Set and start the default model if didn't already exist from unused - if models.default is None: - models.default = MODEL_CLASSES[model_class](**model) - await self.session.autopilot.continue_sdk.run_step( - SetupModelStep(model_class=model_class) - ) + models.default = MODEL_CLASSES[model_class](**model) + await self.session.autopilot.continue_sdk.run_step( + SetupModelStep(model_class=model_class) + ) await self.session.autopilot.continue_sdk.start_model(models.default) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py deleted file mode 100644 index d079475c..00000000 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class AbstractGUIProtocolServer(ABC): - @abstractmethod - async def handle_json(self, data: Any): - """Handle a json message""" - - @abstractmethod - def on_main_input(self, input: str): - """Called when the user inputs something""" - - @abstractmethod - def on_reverse_to_index(self, index: int): - """Called when the user requests reverse to a previous index""" - - @abstractmethod - def on_refinement_input(self, input: str, index: int): - """Called when the user inputs a refinement""" - - @abstractmethod - def on_step_user_input(self, input: str, index: int): - """Called when the user inputs a step""" - - @abstractmethod - def on_retry_at_index(self, index: int): - """Called when the user requests a retry at a previous index""" - - @abstractmethod - def on_clear_history(self): - """Called when the user requests to clear the history""" - - @abstractmethod - def on_delete_at_index(self, index: int): - """Called when the user requests to delete a step at a given index""" - - @abstractmethod - def select_context_item(self, id: str, query: str): - """Called when user selects an item from the dropdown""" diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 390eeb50..e64588e7 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -69,7 +69,7 @@ async def ensure_meilisearch_installed() -> bool: else: non_existing_paths.add(path) - if len(non_existing_paths) > 0: + if len(non_existing_paths) > 0 and os.name != "nt": # Clear the meilisearch binary if meilisearchPath in existing_paths: os.remove(meilisearchPath) @@ -134,5 +134,5 @@ async def start_meilisearch(): stderr=subprocess.STDOUT, close_fds=True, start_new_session=True, - shell=True + shell=True, ) |