diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-17 01:23:13 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-17 01:23:13 -0700 |
commit | 9e8925fa8a68acff46c76ac937087e61802c7861 (patch) | |
tree | 20840ec7fa726a4a5677211ac133bde51c25d41b /continuedev/src | |
parent | a8467be404cf9ae7391b7d23dec096258618ea4b (diff) | |
download | sncontinue-9e8925fa8a68acff46c76ac937087e61802c7861.tar.gz sncontinue-9e8925fa8a68acff46c76ac937087e61802c7861.tar.bz2 sncontinue-9e8925fa8a68acff46c76ac937087e61802c7861.zip |
dynamic context_provider abstract class
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/dynamic.py | 79 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/chat.py | 188 |
2 files changed, 190 insertions, 77 deletions
diff --git a/continuedev/src/continuedev/plugins/context_providers/dynamic.py b/continuedev/src/continuedev/plugins/context_providers/dynamic.py new file mode 100644 index 00000000..e6ea0e88 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/dynamic.py @@ -0,0 +1,79 @@ +from abc import ABC, abstractmethod +from typing import List + +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.util.create_async_task import create_async_task +from .util import remove_meilisearch_disallowed_chars + + +class DynamicProvider(ContextProvider, ABC): + """ + A title representing the provider + """ + + title: str + """ + A name representing the provider. Probably use capitalized version of title + """ + name: str + """ + A description for the provider + """ + description: str + + workspace_dir: str = None + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name=self.name, + description=self.description, + id=ContextItemId(provider_title=self.title, item_id=self.title), + ), + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + create_async_task(self.setup()) + return [self.BASE_CONTEXT_ITEM] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if id.item_id != self.title: + raise Exception("Invalid item id") + + query = query.lstrip(self.title + " ") + results = await self.get_content(query) + + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.content = results + ctx_item.description.name = f"{self.name}: '{query}'" + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars(query) + return ctx_item + + @abstractmethod + async def get_content(self, query: str) -> str: + """Retrieve the content given the query + (e.g. search the codebase, return search results)""" + raise NotImplementedError + + @abstractmethod + async def setup(self): + """Run any setup needed (e.g. indexing the codebase)""" + raise NotImplementedError + + +""" +class ExampleDynamicProvider(DynamicProvider): + title = "example" + name = "Example" + description = "Example description" + + async def get_content(self, query: str) -> str: + return f"Example content for '{query}'" + + async def setup(self): + print("Example setup") +""" diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index ad19434f..7e674272 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -1,22 +1,21 @@ import json +import os from textwrap import dedent from typing import Any, Coroutine, List +import openai +from directory_tree import display_tree +from dotenv import load_dotenv from pydantic import Field -from ...libs.util.strings import remove_quotes_and_escapes -from .main import EditHighlightedCodeStep -from .core.core import MessageStep -from ...core.main import FunctionCall, Models -from ...core.main import ChatMessage, Step, step_to_json_schema +from ...core.main import ChatMessage, FunctionCall, Models, Step, step_to_json_schema from ...core.sdk import ContinueSDK -from ...libs.llm.openai import OpenAI from ...libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from ...libs.llm.openai import OpenAI +from ...libs.util.strings import remove_quotes_and_escapes from ...libs.util.telemetry import posthog_logger -import openai -import os -from dotenv import load_dotenv -from directory_tree import display_tree +from .core.core import MessageStep +from .main import EditHighlightedCodeStep load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") @@ -33,10 +32,28 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): # Check if proxy server API key - if isinstance(sdk.models.default, MaybeProxyOpenAI) and (sdk.models.default.api_key is None or sdk.models.default.api_key.strip() == "") and len(list(filter(lambda x: not x.step.hide, sdk.history.timeline))) >= 10 and len(list(filter(lambda x: x.step.name == FREE_USAGE_STEP_NAME, sdk.history.timeline))) == 0: - await sdk.run_step(MessageStep( - name=FREE_USAGE_STEP_NAME, - message=dedent("""\ + if ( + isinstance(sdk.models.default, MaybeProxyOpenAI) + and ( + sdk.models.default.api_key is None + or sdk.models.default.api_key.strip() == "" + ) + and len(list(filter(lambda x: not x.step.hide, sdk.history.timeline))) >= 10 + and len( + list( + filter( + lambda x: x.step.name == FREE_USAGE_STEP_NAME, + sdk.history.timeline, + ) + ) + ) + == 0 + ): + await sdk.run_step( + MessageStep( + name=FREE_USAGE_STEP_NAME, + message=dedent( + """\ To make it easier to use Continue, you're getting limited free usage. When you have the chance, please enter your own OpenAI key in `~/.continue/config.py`. You can open the file by using the '/config' slash command in the text box below. Here's an example of how to edit the file: @@ -52,36 +69,39 @@ class SimpleChatStep(Step): ``` You can also learn more about customizations [here](https://continue.dev/docs/customization). - """), - )) + """ + ), + ) + ) messages = self.messages or await sdk.get_chat_context() generator = sdk.models.default.stream_chat( - messages, temperature=sdk.config.temperature) + messages, temperature=sdk.config.temperature + ) - posthog_logger.capture_event("model_use", { - "model": sdk.models.default.name - }) + posthog_logger.capture_event("model_use", {"model": sdk.models.default.name}) async for chunk in generator: if sdk.current_step_was_deleted(): # So that the message doesn't disappear self.hide = False + await sdk.update_ui() break if "content" in chunk: self.description += chunk["content"] await sdk.update_ui() - self.name = remove_quotes_and_escapes(await sdk.models.medium.complete( - f"Write a short title for the following chat message: {self.description}")) + self.name = remove_quotes_and_escapes( + await sdk.models.medium.complete( + f"Write a short title for the following chat message: {self.description}" + ) + ) - self.chat_context.append(ChatMessage( - role="assistant", - content=self.description, - summary=self.name - )) + self.chat_context.append( + ChatMessage(role="assistant", content=self.description, summary=self.name) + ) # TODO: Never actually closing. await generator.aclose() @@ -93,13 +113,17 @@ class AddFileStep(Step): filename: str file_contents: str - async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: + async def describe( + self, models: Models + ) -> Coroutine[Any, Any, Coroutine[str, None, None]]: return f"Added a file named `{self.filename}` to the workspace." async def run(self, sdk: ContinueSDK): await sdk.add_file(self.filename, self.file_contents) - await sdk.ide.setFileOpen(os.path.join(sdk.ide.workspace_directory, self.filename)) + await sdk.ide.setFileOpen( + os.path.join(sdk.ide.workspace_directory, self.filename) + ) class DeleteFileStep(Step): @@ -107,7 +131,9 @@ class DeleteFileStep(Step): description = "Delete a file from the workspace." filename: str - async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: + async def describe( + self, models: Models + ) -> Coroutine[Any, Any, Coroutine[str, None, None]]: return f"Deleted a file named `{self.filename}` from the workspace." async def run(self, sdk: ContinueSDK): @@ -119,7 +145,9 @@ class AddDirectoryStep(Step): description = "Add a directory to the workspace." directory_name: str - async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: + async def describe( + self, models: Models + ) -> Coroutine[Any, Any, Coroutine[str, None, None]]: return f"Added a directory named `{self.directory_name}` to the workspace." async def run(self, sdk: ContinueSDK): @@ -142,20 +170,22 @@ class ViewDirectoryTreeStep(Step): name: str = "View Directory Tree" description: str = "View the directory tree to learn which folder and files exist. You should always do this before adding new files." - async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: - return f"Viewed the directory tree." + async def describe( + self, models: Models + ) -> Coroutine[Any, Any, Coroutine[str, None, None]]: + return "Viewed the directory tree." async def run(self, sdk: ContinueSDK): - self.description = f"```\n{display_tree(sdk.ide.workspace_directory, True, max_depth=2)}\n```" + self.description = ( + f"```\n{display_tree(sdk.ide.workspace_directory, True, max_depth=2)}\n```" + ) class EditFileStep(Step): name: str = "Edit File" description: str = "Edit a file in the workspace that is not currently open." - filename: str = Field( - ..., description="The name of the file to edit.") - instructions: str = Field( - ..., description="The instructions to edit the file.") + filename: str = Field(..., description="The name of the file to edit.") + instructions: str = Field(..., description="The instructions to edit the file.") hide: bool = True async def run(self, sdk: ContinueSDK): @@ -164,11 +194,15 @@ class EditFileStep(Step): class ChatWithFunctions(Step): user_input: str - functions: List[Step] = [AddFileStep(filename="", file_contents=""), - EditFileStep(filename="", instructions=""), - EditHighlightedCodeStep(user_input=""), - ViewDirectoryTreeStep(), AddDirectoryStep(directory_name=""), - DeleteFileStep(filename=""), RunTerminalCommandStep(command="")] + functions: List[Step] = [ + AddFileStep(filename="", file_contents=""), + EditFileStep(filename="", instructions=""), + EditHighlightedCodeStep(user_input=""), + ViewDirectoryTreeStep(), + AddDirectoryStep(directory_name=""), + DeleteFileStep(filename=""), + RunTerminalCommandStep(command=""), + ] name: str = "Input" manage_own_chat_context: bool = True description: str = "" @@ -178,18 +212,15 @@ class ChatWithFunctions(Step): await sdk.update_ui() step_name_step_class_map = { - step.name.replace(" ", ""): step.__class__ for step in self.functions} + step.name.replace(" ", ""): step.__class__ for step in self.functions + } - functions = [step_to_json_schema( - function) for function in self.functions] + functions = [step_to_json_schema(function) for function in self.functions] - self.chat_context.append(ChatMessage( - role="user", - content=self.user_input, - summary=self.user_input - )) + self.chat_context.append( + ChatMessage(role="user", content=self.user_input, summary=self.user_input) + ) - last_function_called_index_in_history = None last_function_called_name = None last_function_called_params = None while True: @@ -202,7 +233,9 @@ class ChatWithFunctions(Step): gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(gpt350613) - async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): + async for msg_chunk in gpt350613.stream_chat( + await sdk.get_chat_context(), functions=functions + ): if sdk.current_step_was_deleted(): return @@ -214,8 +247,7 @@ class ChatWithFunctions(Step): # sdk.history.timeline[last_function_called_index_in_history].step.description = msg_content if msg_step is None: msg_step = MessageStep( - name="Chat", - message=msg_chunk["content"] + name="Chat", message=msg_chunk["content"] ) await sdk.run_step(msg_step) else: @@ -230,14 +262,13 @@ class ChatWithFunctions(Step): func_name += msg_chunk["function_call"]["name"] if not was_function_called: - self.chat_context.append(ChatMessage( - role="assistant", - content=msg_content, - summary=msg_content - )) + self.chat_context.append( + ChatMessage( + role="assistant", content=msg_content, summary=msg_content + ) + ) break else: - last_function_called = func_name if func_name == "python" and "python" not in step_name_step_class_map: # GPT must be fine-tuned to believe this exists, but it doesn't always func_name = "EditHighlightedCodeStep" @@ -263,21 +294,20 @@ class ChatWithFunctions(Step): try: fn_call_params = json.loads(func_args) except json.JSONDecodeError: - raise Exception( - "The model returned invalid JSON. Please try again") - self.chat_context.append(ChatMessage( - role="assistant", - content=None, - function_call=FunctionCall( - name=func_name, - arguments=func_args - ), - summary=f"Called function {func_name}" - )) - last_function_called_index_in_history = sdk.history.current_index + 1 + raise Exception("The model returned invalid JSON. Please try again") + self.chat_context.append( + ChatMessage( + role="assistant", + content=None, + function_call=FunctionCall(name=func_name, arguments=func_args), + summary=f"Called function {func_name}", + ) + ) + sdk.history.current_index + 1 if func_name not in step_name_step_class_map: raise Exception( - f"The model tried to call a function ({func_name}) that does not exist. Please try again.") + f"The model tried to call a function ({func_name}) that does not exist. Please try again." + ) # if func_name == "AddFileStep": # step_to_run.hide = True @@ -292,9 +322,13 @@ class ChatWithFunctions(Step): elif func_name == "EditFile": fn_call_params["instructions"] = self.user_input - step_to_run = step_name_step_class_map[func_name]( - **fn_call_params) - if last_function_called_name is not None and last_function_called_name == func_name and last_function_called_params is not None and last_function_called_params == fn_call_params: + step_to_run = step_name_step_class_map[func_name](**fn_call_params) + if ( + last_function_called_name is not None + and last_function_called_name == func_name + and last_function_called_params is not None + and last_function_called_params == fn_call_params + ): # If it's calling the same function more than once in a row, it's probably looping and confused return last_function_called_name = func_name |