summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py9
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py8
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py38
-rw-r--r--continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py10
-rw-r--r--continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py2
-rw-r--r--continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/steps/chat.py6
-rw-r--r--continuedev/src/continuedev/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/steps/core/core.py4
-rw-r--r--continuedev/src/continuedev/steps/draft/migration.py2
-rw-r--r--continuedev/src/continuedev/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/steps/main.py2
-rw-r--r--continuedev/src/continuedev/steps/react.py2
-rw-r--r--continuedev/src/continuedev/steps/search_directory.py2
14 files changed, 45 insertions, 46 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 2986b2c4..108eedf1 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,5 +1,5 @@
from abc import ABC
-from typing import Any, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
from ...models.main import AbstractModel
@@ -9,17 +9,14 @@ from pydantic import BaseModel
class LLM(ABC):
system_message: Union[str, None] = None
- def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
- raise
+ raise NotImplementedError
def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
"""Yield a stream of chat messages."""
raise NotImplementedError
- def __call__(self, prompt: str, **kwargs):
- return self.complete(prompt, **kwargs)
-
def with_system_message(self, system_message: Union[str, None]):
"""Return a new model with the given system message."""
raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 180ea5f0..bc108129 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,7 +1,7 @@
import asyncio
from functools import cached_property
import time
-from typing import Any, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Union
from ...core.main import ChatMessage
import openai
import aiohttp
@@ -107,7 +107,7 @@ class OpenAI(LLM):
for chunk in generator:
yield chunk.choices[0].text
- def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
t1 = time.time()
self.completion_count += 1
@@ -115,12 +115,12 @@ class OpenAI(LLM):
"frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs
if args["model"] in CHAT_MODELS:
- resp = openai.ChatCompletion.create(
+ resp = await openai.ChatCompletion.acreate(
messages=self.compile_chat_messages(with_history, prompt),
**args,
).choices[0].message.content
else:
- resp = openai.Completion.create(
+ resp = await openai.Completion.acreate(
prompt=prompt,
**args,
).choices[0].text
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 4ff57101..b2948f9a 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,8 +1,9 @@
from functools import cached_property
import json
-from typing import Any, Dict, Generator, List, Literal, Union
+from typing import Any, Coroutine, Dict, Generator, List, Literal, Union
import requests
import tiktoken
+import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
@@ -39,16 +40,6 @@ class ProxyServer(LLM):
def count_tokens(self, text: str):
return len(self.__encoding_for_model.encode(text, disallowed_special=()))
- def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
- resp = requests.post(f"{SERVER_URL}/stream_complete", json={
- "chat_history": self.compile_chat_messages(with_history, prompt),
- "model": self.default_model,
- "unique_id": self.unique_id,
- }, stream=True)
- for line in resp.iter_lines():
- if line:
- yield line.decode("utf-8")
-
def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
tokens = tokens_for_completion
for i in range(len(chat_history) - 1, -1, -1):
@@ -75,11 +66,22 @@ class ProxyServer(LLM):
return history
- def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/complete", json={
+ "chat_history": self.compile_chat_messages(with_history, prompt),
+ "model": self.default_model,
+ "unique_id": self.unique_id,
+ }) as resp:
+ return json.loads(await resp.text())
- resp = requests.post(f"{SERVER_URL}/complete", json={
- "chat_history": self.compile_chat_messages(with_history, prompt),
- "model": self.default_model,
- "unique_id": self.unique_id,
- })
- return json.loads(resp.text)
+ async def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async with aiohttp.ClientSession() as session:
+ async with session.post(f"{SERVER_URL}/stream_complete", json={
+ "chat_history": self.compile_chat_messages(with_history, prompt),
+ "model": self.default_model,
+ "unique_id": self.unique_id,
+ }) as resp:
+ async for line in resp.content:
+ if line:
+ yield line.decode("utf-8")
diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
index 096b41c6..3fba1112 100644
--- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py
@@ -29,8 +29,8 @@ class SetupPipelineStep(Step):
async def run(self, sdk: ContinueSDK):
sdk.context.set("api_description", self.api_description)
- source_name = sdk.models.gpt35.complete(
- f"Write a snake_case name for the data source described by {self.api_description}: ").strip()
+ source_name = (await sdk.models.gpt35.complete(
+ f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()
filename = f'{source_name}.py'
# running commands to get started when creating a new dlt pipeline
@@ -91,7 +91,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.gpt35.complete(dedent(f"""\
```python
{file_content}
```
@@ -103,7 +103,7 @@ class ValidatePipelineStep(Step):
This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:"""))
- api_documentation_url = sdk.models.gpt35.complete(dedent(f"""\
+ api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
```python
{file_content}
@@ -159,7 +159,7 @@ class RunQueryStep(Step):
output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.gpt35.complete(dedent(f"""\
```python
{await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}
```
diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py
index 6db9fd4b..df414e2e 100644
--- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py
+++ b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py
@@ -82,7 +82,7 @@ class LoadDataStep(Step):
docs = f.read()
output = "Traceback" + output.split("Traceback")[-1]
- suggestion = sdk.models.default.complete(dedent(f"""\
+ suggestion = await sdk.models.default.complete(dedent(f"""\
When trying to load data into BigQuery, the following error occurred:
```ascii
diff --git a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py
index 688f44c3..6e1244b3 100644
--- a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py
@@ -41,7 +41,7 @@ class WritePytestsRecipe(Step):
"{self.user_input}"
Here is a complete set of pytest unit tests:""")
- tests = sdk.models.gpt35.complete(prompt)
+ tests = await sdk.models.gpt35.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index 7cfe7e0c..499d127f 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -11,9 +11,9 @@ class SimpleChatStep(Step):
async def run(self, sdk: ContinueSDK):
self.description = f"## {self.user_input}\n\n"
- for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()):
+ async for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()):
self.description += chunk
await sdk.update_ui()
- self.name = sdk.models.gpt35.complete(
- f"Write a short title for the following chat message: {self.description}").strip()
+ self.name = (await sdk.models.gpt35.complete(
+ f"Write a short title for the following chat message: {self.description}")).strip()
diff --git a/continuedev/src/continuedev/steps/chroma.py b/continuedev/src/continuedev/steps/chroma.py
index 058455b2..9d085981 100644
--- a/continuedev/src/continuedev/steps/chroma.py
+++ b/continuedev/src/continuedev/steps/chroma.py
@@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = sdk.models.gpt35.complete(prompt)
+ answer = await sdk.models.gpt35.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index aee5bc1d..eb45d1d3 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -80,7 +80,7 @@ class ShellCommandsStep(Step):
for cmd in self.cmds:
output = await sdk.ide.runCommand(cmd)
if self.handle_error and output is not None and output_contains_error(output):
- suggestion = sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.gpt35.complete(dedent(f"""\
While running the command `{cmd}`, the following error occurred:
```ascii
@@ -183,7 +183,7 @@ class DefaultModelEditCodeStep(Step):
prompt = self._prompt.format(
code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1])
- completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context()))
+ completion = str(await sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context()))
eot_token = "<|endoftext|>"
completion = completion.removesuffix(eot_token)
diff --git a/continuedev/src/continuedev/steps/draft/migration.py b/continuedev/src/continuedev/steps/draft/migration.py
index 7c4b7eb5..f3b36b5e 100644
--- a/continuedev/src/continuedev/steps/draft/migration.py
+++ b/continuedev/src/continuedev/steps/draft/migration.py
@@ -13,7 +13,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run([
"cd libs",
"poetry run alembic revision --autogenerate -m " + description,
diff --git a/continuedev/src/continuedev/steps/input/nl_multiselect.py b/continuedev/src/continuedev/steps/input/nl_multiselect.py
index 36c489c7..aee22866 100644
--- a/continuedev/src/continuedev/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/steps/input/nl_multiselect.py
@@ -23,6 +23,6 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = sdk.models.gpt35.complete(
+ gpt_parsed = await sdk.models.gpt35.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py
index 0e42d8bf..3968c4a3 100644
--- a/continuedev/src/continuedev/steps/main.py
+++ b/continuedev/src/continuedev/steps/main.py
@@ -145,7 +145,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in rif_with_contents:
rif_dict[rif.filepath] = rif.contents
- completion = sdk.models.gpt35.complete(prompt)
+ completion = await sdk.models.gpt35.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py
index d825d424..4d310fc8 100644
--- a/continuedev/src/continuedev/steps/react.py
+++ b/continuedev/src/continuedev/steps/react.py
@@ -27,7 +27,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""")
- resp = sdk.models.gpt35.complete(prompt).lower()
+ resp = (await sdk.models.gpt35.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py
index 9f4594b9..d2966f46 100644
--- a/continuedev/src/continuedev/steps/search_directory.py
+++ b/continuedev/src/continuedev/steps/search_directory.py
@@ -41,7 +41,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = sdk.models.gpt35.complete(dedent(f"""\
+ pattern = await sdk.models.gpt35.complete(dedent(f"""\
This is the user request:
{self.user_request}