diff options
30 files changed, 242 insertions, 513 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug-report-🐛.md b/.github/ISSUE_TEMPLATE/bug-report-🐛.md index ab37cfbe..1070d7e2 100644 --- a/.github/ISSUE_TEMPLATE/bug-report-🐛.md +++ b/.github/ISSUE_TEMPLATE/bug-report-🐛.md @@ -1,10 +1,9 @@ --- name: "Bug report \U0001F41B" about: Create a report to help us fix your bug -title: '' +title: "" labels: bug -assignees: '' - +assignees: "" --- **Describe the bug** @@ -12,32 +11,45 @@ A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: + 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your problem. - **Environment** + - Operating System: [e.g. MacOS] - Python Version: [e.g. 3.10.6] - Continue Version: [e.g. v0.0.207] -**Console logs** +**Logs** + ``` REPLACE THIS SECTION WITH CONSOLE LOGS OR A SCREENSHOT... +``` + +To get the Continue server logs: + +1. cmd+shift+p (MacOS) / ctrl+shift+p (Windows) +2. Search for and then select "Continue: View Continue Server Logs" +3. Scroll to the bottom of `continue.log` and copy the last 100 lines or so + +To get the VS Code console logs: -To get the console logs in VS Code: 1. cmd+shift+p (MacOS) / ctrl+shift+p (Windows) 2. Search for and then select "Developer: Toggle Developer Tools" 3. Select Console 4. Read the console logs -``` + +If the problem is related to LLM prompting: + +1. Hover the problematic response in the Continue UI +2. Click the "magnifying glass" icon +3. Copy the contents of the `continue_logs.txt` file that opens + +**Screenshots** +If applicable, add screenshots to help explain your problem. **Additional context** Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature-request-💪.md b/.github/ISSUE_TEMPLATE/feature-request-💪.md index 2b138a9a..b356d488 100644 --- a/.github/ISSUE_TEMPLATE/feature-request-💪.md +++ b/.github/ISSUE_TEMPLATE/feature-request-💪.md @@ -1,10 +1,9 @@ --- name: "Feature request \U0001F4AA" about: Suggest an idea for this project -title: '' +title: "" labels: enhancement assignees: TyDunn - --- **Is your feature request related to a problem? Please describe.** @@ -13,8 +12,5 @@ A clear and concise description of what the problem is. Ex. I'm always frustrate **Describe the solution you'd like** A clear and concise description of what you want to happen. -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - **Additional context** Add any other context or screenshots about the feature request here. diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 98730d38..fdb99d47 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -71,10 +71,6 @@ class AbstractContinueSDK(ABC): async def delete_directory(self, path: str): pass - @abstractmethod - async def get_user_secret(self, env_var: str) -> str: - pass - config: ContinueConfig @abstractmethod diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index bae82739..de0b8c53 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -507,7 +507,7 @@ class Autopilot(ContinueBaseModel): if self.session_info is None: async def create_title(): - title = await self.continue_sdk.models.medium.complete( + title = await self.continue_sdk.models.medium._complete( f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ', max_tokens=20, ) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 37992b67..9b1c2cd0 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -94,14 +94,7 @@ class ContinueSDK(AbstractContinueSDK): self.history.timeline[self.history.current_index].logs.append(message) async def start_model(self, llm: LLM): - kwargs = {} - if llm.requires_api_key: - kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key) - if llm.requires_unique_id: - kwargs["unique_id"] = self.ide.unique_id - if llm.requires_write_log: - kwargs["write_log"] = self.write_log - await llm.start(**kwargs) + await llm.start(unique_id=self.ide.unique_id, write_log=self.write_log) async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): @@ -211,10 +204,6 @@ class ContinueSDK(AbstractContinueSDK): path = await self._ensure_absolute_path(path) return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path))) - async def get_user_secret(self, env_var: str) -> str: - # TODO support error prompt dynamically set on env_var - return await self.ide.getUserSecret(env_var) - _last_valid_config: ContinueConfig = None def _load_config_dot_py(self) -> ContinueConfig: diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 1e77a691..6a321a41 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,19 +1,32 @@ -from abc import ABC, abstractproperty -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from abc import ABC +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from ...models.main import ContinueBaseModel +from ..util.count_tokens import DEFAULT_ARGS, count_tokens class LLM(ContinueBaseModel, ABC): - requires_api_key: Optional[str] = None - requires_unique_id: bool = False - requires_write_log: bool = False title: Optional[str] = None system_message: Optional[str] = None + context_length: int = 2048 + "The maximum context length of the LLM in tokens, as counted by count_tokens." + + unique_id: Optional[str] = None + "The unique ID of the user." + + model: str + "The model name" + prompt_templates: dict = {} + write_log: Optional[Callable[[str], None]] = None + "A function that takes a string and writes it to the log." + + api_key: Optional[str] = None + "The API key for the LLM provider." + class Config: arbitrary_types_allowed = True extra = "allow" @@ -21,36 +34,39 @@ class LLM(ContinueBaseModel, ABC): def dict(self, **kwargs): original_dict = super().dict(**kwargs) original_dict.pop("write_log", None) - original_dict["name"] = self.name original_dict["class_name"] = self.__class__.__name__ return original_dict - @abstractproperty - def name(self): - """Return the name of the LLM.""" - raise NotImplementedError + def collect_args(self, **kwargs) -> Any: + """Collect the arguments for the LLM.""" + args = {**DEFAULT_ARGS.copy(), "model": self.model, "max_tokens": 1024} + args.update(kwargs) + return args - async def start(self, *, api_key: Optional[str] = None, **kwargs): + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): """Start the connection to the LLM.""" - raise NotImplementedError + self.write_log = write_log + self.unique_id = unique_id async def stop(self): """Stop the connection to the LLM.""" - raise NotImplementedError + pass - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" raise NotImplementedError - def stream_complete( + def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the completion through generator.""" raise NotImplementedError - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the chat through generator.""" @@ -58,9 +74,4 @@ class LLM(ContinueBaseModel, ABC): def count_tokens(self, text: str): """Return the number of tokens in the given text.""" - raise NotImplementedError - - @abstractproperty - def context_length(self) -> int: - """Return the context length of the LLM in tokens, as counted by count_tokens.""" - raise NotImplementedError + return count_tokens(self.model, text) diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 16bc2fce..b5aff63a 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,47 +1,36 @@ -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Union from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages class AnthropicLLM(LLM): api_key: str + "Anthropic API key" + model: str = "claude-2" - requires_write_log = True _async_client: AsyncAnthropic = None class Config: arbitrary_types_allowed = True - write_log: Optional[Callable[[str], None]] = None - async def start( self, - *, - api_key: Optional[str] = None, - write_log: Callable[[str], None], **kwargs, ): - self.write_log = write_log + await super().start(**kwargs) self._async_client = AsyncAnthropic(api_key=self.api_key) - async def stop(self): - pass - - @property - def name(self): - return self.model + if self.model == "claude-2": + self.context_length = 100_000 - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model} + def collect_args(self, **kwargs) -> Any: + args = super().collect_args(**kwargs) - def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: - args = args.copy() if "max_tokens" in args: args["max_tokens_to_sample"] = args["max_tokens"] del args["max_tokens"] @@ -51,15 +40,6 @@ class AnthropicLLM(LLM): del args["presence_penalty"] return args - def count_tokens(self, text: str): - return count_tokens(self.model, text) - - @property - def context_length(self): - if self.model == "claude-2": - return 100000 - raise Exception(f"Unknown Anthropic model {self.model}") - def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: prompt = "" @@ -76,13 +56,11 @@ class AnthropicLLM(LLM): prompt += AI_PROMPT return prompt - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = self._transform_args(args) prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" self.write_log(f"Prompt: \n\n{prompt}") @@ -95,13 +73,11 @@ class AnthropicLLM(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = self._transform_args(args) messages = compile_chat_messages( args["model"], @@ -123,11 +99,10 @@ class AnthropicLLM(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} - args = self._transform_args(args) + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index db3aaed7..1668fb65 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,62 +1,29 @@ import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import compile_chat_messages, format_chat_messages class GGML(LLM): - # this is model-specific - max_context_length: int = 2048 server_url: str = "http://localhost:8000" verify_ssl: Optional[bool] = None - - requires_write_log = True - - write_log: Optional[Callable[[str], None]] = None + model: str = "ggml" class Config: arbitrary_types_allowed = True - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - pass - - @property - def name(self): - return "ggml" - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -84,12 +51,12 @@ class GGML(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -142,10 +109,10 @@ class GGML(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) self.write_log(f"Prompt: \n\n{prompt}") async with aiohttp.ClientSession( diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 4b78a247..3a586a43 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -5,17 +5,14 @@ import requests from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, count_tokens DEFAULT_MAX_TIME = 120.0 class HuggingFaceInferenceAPI(LLM): - model: str hf_token: str self_hosted_url: str = None - max_context_length: int = 2048 verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None @@ -24,6 +21,7 @@ class HuggingFaceInferenceAPI(LLM): arbitrary_types_allowed = True async def start(self, **kwargs): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) ) @@ -31,22 +29,7 @@ class HuggingFaceInferenceAPI(LLM): async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ): """Return the completion of the text with the given temperature.""" @@ -77,14 +60,14 @@ class HuggingFaceInferenceAPI(LLM): return data[0]["generated_text"] - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: - response = await self.complete(messages[-1].content, messages[:-1]) + response = await self._complete(messages[-1].content, messages[:-1]) yield {"content": response, "role": "assistant"} - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Any | List | Dict, None, None]: - response = await self.complete(prompt, with_history) + response = await self._complete(prompt, with_history) yield response diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index f04e700d..f106f83f 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -5,44 +5,22 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import code_llama_template_messages class HuggingFaceTGI(LLM): model: str = "huggingface-tgi" - max_context_length: int = 2048 server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages - requires_write_log = True - - write_log: Optional[Callable[[str], None]] = None - class Config: arbitrary_types_allowed = True - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - pass - - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def _transform_args(self, args): + def collect_args(self, **kwargs) -> Any: + args = super().collect_args(**kwargs) args = { **args, "max_new_tokens": args.get("max_tokens", 1024), @@ -50,19 +28,14 @@ class HuggingFaceTGI(LLM): args.pop("max_tokens", None) return args - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -93,12 +66,12 @@ class HuggingFaceTGI(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -107,7 +80,7 @@ class HuggingFaceTGI(LLM): system_message=self.system_message, ) - async for chunk in self.stream_complete( + async for chunk in self._stream_complete( None, self.template_messages(messages), **args ): yield { @@ -115,13 +88,13 @@ class HuggingFaceTGI(LLM): "content": chunk, } - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) completion = "" - async for chunk in self.stream_complete(prompt, with_history, **args): + async for chunk in self._stream_complete(prompt, with_history, **args): completion += chunk return completion diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index e6f38cd0..7940c4c9 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -6,12 +6,12 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages class LlamaCpp(LLM): - max_context_length: int = 2048 + model: str = "llamacpp" server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None @@ -20,9 +20,6 @@ class LlamaCpp(LLM): use_command: Optional[str] = None - requires_write_log = True - write_log: Optional[Callable[[str], None]] = None - class Config: arbitrary_types_allowed = True @@ -31,29 +28,8 @@ class LlamaCpp(LLM): d.pop("template_messages") return d - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - await self._client_session.close() - - @property - def name(self): - return "llamacpp" - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]: - args = args.copy() + def collect_args(self, **kwargs) -> Any: + args = super().collect_args(**kwargs) if "max_tokens" in args: args["n_predict"] = args["max_tokens"] del args["max_tokens"] @@ -85,16 +61,14 @@ class LlamaCpp(LLM): await process.wait() - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -125,12 +99,12 @@ class LlamaCpp(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -177,10 +151,10 @@ class LlamaCpp(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) self.write_log(f"Prompt: \n\n{prompt}") diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index daffe41f..99b7c47f 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from . import LLM @@ -10,63 +10,42 @@ class MaybeProxyOpenAI(LLM): model: str api_key: Optional[str] = None - requires_write_log: bool = True - requires_unique_id: bool = True - system_message: Union[str, None] = None - llm: Optional[LLM] = None def update_llm_properties(self): if self.llm is not None: self.llm.system_message = self.system_message - @property - def name(self): - if self.llm is not None: - return self.llm.name - else: - return None - - @property - def context_length(self): - return self.llm.context_length - - async def start( - self, - *, - api_key: Optional[str] = None, - unique_id: str, - write_log: Callable[[str], None] - ): + async def start(self, **kwargs): if self.api_key is None or self.api_key.strip() == "": self.llm = ProxyServer(model=self.model) else: self.llm = OpenAI(api_key=self.api_key, model=self.model) - await self.llm.start(write_log=write_log, unique_id=unique_id) + await self.llm.start(**kwargs) async def stop(self): await self.llm.stop() - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: self.update_llm_properties() - return await self.llm.complete(prompt, with_history=with_history, **kwargs) + return await self.llm._complete(prompt, with_history=with_history, **kwargs) - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: self.update_llm_properties() - resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs) + resp = self.llm._stream_complete(prompt, with_history=with_history, **kwargs) async for item in resp: yield item - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: self.update_llm_properties() - resp = self.llm.stream_chat(messages=messages, **kwargs) + resp = self.llm._stream_chat(messages=messages, **kwargs) async for item in resp: yield item diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index 03300435..ef8ed47b 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -7,17 +7,15 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages class Ollama(LLM): model: str = "llama2" server_url: str = "http://localhost:11434" - max_context_length: int = 2048 _client_session: aiohttp.ClientSession = None - requires_write_log = True prompt_templates = { "edit": dedent( @@ -36,34 +34,19 @@ class Ollama(LLM): class Config: arbitrary_types_allowed = True - async def start(self, write_log, **kwargs): + async def start(self, **kwargs): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession() - self.write_log = write_log async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self) -> int: - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -102,12 +85,12 @@ class Ollama(LLM): yield urllib.parse.unquote(url_decode_buffer) url_decode_buffer = "" - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -143,11 +126,11 @@ class Ollama(LLM): completion += j["response"] self.write_log(f"Completion:\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: completion = "" - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) async with self._client_session.post( f"{self.server_url}/api/generate", json={ diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index c5d19ed2..a017af22 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -12,26 +12,15 @@ from typing import ( import certifi import openai -from pydantic import BaseModel from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import ( - DEFAULT_ARGS, compile_chat_messages, - count_tokens, format_chat_messages, prune_raw_prompt_from_top, ) - -class OpenAIServerInfo(BaseModel): - api_base: Optional[str] = None - engine: Optional[str] = None - api_version: Optional[str] = None - api_type: Literal["azure", "openai"] = "openai" - - CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"} MAX_TOKENS_FOR_MODEL = { "gpt-3.5-turbo": 4096, @@ -47,32 +36,43 @@ MAX_TOKENS_FOR_MODEL = { class OpenAI(LLM): api_key: str - model: str - openai_server_info: Optional[OpenAIServerInfo] = None + "OpenAI API key" + verify_ssl: Optional[bool] = None + "Whether to verify SSL certificates for requests." + ca_bundle_path: Optional[str] = None + "Path to CA bundle to use for requests." + proxy: Optional[str] = None + "Proxy URL to use for requests." - requires_write_log = True + api_base: Optional[str] = None + "OpenAI API base URL." - write_log: Optional[Callable[[str], None]] = None + api_type: Optional[Literal["azure", "openai"]] = None + "OpenAI API type." + + api_version: Optional[str] = None + "OpenAI API version. For use with Azure OpenAI Service." + + engine: Optional[str] = None + "OpenAI engine. For use with Azure OpenAI Service." async def start( - self, - *, - api_key: Optional[str] = None, - write_log: Callable[[str], None], - **kwargs, + self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None ): - self.write_log = write_log - openai.api_key = self.api_key + await super().start(write_log=write_log, unique_id=unique_id) + + self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096) - if self.openai_server_info is not None: - openai.api_type = self.openai_server_info.api_type - if self.openai_server_info.api_base is not None: - openai.api_base = self.openai_server_info.api_base - if self.openai_server_info.api_version is not None: - openai.api_version = self.openai_server_info.api_version + openai.api_key = self.api_key + if self.api_type is not None: + openai.api_type = self.api_type + if self.api_base is not None: + openai.api_base = self.api_base + if self.api_version is not None: + openai.api_version = self.api_version if self.verify_ssl is not None and self.verify_ssl is False: openai.verify_ssl_certs = False @@ -82,32 +82,16 @@ class OpenAI(LLM): openai.ca_bundle_path = self.ca_bundle_path or certifi.where() - async def stop(self): - pass - - @property - def name(self): - return self.model - - @property - def context_length(self): - return MAX_TOKENS_FOR_MODEL.get(self.model, 4096) - - @property - def default_args(self): - args = {**DEFAULT_ARGS, "model": self.model} - if self.openai_server_info is not None: - args["engine"] = self.openai_server_info.engine + def collect_args(self, **kwargs): + args = super().collect_args() + if self.engine is not None: + args["engine"] = self.engine return args - def count_tokens(self, text: str): - return count_tokens(self.model, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True if args["model"] in CHAT_MODELS: @@ -142,11 +126,10 @@ class OpenAI(LLM): self.write_log(f"Completion:\n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True if not args["model"].endswith("0613") and "functions" in args: @@ -174,10 +157,10 @@ class OpenAI(LLM): completion += chunk.choices[0].delta.content self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) if args["model"] in CHAT_MODELS: messages = compile_chat_messages( diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index fa77a22a..3ac6371f 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,19 +1,14 @@ import json import ssl import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Union import aiohttp import certifi from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import compile_chat_messages, format_chat_messages from ..util.telemetry import posthog_logger ca_bundle_path = certifi.where() @@ -31,59 +26,32 @@ MAX_TOKENS_FOR_MODEL = { class ProxyServer(LLM): - model: str - system_message: Optional[str] - - unique_id: str = None - write_log: Callable[[str], None] = None _client_session: aiohttp.ClientSession - requires_unique_id = True - requires_write_log = True - class Config: arbitrary_types_allowed = True async def start( self, - *, - api_key: Optional[str] = None, - write_log: Callable[[str], None], - unique_id: str, **kwargs, ): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context) ) - self.write_log = write_log - self.unique_id = unique_id + self.context_length = MAX_TOKENS_FOR_MODEL[self.model] async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self): - return MAX_TOKENS_FOR_MODEL[self.model] - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model} - - def count_tokens(self, text: str): - return count_tokens(self.model, text) - def get_headers(self): # headers with unique id return {"unique_id": self.unique_id} - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], @@ -107,10 +75,10 @@ class ProxyServer(LLM): self.write_log(f"Completion: \n\n{response_text}") return response_text - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], messages, @@ -158,10 +126,10 @@ class ProxyServer(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( self.model, with_history, diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 0424d827..fb0d3f5c 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -4,39 +4,22 @@ from typing import List import replicate from ...core.main import ChatMessage -from ..util.count_tokens import DEFAULT_ARGS, count_tokens from . import LLM class ReplicateLLM(LLM): api_key: str + "Replicate API key" + model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" - max_context_length: int = 2048 _client: replicate.Client = None - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def start(self): + async def start(self, **kwargs): + await super().start(**kwargs) self._client = replicate.Client(api_token=self.api_key) - async def stop(self): - pass - - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ): def helper(): @@ -55,7 +38,7 @@ class ReplicateLLM(LLM): return completion - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ): for item in self._client.run( @@ -63,7 +46,7 @@ class ReplicateLLM(LLM): ): yield item - async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs): + async def _stream_chat(self, messages: List[ChatMessage] = None, **kwargs): for item in self._client.run( self.model, input={"message": messages[-1].content, "prompt": messages[-1].content}, diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py index 380f7b48..59627629 100644 --- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py +++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py @@ -1,51 +1,23 @@ import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union import websockets from ...core.main import ChatMessage -from ..util.count_tokens import ( - DEFAULT_ARGS, - compile_chat_messages, - count_tokens, - format_chat_messages, -) +from ..util.count_tokens import compile_chat_messages, format_chat_messages from . import LLM class TextGenUI(LLM): # this is model-specific model: str = "text-gen-ui" - max_context_length: int = 2048 server_url: str = "http://localhost:5000" streaming_url: str = "http://localhost:5005" verify_ssl: Optional[bool] = None - requires_write_log = True - - write_log: Optional[Callable[[str], None]] = None - class Config: arbitrary_types_allowed = True - async def start(self, write_log: Callable[[str], None], **kwargs): - self.write_log = write_log - - async def stop(self): - pass - - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} - def _transform_args(self, args): args = { **args, @@ -54,18 +26,12 @@ class TextGenUI(LLM): args.pop("max_tokens", None) return args - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream"] = True - args = {**self.default_args, **kwargs} - self.write_log(f"Prompt: \n\n{prompt}") completion = "" @@ -89,12 +55,12 @@ class TextGenUI(LLM): self.write_log(f"Completion: \n\n{completion}") - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -146,10 +112,10 @@ class TextGenUI(LLM): self.write_log(f"Completion: \n\n{completion}") - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - generator = self.stream_chat( + generator = self._stream_chat( [ChatMessage(role="user", content=prompt, summary=prompt)], **kwargs ) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index ddae91a9..d8c7334b 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -5,21 +5,23 @@ import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens +from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages class TogetherLLM(LLM): # this is model-specific api_key: str + "Together API key" + model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" - max_context_length: int = 2048 base_url: str = "https://api.together.xyz" verify_ssl: Optional[bool] = None _client_session: aiohttp.ClientSession = None async def start(self, **kwargs): + await super().start(**kwargs) self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) ) @@ -27,31 +29,14 @@ class TogetherLLM(LLM): async def stop(self): await self._client_session.close() - @property - def name(self): - return self.model - - @property - def context_length(self): - return self.max_context_length - - @property - def default_args(self): - return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} - - def count_tokens(self, text: str): - return count_tokens(self.name, text) - - async def stream_complete( + async def _stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.default_args.copy() - args.update(kwargs) + args = self.collect_args(**kwargs) args["stream_tokens"] = True - args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, + self.model, with_history, self.context_length, args["max_tokens"], @@ -72,12 +57,12 @@ class TogetherLLM(LLM): except: raise Exception(str(line)) - async def stream_chat( + async def _stream_chat( self, messages: List[ChatMessage] = None, **kwargs ) -> Generator[Union[Any, List, Dict], None, None]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( - self.name, + self.model, messages, self.context_length, args["max_tokens"], @@ -112,10 +97,10 @@ class TogetherLLM(LLM): "content": json_chunk["choices"][0]["text"], } - async def complete( + async def _complete( self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ) -> Coroutine[Any, Any, str]: - args = {**self.default_args, **kwargs} + args = self.collect_args(**kwargs) messages = compile_chat_messages( args["model"], diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index eed43054..45a4a599 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -74,9 +74,6 @@ def add_config_import(line: str): filtered_attrs = { - "requires_api_key", - "requires_unique_id", - "requires_write_log", "class_name", "name", "llm", diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 43a2b800..fe049268 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -30,7 +30,7 @@ class SetupPipelineStep(Step): sdk.context.set("api_description", self.api_description) source_name = ( - await sdk.models.medium.complete( + await sdk.models.medium._complete( f"Write a snake_case name for the data source described by {self.api_description}: " ) ).strip() @@ -115,7 +115,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.medium.complete( + suggestion = await sdk.models.medium._complete( dedent( f"""\ ```python @@ -131,7 +131,7 @@ class ValidatePipelineStep(Step): ) ) - api_documentation_url = await sdk.models.medium.complete( + api_documentation_url = await sdk.models.medium._complete( dedent( f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: @@ -216,7 +216,7 @@ class RunQueryStep(Step): ) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.medium.complete( + suggestion = await sdk.models.medium._complete( dedent( f"""\ ```python diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py index d6769148..44065d22 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py @@ -92,7 +92,7 @@ class LoadDataStep(Step): docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = await sdk.models.default.complete( + suggestion = await sdk.models.default._complete( dedent( f"""\ When trying to load data into BigQuery, the following error occurred: diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index e2712746..4727c994 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -45,7 +45,7 @@ class WritePytestsRecipe(Step): Here is a complete set of pytest unit tests:""" ) - tests = await sdk.models.medium.complete(prompt) + tests = await sdk.models.medium._complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 857183bc..d580f886 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -83,11 +83,17 @@ class SimpleChatStep(Step): messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.chat.stream_chat( + generator = sdk.models.chat._stream_chat( messages, temperature=sdk.config.temperature ) - posthog_logger.capture_event("model_use", {"model": sdk.models.default.name}) + posthog_logger.capture_event( + "model_use", + { + "model": sdk.models.default.model, + "provider": sdk.models.default.__class__.__name__, + }, + ) async for chunk in generator: if sdk.current_step_was_deleted(): @@ -112,7 +118,7 @@ class SimpleChatStep(Step): await sdk.update_ui() self.name = add_ellipsis( remove_quotes_and_escapes( - await sdk.models.medium.complete( + await sdk.models.medium._complete( f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', max_tokens=20, ) @@ -254,7 +260,7 @@ class ChatWithFunctions(Step): gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(gpt350613) - async for msg_chunk in gpt350613.stream_chat( + async for msg_chunk in gpt350613._stream_chat( await sdk.get_chat_context(), functions=functions ): if sdk.current_step_was_deleted(): diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 25633942..9ee2a48d 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -58,7 +58,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""" ) - answer = await sdk.models.medium.complete(prompt) + answer = await sdk.models.medium._complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 2c7416aa..9d40822b 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -275,7 +275,7 @@ class DefaultModelEditCodeStep(Step): ) # If using 3.5 and overflows, upgrade to 3.5.16k - if model_to_use.name == "gpt-3.5-turbo": + if model_to_use.model == "gpt-3.5-turbo": if total_tokens > model_to_use.context_length: model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(model_to_use) @@ -663,11 +663,14 @@ Please output the code to be inserted at the cursor in order to fulfill the user else: messages = rendered - generator = model_to_use.stream_chat( + generator = model_to_use._stream_chat( messages, temperature=sdk.config.temperature, max_tokens=max_tokens ) - posthog_logger.capture_event("model_use", {"model": model_to_use.name}) + posthog_logger.capture_event( + "model_use", + {"model": model_to_use.model, "provider": model_to_use.__class__.__name__}, + ) try: async for chunk in generator: diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 148dddb8..c73d7eef 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -59,7 +59,7 @@ class HelpStep(Step): ChatMessage(role="user", content=prompt, summary="Help") ) messages = await sdk.get_chat_context() - generator = sdk.models.default.stream_chat(messages) + generator = sdk.models.default._stream_chat(messages) async for chunk in generator: if "content" in chunk: self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index 721f1306..001876d0 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -26,7 +26,7 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = await sdk.models.default.complete( + gpt_parsed = await sdk.models.default._complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:" ) return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index ca15aaab..7762666c 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -105,7 +105,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.medium.complete(prompt) + completion = await sdk.models.medium._complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt @@ -180,7 +180,7 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.medium.complete( + return await models.medium._complete( f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:" ) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index a2612731..2ed2d3d7 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -29,7 +29,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""" ) - resp = (await sdk.models.medium.complete(prompt)).lower() + resp = (await sdk.models.medium._complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 04fb98b7..9317bfe1 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -46,7 +46,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.medium.complete( + pattern = await sdk.models.medium._complete( dedent( f"""\ This is the user request: |