diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-15 18:53:57 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-15 18:53:57 -0700 |
commit | 95c320a3462548d5f39b2765f0501cc5e2f14b99 (patch) | |
tree | e3462064f9253ad0585a6ee3cf2de01ea9968e2f | |
parent | b031c4bbaef986a48a66080d6b786f05abc0a793 (diff) | |
download | sncontinue-95c320a3462548d5f39b2765f0501cc5e2f14b99.tar.gz sncontinue-95c320a3462548d5f39b2765f0501cc5e2f14b99.tar.bz2 sncontinue-95c320a3462548d5f39b2765f0501cc5e2f14b99.zip |
convert all openai calls to async
14 files changed, 45 insertions, 46 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 2986b2c4..108eedf1 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Union from ...core.main import ChatMessage from ...models.main import AbstractModel @@ -9,17 +9,14 @@ from pydantic import BaseModel class LLM(ABC): system_message: Union[str, None] = None - def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: """Return the completion of the text with the given temperature.""" - raise + raise NotImplementedError def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: """Yield a stream of chat messages.""" raise NotImplementedError - def __call__(self, prompt: str, **kwargs): - return self.complete(prompt, **kwargs) - def with_system_message(self, system_message: Union[str, None]): """Return a new model with the given system message.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 180ea5f0..bc108129 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,7 +1,7 @@ import asyncio from functools import cached_property import time -from typing import Any, Dict, Generator, List, Union +from typing import Any, Coroutine, Dict, Generator, List, Union from ...core.main import ChatMessage import openai import aiohttp @@ -107,7 +107,7 @@ class OpenAI(LLM): for chunk in generator: yield chunk.choices[0].text - def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str: + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: t1 = time.time() self.completion_count += 1 @@ -115,12 +115,12 @@ class OpenAI(LLM): "frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs if args["model"] in CHAT_MODELS: - resp = openai.ChatCompletion.create( + resp = await openai.ChatCompletion.acreate( messages=self.compile_chat_messages(with_history, prompt), **args, ).choices[0].message.content else: - resp = openai.Completion.create( + resp = await openai.Completion.acreate( prompt=prompt, **args, ).choices[0].text diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 4ff57101..b2948f9a 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,8 +1,9 @@ from functools import cached_property import json -from typing import Any, Dict, Generator, List, Literal, Union +from typing import Any, Coroutine, Dict, Generator, List, Literal, Union import requests import tiktoken +import aiohttp from ...core.main import ChatMessage from ..llm import LLM @@ -39,16 +40,6 @@ class ProxyServer(LLM): def count_tokens(self, text: str): return len(self.__encoding_for_model.encode(text, disallowed_special=())) - def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: - resp = requests.post(f"{SERVER_URL}/stream_complete", json={ - "chat_history": self.compile_chat_messages(with_history, prompt), - "model": self.default_model, - "unique_id": self.unique_id, - }, stream=True) - for line in resp.iter_lines(): - if line: - yield line.decode("utf-8") - def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): tokens = tokens_for_completion for i in range(len(chat_history) - 1, -1, -1): @@ -75,11 +66,22 @@ class ProxyServer(LLM): return history - def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str: + async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/complete", json={ + "chat_history": self.compile_chat_messages(with_history, prompt), + "model": self.default_model, + "unique_id": self.unique_id, + }) as resp: + return json.loads(await resp.text()) - resp = requests.post(f"{SERVER_URL}/complete", json={ - "chat_history": self.compile_chat_messages(with_history, prompt), - "model": self.default_model, - "unique_id": self.unique_id, - }) - return json.loads(resp.text) + async def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + async with aiohttp.ClientSession() as session: + async with session.post(f"{SERVER_URL}/stream_complete", json={ + "chat_history": self.compile_chat_messages(with_history, prompt), + "model": self.default_model, + "unique_id": self.unique_id, + }) as resp: + async for line in resp.content: + if line: + yield line.decode("utf-8") diff --git a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py index 096b41c6..3fba1112 100644 --- a/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/CreatePipelineRecipe/steps.py @@ -29,8 +29,8 @@ class SetupPipelineStep(Step): async def run(self, sdk: ContinueSDK): sdk.context.set("api_description", self.api_description) - source_name = sdk.models.gpt35.complete( - f"Write a snake_case name for the data source described by {self.api_description}: ").strip() + source_name = (await sdk.models.gpt35.complete( + f"Write a snake_case name for the data source described by {self.api_description}: ")).strip() filename = f'{source_name}.py' # running commands to get started when creating a new dlt pipeline @@ -91,7 +91,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.gpt35.complete(dedent(f"""\ ```python {file_content} ``` @@ -103,7 +103,7 @@ class ValidatePipelineStep(Step): This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:""")) - api_documentation_url = sdk.models.gpt35.complete(dedent(f"""\ + api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: ```python {file_content} @@ -159,7 +159,7 @@ class RunQueryStep(Step): output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False) if "Traceback" in output or "SyntaxError" in output: - suggestion = sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.gpt35.complete(dedent(f"""\ ```python {await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))} ``` diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py index 6db9fd4b..df414e2e 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py @@ -82,7 +82,7 @@ class LoadDataStep(Step): docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = sdk.models.default.complete(dedent(f"""\ + suggestion = await sdk.models.default.complete(dedent(f"""\ When trying to load data into BigQuery, the following error occurred: ```ascii diff --git a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py index 688f44c3..6e1244b3 100644 --- a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py @@ -41,7 +41,7 @@ class WritePytestsRecipe(Step): "{self.user_input}" Here is a complete set of pytest unit tests:""") - tests = sdk.models.gpt35.complete(prompt) + tests = await sdk.models.gpt35.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 7cfe7e0c..499d127f 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -11,9 +11,9 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): self.description = f"## {self.user_input}\n\n" - for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): + async for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): self.description += chunk await sdk.update_ui() - self.name = sdk.models.gpt35.complete( - f"Write a short title for the following chat message: {self.description}").strip() + self.name = (await sdk.models.gpt35.complete( + f"Write a short title for the following chat message: {self.description}")).strip() diff --git a/continuedev/src/continuedev/steps/chroma.py b/continuedev/src/continuedev/steps/chroma.py index 058455b2..9d085981 100644 --- a/continuedev/src/continuedev/steps/chroma.py +++ b/continuedev/src/continuedev/steps/chroma.py @@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""") - answer = sdk.models.gpt35.complete(prompt) + answer = await sdk.models.gpt35.complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index aee5bc1d..eb45d1d3 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -80,7 +80,7 @@ class ShellCommandsStep(Step): for cmd in self.cmds: output = await sdk.ide.runCommand(cmd) if self.handle_error and output is not None and output_contains_error(output): - suggestion = sdk.models.gpt35.complete(dedent(f"""\ + suggestion = await sdk.models.gpt35.complete(dedent(f"""\ While running the command `{cmd}`, the following error occurred: ```ascii @@ -183,7 +183,7 @@ class DefaultModelEditCodeStep(Step): prompt = self._prompt.format( code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) - completion = str(sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context())) + completion = str(await sdk.models.default.complete(prompt, with_history=await sdk.get_chat_context())) eot_token = "<|endoftext|>" completion = completion.removesuffix(eot_token) diff --git a/continuedev/src/continuedev/steps/draft/migration.py b/continuedev/src/continuedev/steps/draft/migration.py index 7c4b7eb5..f3b36b5e 100644 --- a/continuedev/src/continuedev/steps/draft/migration.py +++ b/continuedev/src/continuedev/steps/draft/migration.py @@ -13,7 +13,7 @@ class MigrationStep(Step): recent_edits = await sdk.ide.get_recent_edits(self.edited_file) recent_edits_string = "\n\n".join( map(lambda x: x.to_string(), recent_edits)) - description = sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") + description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n") await sdk.run([ "cd libs", "poetry run alembic revision --autogenerate -m " + description, diff --git a/continuedev/src/continuedev/steps/input/nl_multiselect.py b/continuedev/src/continuedev/steps/input/nl_multiselect.py index 36c489c7..aee22866 100644 --- a/continuedev/src/continuedev/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/steps/input/nl_multiselect.py @@ -23,6 +23,6 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = sdk.models.gpt35.complete( + gpt_parsed = await sdk.models.gpt35.complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:") return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 0e42d8bf..3968c4a3 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -145,7 +145,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in rif_with_contents: rif_dict[rif.filepath] = rif.contents - completion = sdk.models.gpt35.complete(prompt) + completion = await sdk.models.gpt35.complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py index d825d424..4d310fc8 100644 --- a/continuedev/src/continuedev/steps/react.py +++ b/continuedev/src/continuedev/steps/react.py @@ -27,7 +27,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""") - resp = sdk.models.gpt35.complete(prompt).lower() + resp = (await sdk.models.gpt35.complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py index 9f4594b9..d2966f46 100644 --- a/continuedev/src/continuedev/steps/search_directory.py +++ b/continuedev/src/continuedev/steps/search_directory.py @@ -41,7 +41,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = sdk.models.gpt35.complete(dedent(f"""\ + pattern = await sdk.models.gpt35.complete(dedent(f"""\ This is the user request: {self.user_request} |