From c6a12550ffca1ffe35630e7aa9af6913ddbe0675 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sat, 29 Jul 2023 22:40:04 -0700 Subject: feat: :sparkles: EmbeddingContextProvider --- continuedev/src/continuedev/core/autopilot.py | 24 +++---- continuedev/src/continuedev/core/config.py | 11 ++- continuedev/src/continuedev/core/policy.py | 80 ---------------------- continuedev/src/continuedev/libs/chroma/query.py | 2 +- continuedev/src/continuedev/libs/chroma/update.py | 2 +- .../libs/constants/default_config.py.txt | 7 +- continuedev/src/continuedev/libs/llm/ggml.py | 2 +- continuedev/src/continuedev/libs/llm/openai.py | 2 +- .../src/continuedev/libs/util/calculate_diff.py | 2 +- continuedev/src/continuedev/libs/util/strings.py | 2 +- .../src/continuedev/models/generate_json_schema.py | 2 +- .../plugins/context_providers/embeddings.py | 79 +++++++++++++++++++++ .../continuedev/plugins/context_providers/file.py | 3 +- .../plugins/context_providers/google.py | 4 +- .../continuedev/plugins/context_providers/util.py | 5 ++ .../src/continuedev/plugins/policies/default.py | 79 +++++++++++++++++++++ .../src/continuedev/plugins/steps/core/core.py | 6 +- .../continuedev/plugins/steps/search_directory.py | 4 +- continuedev/src/continuedev/server/gui.py | 2 +- continuedev/src/continuedev/server/ide.py | 4 +- .../src/continuedev/server/session_manager.py | 48 +++++++------ extension/src/suggestions.ts | 2 +- 22 files changed, 234 insertions(+), 138 deletions(-) delete mode 100644 continuedev/src/continuedev/core/policy.py create mode 100644 continuedev/src/continuedev/plugins/context_providers/embeddings.py create mode 100644 continuedev/src/continuedev/plugins/context_providers/util.py create mode 100644 continuedev/src/continuedev/plugins/policies/default.py diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 3f25e64e..12339f9b 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation from .context import ContextManager +from ..plugins.policies.default import DefaultPolicy from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..server.ide_protocol import AbstractIdeProtocolServer @@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str: class Autopilot(ContinueBaseModel): - policy: Policy ide: AbstractIdeProtocolServer + + policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None @@ -64,20 +66,18 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @classmethod - async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": - autopilot = cls(ide=ide, policy=policy) - autopilot.continue_sdk = await ContinueSDK.create(autopilot) + async def start(self): + self.continue_sdk = await ContinueSDK.create(self) + if override_policy := self.continue_sdk.config.policy_override: + self.policy = override_policy # Load documents into the search index - autopilot.context_manager = await ContextManager.create( - autopilot.continue_sdk.config.context_providers + [ - HighlightedCodeContextProvider(ide=ide), - FileContextProvider(workspace_dir=ide.workspace_directory) + self.context_manager = await ContextManager.create( + self.continue_sdk.config.context_providers + [ + HighlightedCodeContextProvider(ide=self.ide), + FileContextProvider(workspace_dir=self.ide.workspace_directory) ]) - await autopilot.context_manager.load_index() - - return autopilot + await self.context_manager.load_index() class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 9fbda824..fe0946cd 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,10 +1,8 @@ -import json -import os -from .main import Step -from .context import ContextProvider from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict, Type, Union -import yaml +from typing import List, Literal, Optional, Dict, Type + +from .main import Policy, Step +from .context import ContextProvider class SlashCommand(BaseModel): @@ -51,6 +49,7 @@ class ContinueConfig(BaseModel): on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None openai_server_info: Optional[OpenAIServerInfo] = None + policy_override: Optional[Policy] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py deleted file mode 100644 index d90177b5..00000000 --- a/continuedev/src/continuedev/core/policy.py +++ /dev/null @@ -1,80 +0,0 @@ -from textwrap import dedent -from typing import Union - -from ..plugins.steps.chat import SimpleChatStep -from ..plugins.steps.welcome import WelcomeStep -from .config import ContinueConfig -from ..plugins.steps.steps_on_startup import StepsOnStartupStep -from .main import Step, History, Policy -from .observation import UserInputObservation -from ..plugins.steps.core.core import MessageStep -from ..plugins.steps.custom_command import CustomCommandStep -from ..plugins.steps.main import EditHighlightedCodeStep - - -def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: - """ - Parses a slash command, returning the command name and the rest of the input. - """ - if inp.startswith("/"): - command_name = inp.split(" ")[0] - after_command = " ".join(inp.split(" ")[1:]) - - for slash_command in config.slash_commands: - if slash_command.name == command_name[1:]: - params = slash_command.params - params["user_input"] = after_command - try: - return slash_command.step(**params) - except TypeError as e: - raise Exception( - f"Incorrect params used for slash command '{command_name}': {e}") - return None - - -def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: - command_name = inp.split(" ")[0] - after_command = " ".join(inp.split(" ")[1:]) - for custom_cmd in config.custom_commands: - if custom_cmd.name == command_name[1:]: - slash_command = parse_slash_command(custom_cmd.prompt, config) - if slash_command is not None: - return slash_command - return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name) - return None - - -class DefaultPolicy(Policy): - ran_code_last: bool = False - - def next(self, config: ContinueConfig, history: History) -> Step: - # At the very start, run initial Steps spcecified in the config - if history.get_current() is None: - return ( - MessageStep(name="Welcome to Continue", message=dedent("""\ - - Highlight code section and ask a question or give instructions - - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - - Use `/help` to ask questions about how to use Continue""")) >> - WelcomeStep() >> - # CreateCodebaseIndexChroma() >> - StepsOnStartupStep()) - - observation = history.get_current().observation - if observation is not None and isinstance(observation, UserInputObservation): - # This could be defined with ObservationTypePolicy. Ergonomics not right though. - user_input = observation.user_input - - slash_command = parse_slash_command(user_input, config) - if slash_command is not None: - return slash_command - - custom_command = parse_custom_command(user_input, config) - if custom_command is not None: - return custom_command - - if user_input.startswith("/edit"): - return EditHighlightedCodeStep(user_input=user_input[5:]) - - return SimpleChatStep() - - return None diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index f09b813a..dba4874f 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -59,7 +59,7 @@ class ChromaIndexManager: except: logger.warning( f"ERROR (probably found special token): {doc.text}") - continue + continue # lol filename = doc.extra_info["filename"] chunks[filename] = len(text_chunks) for i, text in enumerate(text_chunks): diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index 23ed950f..d5326a06 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -23,7 +23,7 @@ def filter_ignored_files(files: List[str], root_dir: str): """Further filter files before indexing.""" for file in files: if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): - continue + continue # nice yield root_dir + "/" + file diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 1a66c847..69fd357b 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -12,6 +12,7 @@ from continuedev.core.sdk import ContinueSDK from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider from continuedev.plugins.context_providers.google import GoogleContextProvider +from continuedev.plugins.policies.default import DefaultPolicy from continuedev.plugins.steps.open_config import OpenConfigStep from continuedev.plugins.steps.clear_history import ClearHistoryStep @@ -114,5 +115,9 @@ config = ContinueConfig( # GoogleContextProvider( # serper_api_key="" # ) - ] + ], + + # Policies hold the main logic that decides which Step to take next + # You can use them to design agents, or deeply customize Continue + policy=DefaultPolicy() ) diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..21374359 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -64,7 +64,7 @@ class GGML(LLM): try: json_chunk = line[0].decode("utf-8") if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): - continue + continue # hehe chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 654c7326..d1ca4ef9 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -62,7 +62,7 @@ class OpenAI(LLM): yield chunk.choices[0].delta.content completion += chunk.choices[0].delta.content else: - continue + continue # :) self.write_log(f"Completion: \n\n{completion}") else: diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py index ff0a135f..3e82bab3 100644 --- a/continuedev/src/continuedev/libs/util/calculate_diff.py +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -92,7 +92,7 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index] replacement = updated[j1:j2] if tag == "equal": - continue + continue # ;) elif tag == "delete": edits.append(FileEdit.from_deletion( filepath, Range.from_indices(original, i1, i2))) diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index f1fb8d0b..285c1e47 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -12,7 +12,7 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: for i in range(1, len(lines)): # Empty lines are wildcards if lines[i].strip() == "": - continue + continue # hey that's us! # Iterate through the leading whitespace characters of the current line for j in range(0, len(lcp)): # If it doesn't have the same whitespace as lcp, then update lcp diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 51869fdd..2166bc37 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -39,7 +39,7 @@ def main(): json = schema_json_of(model, indent=2, title=title) except Exception as e: print(f"Failed to generate json schema for {title}: {e}") - continue + continue # pun intended with open(f"{SCHEMA_DIR}/{title}.json", "w") as f: f.write(json) diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py new file mode 100644 index 00000000..42d1f754 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional +import uuid +from pydantic import BaseModel + +from ...core.main import ContextItemId +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.chroma.query import ChromaIndexManager +from .util import remove_meilisearch_disallowed_chars + + +class EmbeddingResult(BaseModel): + filename: str + content: str + + +class EmbeddingsProvider(ContextProvider): + title = "embed" + + workspace_directory: str + + EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" + + index_manager: Optional[ChromaIndexManager] = None + + class Config: + arbitrary_types_allowed = True + + @property + def index(self): + if self.index_manager is None: + self.index_manager = ChromaIndexManager(self.workspace_directory) + return self.index_manager + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Embedding Search", + description="Enter a query to embedding search codebase", + id=ContextItemId( + provider_title=self.title, + item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID + ) + ) + ) + + async def _get_query_results(self, query: str) -> str: + results = self.index.query_codebase_index(query) + + ret = [] + for node in results.source_nodes: + resource_name = list(node.node.relationships.values())[0] + filepath = resource_name[:resource_name.index("::")] + ret.append(EmbeddingResult( + filename=filepath, content=node.node.text)) + + return ret + + async def provide_context_items(self) -> List[ContextItem]: + self.index.create_codebase_index() # TODO Synchronous here is not ideal + + return [self.BASE_CONTEXT_ITEM] + + async def add_context_item(self, id: ContextItemId, query: str): + if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID: + raise Exception("Invalid item id") + + results = await self._get_query_results(query) + + for i in range(len(results)): + result = results[i] + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.description.name = os.path.basename(result.filename) + ctx_item.content = f"{result.filename}\n```\n{result.content}\n```" + ctx_item.description.id.item_id = uuid.uuid4().hex + self.selected_items.append(ctx_item) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 31c8e1d9..c85a87a2 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -3,6 +3,7 @@ import re from typing import List from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from .util import remove_meilisearch_disallowed_chars from fnmatch import fnmatch @@ -80,7 +81,7 @@ class FileContextProvider(ContextProvider): description=file, id=ContextItemId( provider_title=self.title, - item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file) + item_id=remove_meilisearch_disallowed_chars(file) ) ) )) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index 64954833..4ad7c4a1 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -2,6 +2,7 @@ import json from typing import List import aiohttp +from .util import remove_meilisearch_disallowed_chars from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider @@ -60,5 +61,6 @@ class GoogleContextProvider(ContextProvider): ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = content - ctx_item.description.id.item_id = query + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( + query) return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py new file mode 100644 index 00000000..da2e6b17 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/util.py @@ -0,0 +1,5 @@ +import re + + +def remove_meilisearch_disallowed_chars(id: str) -> str: + return re.sub(r'[^0-9a-zA-Z_-]', '', id) diff --git a/continuedev/src/continuedev/plugins/policies/default.py b/continuedev/src/continuedev/plugins/policies/default.py new file mode 100644 index 00000000..f479b758 --- /dev/null +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -0,0 +1,79 @@ +from textwrap import dedent +from typing import Union + +from ..steps.chat import SimpleChatStep +from ..steps.welcome import WelcomeStep +from ...core.config import ContinueConfig +from ..steps.steps_on_startup import StepsOnStartupStep +from ...core.main import Step, History, Policy +from ...core.observation import UserInputObservation +from ..steps.core.core import MessageStep +from ..steps.custom_command import CustomCommandStep +from ..steps.main import EditHighlightedCodeStep + + +def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: + """ + Parses a slash command, returning the command name and the rest of the input. + """ + if inp.startswith("/"): + command_name = inp.split(" ")[0] + after_command = " ".join(inp.split(" ")[1:]) + + for slash_command in config.slash_commands: + if slash_command.name == command_name[1:]: + params = slash_command.params + params["user_input"] = after_command + try: + return slash_command.step(**params) + except TypeError as e: + raise Exception( + f"Incorrect params used for slash command '{command_name}': {e}") + return None + + +def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: + command_name = inp.split(" ")[0] + after_command = " ".join(inp.split(" ")[1:]) + for custom_cmd in config.custom_commands: + if custom_cmd.name == command_name[1:]: + slash_command = parse_slash_command(custom_cmd.prompt, config) + if slash_command is not None: + return slash_command + return CustomCommandStep(name=custom_cmd.name, description=custom_cmd.description, prompt=custom_cmd.prompt, user_input=after_command, slash_command=command_name) + return None + + +class DefaultPolicy(Policy): + ran_code_last: bool = False + + def next(self, config: ContinueConfig, history: History) -> Step: + # At the very start, run initial Steps spcecified in the config + if history.get_current() is None: + return ( + MessageStep(name="Welcome to Continue", message=dedent("""\ + - Highlight code section and ask a question or give instructions + - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue + - Use `/help` to ask questions about how to use Continue""")) >> + WelcomeStep() >> + StepsOnStartupStep()) + + observation = history.get_current().observation + if observation is not None and isinstance(observation, UserInputObservation): + # This could be defined with ObservationTypePolicy. Ergonomics not right though. + user_input = observation.user_input + + slash_command = parse_slash_command(user_input, config) + if slash_command is not None: + return slash_command + + custom_command = parse_custom_command(user_input, config) + if custom_command is not None: + return custom_command + + if user_input.startswith("/edit"): + return EditHighlightedCodeStep(user_input=user_input[5:]) + + return SimpleChatStep() + + return None diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index c80cecc3..de7cf3ac 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -525,7 +525,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user # Accumulate lines if "content" not in chunk: - continue + continue # ayo chunk = chunk["content"] chunk_lines = chunk.split("\n") chunk_lines[0] = unfinished_line + chunk_lines[0] @@ -546,12 +546,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user break # Lines that should be ignored, like the <> tags elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): - continue + continue # noice # Check if we are currently just copying the prefix elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line lines_of_prefix_copied += 1 - continue + continue # also nice # Because really short lines might be expected to be repeated, this is only a !heuristic! # Stop when it starts copying the file_suffix elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 07b50473..522a84a3 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -20,10 +20,10 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]: for root, dirs, files in os.walk(dirpath): dirname = os.path.basename(root) if dirname.startswith(".") or dirname in IGNORE_DIRS: - continue + continue # continue! for file in files: if file in IGNORE_FILES: - continue + continue # pun intended with open(os.path.join(root, file), "r") as f: # Find the index of all occurences of the pattern in the file. Use re. file_content = f.read() diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 98a5aea0..cf18c56b 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -176,7 +176,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we message = json.loads(message) if "messageType" not in message or "data" not in message: - continue + continue # :o message_type = message["messageType"] data = message["data"] diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e4c07029..6124f3bd 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -139,7 +139,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): msg_string = await self.websocket.receive_text() message = json.loads(msg_string) if "messageType" not in message or "data" not in message: - continue + continue # <-- hey that's the name of this repo! message_type = message["messageType"] data = message["data"] logger.debug(f"Received message while initializing {message_type}") @@ -311,7 +311,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onFileEdits(self, edits: List[FileEditWithFullContents]): if autopilot := self.__get_autopilot(): - autopilot.handle_manual_edits(edits) + pass def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index cf46028f..b5580fe8 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,4 +1,5 @@ import os +import traceback from fastapi import WebSocket from typing import Any, Dict, List, Union from uuid import uuid4 @@ -6,12 +7,10 @@ import json from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER -from ..core.policy import DefaultPolicy -from ..core.main import FullState +from ..core.main import FullState, HistoryNode from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task @@ -31,19 +30,6 @@ class Session: self.ws = None -class DemoAutopilot(Autopilot): - first_seen: bool = False - cumulative_edit_string = "" - - def handle_manual_edits(self, edits: List[FileEditWithFullContents]): - return - for edit in edits: - self.cumulative_edit_string += edit.fileEdit.replacement - self._manual_edits_buffer.append(edit) - # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. - # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) - - class SessionManager: sessions: Dict[str, Session] = {} # Mapping of session_id to IDE, where the IDE is still alive @@ -65,27 +51,47 @@ class SessionManager: async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: logger.debug(f"New session: {session_id}") + # Load the persisted state (not being used right now) full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: full_state = FullState(**json.load(f)) - autopilot = await DemoAutopilot.create( - policy=DefaultPolicy(), ide=ide, full_state=full_state) + # Register the session and ide (do this first so that the autopilot can access the session) + autopilot = Autopilot(ide=ide) session_id = session_id or str(uuid4()) ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) self.sessions[session_id] = session self.registered_ides[session_id] = ide + # Set up the autopilot to update the GUI async def on_update(state: FullState): await session_manager.send_ws_data(session_id, "state_update", { "state": state.dict() }) autopilot.on_update(on_update) - create_async_task(autopilot.run_policy( - ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) + + # Start the autopilot (must be after session is added to sessions) and the policy + try: + await autopilot.start() + except Exception as e: + # Have to manually add to history because autopilot isn't started + formatted_err = '\n'.join(traceback.format_exception(e)) + msg_step = MessageStep( + name="Error loading context manager", message=formatted_err) + msg_step.description = f"```\n{formatted_err}\n```" + autopilot.history.add_node(HistoryNode( + step=msg_step, + observation=None, + depth=0, + active=False + )) + logger.warning(f"Error loading context manager: {e}") + + create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step( + DisplayErrorStep(e=e))) return session async def remove_session(self, session_id: str): diff --git a/extension/src/suggestions.ts b/extension/src/suggestions.ts index 5c2b8860..b5be341d 100644 --- a/extension/src/suggestions.ts +++ b/extension/src/suggestions.ts @@ -72,7 +72,7 @@ export function rerenderDecorations(editorUri: string) { range.end.character === 0 ) { // Empty range, don't show it - continue; + continue; // is great } newRanges.push( new vscode.Range( -- cgit v1.2.3-70-g09d2 From 396679009fef21e13c1a6095212d1bd68e7f2a86 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Mon, 31 Jul 2023 00:05:30 -0700 Subject: feat: :technologist: bit of customization for DefaultPolicy --- continuedev/src/continuedev/plugins/policies/default.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/continuedev/src/continuedev/plugins/policies/default.py b/continuedev/src/continuedev/plugins/policies/default.py index f479b758..2e7573f3 100644 --- a/continuedev/src/continuedev/plugins/policies/default.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -45,7 +45,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: class DefaultPolicy(Policy): - ran_code_last: bool = False + + default_step: Step = SimpleChatStep() def next(self, config: ContinueConfig, history: History) -> Step: # At the very start, run initial Steps spcecified in the config @@ -74,6 +75,6 @@ class DefaultPolicy(Policy): if user_input.startswith("/edit"): return EditHighlightedCodeStep(user_input=user_input[5:]) - return SimpleChatStep() + return self.default_step return None -- cgit v1.2.3-70-g09d2