diff options
Diffstat (limited to 'continuedev')
20 files changed, 68 insertions, 44 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index e4cb8ed6..900762b6 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -34,33 +34,23 @@ class Models(BaseModel): '''depending on the model, return the single prompt string''' """ - async def _start_llm(self, llm: LLM): - kwargs = {} - if llm.requires_api_key: - kwargs["api_key"] = await self.sdk.get_api_key(llm.requires_api_key) - if llm.requires_unique_id: - kwargs["unique_id"] = self.sdk.ide.unique_id - if llm.requires_write_log: - kwargs["write_log"] = self.sdk.write_log - await llm.start(**kwargs) - async def start(self, sdk: "ContinueSDK"): """Start each of the LLMs, or fall back to default""" self.sdk = sdk self.system_message = self.sdk.config.system_message - await self._start_llm(self.default) + await sdk.start_model(self.default) if self.small: - await self._start_llm(self.small) + await sdk.start_model(self.small) else: self.small = self.default if self.medium: - await self._start_llm(self.medium) + await sdk.start_model(self.medium) else: self.medium = self.default if self.large: - await self._start_llm(self.large) + await sdk.start_model(self.large) else: self.large = self.default diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index b0f7d40a..7febb932 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -15,13 +15,13 @@ from .main import Context, ContinueCustomException, History, HistoryNode, Step, from ..plugins.steps.core.core import * from ..libs.util.telemetry import posthog_logger from ..libs.util.paths import getConfigFilePath +from .models import Models class Autopilot: pass - class ContinueSDK(AbstractContinueSDK): """The SDK provided as parameters to a step""" ide: AbstractIdeProtocolServer @@ -66,6 +66,16 @@ class ContinueSDK(AbstractContinueSDK): def write_log(self, message: str): self.history.timeline[self.history.current_index].logs.append(message) + async def start_model(self, llm: LLM): + kwargs = {} + if llm.requires_api_key: + kwargs["api_key"] = await self.get_api_key(llm.requires_api_key) + if llm.requires_unique_id: + kwargs["unique_id"] = self.ide.unique_id + if llm.requires_write_log: + kwargs["write_log"] = self.write_log + await llm.start(**kwargs) + async def _ensure_absolute_path(self, path: str) -> str: if os.path.isabs(path): return path diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 5708747f..7cd2226a 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -31,7 +31,10 @@ class CommitMessageStep(Step): # Ask gpt-3.5-16k to write a commit message, # and set it as the description of this step - self.description = await sdk.models.gpt3516k.complete( + gpt3516k = OpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(gpt3516k) + + self.description = await gpt3516k.complete( f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 21afc338..58572634 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,5 +1,5 @@ import functools -from abc import ABC +from abc import ABC, abstractproperty from pydantic import BaseModel, ConfigDict from typing import Any, Coroutine, Dict, Generator, List, Union, Optional @@ -15,7 +15,12 @@ class LLM(BaseModel, ABC): system_message: Union[str, None] = None - async def start(self, *, api_key: Optional[str] = None): + @abstractproperty + def name(self): + """Return the name of the LLM.""" + raise NotImplementedError + + async def start(self, *, api_key: Optional[str] = None, **kwargs): """Start the connection to the LLM.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index 067a903b..c9c8e9db 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,7 +1,7 @@ from functools import cached_property import time -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union from ...core.main import ChatMessage from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic from ..llm import LLM @@ -18,7 +18,7 @@ class AnthropicLLM(LLM): self.model = model self.system_message = system_message - async def start(self, *, api_key: str): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self._async_client = AsyncAnthropic(api_key=api_key) async def stop(self): diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 4ad32e0e..49f593d8 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from ...core.main import ChatMessage from ..llm import LLM import requests @@ -17,7 +17,7 @@ class HuggingFaceInferenceAPI(LLM): self.model = model self.system_message = system_message # TODO: Nothing being done with this - async def start(self, *, api_key: str): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self.api_key = api_key def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index d2898b5c..121ae99e 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -15,6 +15,10 @@ class MaybeProxyOpenAI(LLM): llm: Optional[LLM] = None + @property + def name(self): + return self.llm.name + async def start(self, *, api_key: Optional[str] = None, **kwargs): if api_key is None or api_key.strip() == "": self.llm = ProxyServer( @@ -22,17 +26,24 @@ class MaybeProxyOpenAI(LLM): else: self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"]) + await self.llm.start(api_key=api_key, **kwargs) + async def stop(self): await self.llm.stop() async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: return await self.llm.complete(prompt, with_history=with_history, **kwargs) - def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - return self.llm.stream_complete(prompt, with_history=with_history, **kwargs) + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + resp = self.llm.stream_complete( + prompt, with_history=with_history, **kwargs) + async for item in resp: + yield item async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - return self.llm.stream_chat(messages=messages, **kwargs) + resp = self.llm.stream_chat(messages=messages, **kwargs) + async for item in resp: + yield item def count_tokens(self, text: str): return self.llm.count_tokens(text) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 0c2c360b..de02a614 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -26,7 +26,7 @@ class OpenAI(LLM): write_log: Optional[Callable[[str], None]] = None api_key: str = None - async def start(self, *, api_key): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self.api_key = api_key openai.api_key = self.api_key diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index e8f1cb46..1c942523 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -27,7 +27,7 @@ class ProxyServer(LLM): requires_unique_id = True requires_write_log = True - async def start(self, **kwargs): + async def start(self, *, api_key: Optional[str] = None, **kwargs): self._client_session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(ssl_context=ssl_context)) self.write_log = kwargs["write_log"] diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index 433e309e..872f8d62 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -27,7 +27,7 @@ class SetupPipelineStep(Step): async def run(self, sdk: ContinueSDK): sdk.context.set("api_description", self.api_description) - source_name = (await sdk.models.gpt35.complete( + source_name = (await sdk.models.medium.complete( f"Write a snake_case name for the data source described by {self.api_description}: ")).strip() filename = f'{source_name}.py' @@ -89,7 +89,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ ```python {file_content} ``` @@ -101,7 +101,7 @@ class ValidatePipelineStep(Step): This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) - api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\ + api_documentation_url = await sdk.models.medium.complete(dedent(f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: ```python {file_content} @@ -151,7 +151,7 @@ class RunQueryStep(Step): output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ ```python {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))} ``` diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index 6ef5ffd6..c66cd629 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -42,7 +42,7 @@ class WritePytestsRecipe(Step): "{self.user_input}" Here is a complete set of pytest unit tests:""") - tests = await sdk.models.gpt35.complete(prompt) + tests = await sdk.models.medium.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 2c662459..0a0fbca2 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -9,6 +9,7 @@ from .core.core import MessageStep from ...core.main import FunctionCall, Models from ...core.main import ChatMessage, Step, step_to_json_schema from ...core.sdk import ContinueSDK +from ...libs.llm.openai import OpenAI import openai import os from dotenv import load_dotenv @@ -43,7 +44,7 @@ class SimpleChatStep(Step): completion += chunk["content"] await sdk.update_ui() finally: - self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete( + self.name = remove_quotes_and_escapes(await sdk.models.medium.complete( f"Write a short title for the following chat message: {self.description}")) self.chat_context.append(ChatMessage( @@ -168,7 +169,10 @@ class ChatWithFunctions(Step): msg_content = "" msg_step = None - async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): + gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(gpt350613) + + async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): if sdk.current_step_was_deleted(): return diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index dbe8363e..658cc7f3 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""") - answer = await sdk.models.gpt35.complete(prompt) + answer = await sdk.models.medium.complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 5a81e5ee..b9f27fe5 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -7,7 +7,7 @@ from typing import Coroutine, List, Literal, Union from ....libs.llm.ggml import GGML from ....models.main import Range -from ....libs.llm.prompt_utils import MarkdownStyleEncoderDecoder +from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation @@ -84,7 +84,7 @@ class ShellCommandsStep(Step): for cmd in self.cmds: output = await sdk.ide.runCommand(cmd) if self.handle_error and output is not None and output_contains_error(output): - suggestion = await sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.medium.complete(dedent(f"""\ While running the command `{cmd}`, the following error occurred: ```ascii @@ -202,7 +202,8 @@ class DefaultModelEditCodeStep(Step): # If using 3.5 and overflows, upgrade to 3.5.16k if model_to_use.name == "gpt-3.5-turbo": if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]: - model_to_use = sdk.models.gpt3516k + model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613") + await sdk.start_model(model_to_use) # Remove tokens from the end first, and then the start to clear space # This part finds the start and end lines diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py index a76d491b..c38f54dc 100644 --- a/continuedev/src/continuedev/plugins/steps/draft/migration.py +++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step): recent_edits = await sdk.ide.get_recent_edits(self.edited_file) recent_edits_string = "\n\n".join( map(lambda x: x.to_string(), recent_edits)) - description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") + description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") await sdk.run([ "cd libs", "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index d3807706..4d75af30 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -49,7 +49,7 @@ class HelpStep(Step): summary="Help" )) messages = await sdk.get_chat_context() - generator = sdk.models.gpt4.stream_chat(messages) + generator = sdk.models.default.stream_chat(messages) async for chunk in generator: if "content" in chunk: self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index b54d394a..3d8d96fb 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -23,6 +23,6 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = await sdk.models.gpt35.complete( + gpt_parsed = await sdk.models.default.complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index a8752df2..26c1cabd 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -100,7 +100,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.gpt35.complete(prompt) + completion = await sdk.models.medium.complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index 8b2e7c2e..da6acdbf 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -27,7 +27,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") - resp = (await sdk.models.gpt35.complete(prompt)).lower() + resp = (await sdk.models.medium.complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 7d02d6fa..c13047d6 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.gpt35.complete(dedent(f"""\ + pattern = await sdk.models.medium.complete(dedent(f"""\ This is the user request: {self.user_request} |