summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/models.py18
-rw-r--r--continuedev/src/continuedev/core/sdk.py12
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt5
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py9
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py4
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py17
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py2
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py8
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py7
-rw-r--r--continuedev/src/continuedev/plugins/steps/draft/migration.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py2
20 files changed, 68 insertions, 44 deletions
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index e4cb8ed6..900762b6 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -34,33 +34,23 @@ class Models(BaseModel):
'''depending on the model, return the single prompt string'''
"""
- async def _start_llm(self, llm: LLM):
- kwargs = {}
- if llm.requires_api_key:
- kwargs["api_key"] = await self.sdk.get_api_key(llm.requires_api_key)
- if llm.requires_unique_id:
- kwargs["unique_id"] = self.sdk.ide.unique_id
- if llm.requires_write_log:
- kwargs["write_log"] = self.sdk.write_log
- await llm.start(**kwargs)
-
async def start(self, sdk: "ContinueSDK"):
"""Start each of the LLMs, or fall back to default"""
self.sdk = sdk
self.system_message = self.sdk.config.system_message
- await self._start_llm(self.default)
+ await sdk.start_model(self.default)
if self.small:
- await self._start_llm(self.small)
+ await sdk.start_model(self.small)
else:
self.small = self.default
if self.medium:
- await self._start_llm(self.medium)
+ await sdk.start_model(self.medium)
else:
self.medium = self.default
if self.large:
- await self._start_llm(self.large)
+ await sdk.start_model(self.large)
else:
self.large = self.default
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index b0f7d40a..7febb932 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -15,13 +15,13 @@ from .main import Context, ContinueCustomException, History, HistoryNode, Step,
from ..plugins.steps.core.core import *
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
+from .models import Models
class Autopilot:
pass
-
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
ide: AbstractIdeProtocolServer
@@ -66,6 +66,16 @@ class ContinueSDK(AbstractContinueSDK):
def write_log(self, message: str):
self.history.timeline[self.history.current_index].logs.append(message)
+ async def start_model(self, llm: LLM):
+ kwargs = {}
+ if llm.requires_api_key:
+ kwargs["api_key"] = await self.get_api_key(llm.requires_api_key)
+ if llm.requires_unique_id:
+ kwargs["unique_id"] = self.ide.unique_id
+ if llm.requires_write_log:
+ kwargs["write_log"] = self.write_log
+ await llm.start(**kwargs)
+
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
return path
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index 5708747f..7cd2226a 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -31,7 +31,10 @@ class CommitMessageStep(Step):
# Ask gpt-3.5-16k to write a commit message,
# and set it as the description of this step
- self.description = await sdk.models.gpt3516k.complete(
+ gpt3516k = OpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(gpt3516k)
+
+ self.description = await gpt3516k.complete(
f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:")
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 21afc338..58572634 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,5 +1,5 @@
import functools
-from abc import ABC
+from abc import ABC, abstractproperty
from pydantic import BaseModel, ConfigDict
from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
@@ -15,7 +15,12 @@ class LLM(BaseModel, ABC):
system_message: Union[str, None] = None
- async def start(self, *, api_key: Optional[str] = None):
+ @abstractproperty
+ def name(self):
+ """Return the name of the LLM."""
+ raise NotImplementedError
+
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
"""Start the connection to the LLM."""
raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 067a903b..c9c8e9db 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,7 +1,7 @@
from functools import cached_property
import time
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
@@ -18,7 +18,7 @@ class AnthropicLLM(LLM):
self.model = model
self.system_message = system_message
- async def start(self, *, api_key: str):
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
self._async_client = AsyncAnthropic(api_key=api_key)
async def stop(self):
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 4ad32e0e..49f593d8 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from ...core.main import ChatMessage
from ..llm import LLM
import requests
@@ -17,7 +17,7 @@ class HuggingFaceInferenceAPI(LLM):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
- async def start(self, *, api_key: str):
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
self.api_key = api_key
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
index d2898b5c..121ae99e 100644
--- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -15,6 +15,10 @@ class MaybeProxyOpenAI(LLM):
llm: Optional[LLM] = None
+ @property
+ def name(self):
+ return self.llm.name
+
async def start(self, *, api_key: Optional[str] = None, **kwargs):
if api_key is None or api_key.strip() == "":
self.llm = ProxyServer(
@@ -22,17 +26,24 @@ class MaybeProxyOpenAI(LLM):
else:
self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"])
+ await self.llm.start(api_key=api_key, **kwargs)
+
async def stop(self):
await self.llm.stop()
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
return await self.llm.complete(prompt, with_history=with_history, **kwargs)
- def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
- return self.llm.stream_complete(prompt, with_history=with_history, **kwargs)
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_complete(
+ prompt, with_history=with_history, **kwargs)
+ async for item in resp:
+ yield item
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
- return self.llm.stream_chat(messages=messages, **kwargs)
+ resp = self.llm.stream_chat(messages=messages, **kwargs)
+ async for item in resp:
+ yield item
def count_tokens(self, text: str):
return self.llm.count_tokens(text)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 0c2c360b..de02a614 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -26,7 +26,7 @@ class OpenAI(LLM):
write_log: Optional[Callable[[str], None]] = None
api_key: str = None
- async def start(self, *, api_key):
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
self.api_key = api_key
openai.api_key = self.api_key
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index e8f1cb46..1c942523 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -27,7 +27,7 @@ class ProxyServer(LLM):
requires_unique_id = True
requires_write_log = True
- async def start(self, **kwargs):
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
self._client_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl_context=ssl_context))
self.write_log = kwargs["write_log"]
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index 433e309e..872f8d62 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -27,7 +27,7 @@ class SetupPipelineStep(Step):
async def run(self, sdk: ContinueSDK):
sdk.context.set("api_description", self.api_description)
- source_name = (await sdk.models.gpt35.complete(
+ source_name = (await sdk.models.medium.complete(
f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()
filename = f'{source_name}.py'
@@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{file_content}
```
@@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):
This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:"""))
- api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\
+ api_documentation_url = await sdk.models.medium.complete(dedent(f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
```python
{file_content}
@@ -151,7 +151,7 @@ class RunQueryStep(Step):
output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}
```
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index 6ef5ffd6..c66cd629 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):
"{self.user_input}"
Here is a complete set of pytest unit tests:""")
- tests = await sdk.models.gpt35.complete(prompt)
+ tests = await sdk.models.medium.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index 2c662459..0a0fbca2 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -9,6 +9,7 @@ from .core.core import MessageStep
from ...core.main import FunctionCall, Models
from ...core.main import ChatMessage, Step, step_to_json_schema
from ...core.sdk import ContinueSDK
+from ...libs.llm.openai import OpenAI
import openai
import os
from dotenv import load_dotenv
@@ -43,7 +44,7 @@ class SimpleChatStep(Step):
completion += chunk["content"]
await sdk.update_ui()
finally:
- self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
f"Write a short title for the following chat message: {self.description}"))
self.chat_context.append(ChatMessage(
@@ -168,7 +169,10 @@ class ChatWithFunctions(Step):
msg_content = ""
msg_step = None
- async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
+ gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(gpt350613)
+
+ async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
if sdk.current_step_was_deleted():
return
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index dbe8363e..658cc7f3 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = await sdk.models.gpt35.complete(prompt)
+ answer = await sdk.models.medium.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 5a81e5ee..b9f27fe5 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -7,7 +7,7 @@ from typing import Coroutine, List, Literal, Union
from ....libs.llm.ggml import GGML
from ....models.main import Range
-from ....libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
+from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
@@ -84,7 +84,7 @@ class ShellCommandsStep(Step):
for cmd in self.cmds:
output = await sdk.ide.runCommand(cmd)
if self.handle_error and output is not None and output_contains_error(output):
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
While running the command `{cmd}`, the following error occurred:
```ascii
@@ -202,7 +202,8 @@ class DefaultModelEditCodeStep(Step):
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
- model_to_use = sdk.models.gpt3516k
+ model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(model_to_use)
# Remove tokens from the end first, and then the start to clear space
# This part finds the start and end lines
diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py
index a76d491b..c38f54dc 100644
--- a/continuedev/src/continuedev/plugins/steps/draft/migration.py
+++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py
@@ -13,7 +13,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run([
"cd libs",
"poetry run alembic revision --autogenerate -m " + description,
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index d3807706..4d75af30 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -49,7 +49,7 @@ class HelpStep(Step):
summary="Help"
))
messages = await sdk.get_chat_context()
- generator = sdk.models.gpt4.stream_chat(messages)
+ generator = sdk.models.default.stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index b54d394a..3d8d96fb 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -23,6 +23,6 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.gpt35.complete(
+ gpt_parsed = await sdk.models.default.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index a8752df2..26c1cabd 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -100,7 +100,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.gpt35.complete(prompt)
+ completion = await sdk.models.medium.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index 8b2e7c2e..da6acdbf 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -27,7 +27,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""")
- resp = (await sdk.models.gpt35.complete(prompt)).lower()
+ resp = (await sdk.models.medium.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 7d02d6fa..c13047d6 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.gpt35.complete(dedent(f"""\
+ pattern = await sdk.models.medium.complete(dedent(f"""\
This is the user request:
{self.user_request}