diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-07-31 00:11:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-31 00:11:29 -0700 |
commit | d12b3b7d922a385dc59473f951fad1b3f1db8bdc (patch) | |
tree | e0fe04351404c33a4299e2e7b2d36511d24d321b | |
parent | 8bd76be6c0925e0d5e5f6d239e9c6907df3cfd23 (diff) | |
parent | fc77dd52708d2a28cc6f138c5f0ee390b6d71a3f (diff) | |
download | sncontinue-d12b3b7d922a385dc59473f951fad1b3f1db8bdc.tar.gz sncontinue-d12b3b7d922a385dc59473f951fad1b3f1db8bdc.tar.bz2 sncontinue-d12b3b7d922a385dc59473f951fad1b3f1db8bdc.zip |
Merge pull request #327 from continuedev/at-embed
At embed
21 files changed, 167 insertions, 69 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 57e39d5c..5ab5f8ae 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,6 +9,7 @@ from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation from .context import ContextManager +from ..plugins.policies.default import DefaultPolicy from ..plugins.context_providers.file import FileContextProvider from ..plugins.context_providers.highlighted_code import HighlightedCodeContextProvider from ..server.ide_protocol import AbstractIdeProtocolServer @@ -47,8 +48,9 @@ def get_error_title(e: Exception) -> str: class Autopilot(ContinueBaseModel): - policy: Policy ide: AbstractIdeProtocolServer + + policy: Policy = DefaultPolicy() history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None @@ -64,20 +66,19 @@ class Autopilot(ContinueBaseModel): _user_input_queue = AsyncSubscriptionQueue() _retry_queue = AsyncSubscriptionQueue() - @classmethod - async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": - autopilot = cls(ide=ide, policy=policy) - autopilot.continue_sdk = await ContinueSDK.create(autopilot) + async def start(self): + self.continue_sdk = await ContinueSDK.create(self) + if override_policy := self.continue_sdk.config.policy_override: + self.policy = override_policy # Load documents into the search index - autopilot.context_manager = await ContextManager.create( - autopilot.continue_sdk.config.context_providers + [ - HighlightedCodeContextProvider(ide=ide), - FileContextProvider(workspace_dir=ide.workspace_directory) + self.context_manager = await ContextManager.create( + self.continue_sdk.config.context_providers + [ + HighlightedCodeContextProvider(ide=self.ide), + FileContextProvider(workspace_dir=self.ide.workspace_directory) ]) - await autopilot.context_manager.load_index(ide.workspace_directory) - return autopilot + await self.context_manager.load_index(self.ide.workspace_directory) class Config: arbitrary_types_allowed = True diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 9fbda824..fe0946cd 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,10 +1,8 @@ -import json -import os -from .main import Step -from .context import ContextProvider from pydantic import BaseModel, validator -from typing import List, Literal, Optional, Dict, Type, Union -import yaml +from typing import List, Literal, Optional, Dict, Type + +from .main import Policy, Step +from .context import ContextProvider class SlashCommand(BaseModel): @@ -51,6 +49,7 @@ class ContinueConfig(BaseModel): on_traceback: Optional[List[OnTracebackSteps]] = [] system_message: Optional[str] = None openai_server_info: Optional[OpenAIServerInfo] = None + policy_override: Optional[Policy] = None context_providers: List[ContextProvider] = [] diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index f09b813a..dba4874f 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -59,7 +59,7 @@ class ChromaIndexManager: except: logger.warning( f"ERROR (probably found special token): {doc.text}") - continue + continue # lol filename = doc.extra_info["filename"] chunks[filename] = len(text_chunks) for i, text in enumerate(text_chunks): diff --git a/continuedev/src/continuedev/libs/chroma/update.py b/continuedev/src/continuedev/libs/chroma/update.py index 23ed950f..d5326a06 100644 --- a/continuedev/src/continuedev/libs/chroma/update.py +++ b/continuedev/src/continuedev/libs/chroma/update.py @@ -23,7 +23,7 @@ def filter_ignored_files(files: List[str], root_dir: str): """Further filter files before indexing.""" for file in files: if file.endswith(tuple(FILE_TYPES_TO_IGNORE)) or file.startswith('.git') or file.startswith('archive'): - continue + continue # nice yield root_dir + "/" + file diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index 1a66c847..69fd357b 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -12,6 +12,7 @@ from continuedev.core.sdk import ContinueSDK from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider from continuedev.plugins.context_providers.google import GoogleContextProvider +from continuedev.plugins.policies.default import DefaultPolicy from continuedev.plugins.steps.open_config import OpenConfigStep from continuedev.plugins.steps.clear_history import ClearHistoryStep @@ -114,5 +115,9 @@ config = ContinueConfig( # GoogleContextProvider( # serper_api_key="<your serper.dev api key>" # ) - ] + ], + + # Policies hold the main logic that decides which Step to take next + # You can use them to design agents, or deeply customize Continue + policy=DefaultPolicy() ) diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 4889a556..21374359 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -64,7 +64,7 @@ class GGML(LLM): try: json_chunk = line[0].decode("utf-8") if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): - continue + continue # hehe chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 654c7326..d1ca4ef9 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -62,7 +62,7 @@ class OpenAI(LLM): yield chunk.choices[0].delta.content completion += chunk.choices[0].delta.content else: - continue + continue # :) self.write_log(f"Completion: \n\n{completion}") else: diff --git a/continuedev/src/continuedev/libs/util/calculate_diff.py b/continuedev/src/continuedev/libs/util/calculate_diff.py index ff0a135f..3e82bab3 100644 --- a/continuedev/src/continuedev/libs/util/calculate_diff.py +++ b/continuedev/src/continuedev/libs/util/calculate_diff.py @@ -92,7 +92,7 @@ def calculate_diff2(filepath: str, original: str, updated: str) -> List[FileEdit tag, i1, i2, j1, j2 = s.get_opcodes()[edit_index] replacement = updated[j1:j2] if tag == "equal": - continue + continue # ;) elif tag == "delete": edits.append(FileEdit.from_deletion( filepath, Range.from_indices(original, i1, i2))) diff --git a/continuedev/src/continuedev/libs/util/strings.py b/continuedev/src/continuedev/libs/util/strings.py index f1fb8d0b..285c1e47 100644 --- a/continuedev/src/continuedev/libs/util/strings.py +++ b/continuedev/src/continuedev/libs/util/strings.py @@ -12,7 +12,7 @@ def dedent_and_get_common_whitespace(s: str) -> Tuple[str, str]: for i in range(1, len(lines)): # Empty lines are wildcards if lines[i].strip() == "": - continue + continue # hey that's us! # Iterate through the leading whitespace characters of the current line for j in range(0, len(lcp)): # If it doesn't have the same whitespace as lcp, then update lcp diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 51869fdd..2166bc37 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -39,7 +39,7 @@ def main(): json = schema_json_of(model, indent=2, title=title) except Exception as e: print(f"Failed to generate json schema for {title}: {e}") - continue + continue # pun intended with open(f"{SCHEMA_DIR}/{title}.json", "w") as f: f.write(json) diff --git a/continuedev/src/continuedev/plugins/context_providers/embeddings.py b/continuedev/src/continuedev/plugins/context_providers/embeddings.py new file mode 100644 index 00000000..42d1f754 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/embeddings.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional +import uuid +from pydantic import BaseModel + +from ...core.main import ContextItemId +from ...core.context import ContextProvider +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...libs.chroma.query import ChromaIndexManager +from .util import remove_meilisearch_disallowed_chars + + +class EmbeddingResult(BaseModel): + filename: str + content: str + + +class EmbeddingsProvider(ContextProvider): + title = "embed" + + workspace_directory: str + + EMBEDDINGS_CONTEXT_ITEM_ID = "embeddings" + + index_manager: Optional[ChromaIndexManager] = None + + class Config: + arbitrary_types_allowed = True + + @property + def index(self): + if self.index_manager is None: + self.index_manager = ChromaIndexManager(self.workspace_directory) + return self.index_manager + + @property + def BASE_CONTEXT_ITEM(self): + return ContextItem( + content="", + description=ContextItemDescription( + name="Embedding Search", + description="Enter a query to embedding search codebase", + id=ContextItemId( + provider_title=self.title, + item_id=self.EMBEDDINGS_CONTEXT_ITEM_ID + ) + ) + ) + + async def _get_query_results(self, query: str) -> str: + results = self.index.query_codebase_index(query) + + ret = [] + for node in results.source_nodes: + resource_name = list(node.node.relationships.values())[0] + filepath = resource_name[:resource_name.index("::")] + ret.append(EmbeddingResult( + filename=filepath, content=node.node.text)) + + return ret + + async def provide_context_items(self) -> List[ContextItem]: + self.index.create_codebase_index() # TODO Synchronous here is not ideal + + return [self.BASE_CONTEXT_ITEM] + + async def add_context_item(self, id: ContextItemId, query: str): + if not id.item_id == self.EMBEDDINGS_CONTEXT_ITEM_ID: + raise Exception("Invalid item id") + + results = await self._get_query_results(query) + + for i in range(len(results)): + result = results[i] + ctx_item = self.BASE_CONTEXT_ITEM.copy() + ctx_item.description.name = os.path.basename(result.filename) + ctx_item.content = f"{result.filename}\n```\n{result.content}\n```" + ctx_item.description.id.item_id = uuid.uuid4().hex + self.selected_items.append(ctx_item) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 634774df..31aa5423 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -3,6 +3,7 @@ import re from typing import List from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider +from .util import remove_meilisearch_disallowed_chars from fnmatch import fnmatch @@ -79,7 +80,7 @@ class FileContextProvider(ContextProvider): description=file, id=ContextItemId( provider_title=self.title, - item_id=re.sub(r'[^0-9a-zA-Z_-]', '', file) + item_id=remove_meilisearch_disallowed_chars(file) ) ) )) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index fc76fe67..4b0a59ec 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -2,6 +2,7 @@ import json from typing import List import aiohttp +from .util import remove_meilisearch_disallowed_chars from ...core.main import ContextItem, ContextItemDescription, ContextItemId from ...core.context import ContextProvider @@ -60,5 +61,6 @@ class GoogleContextProvider(ContextProvider): ctx_item = self.BASE_CONTEXT_ITEM.copy() ctx_item.content = content - ctx_item.description.id.item_id = query + ctx_item.description.id.item_id = remove_meilisearch_disallowed_chars( + query) return ctx_item diff --git a/continuedev/src/continuedev/plugins/context_providers/util.py b/continuedev/src/continuedev/plugins/context_providers/util.py new file mode 100644 index 00000000..da2e6b17 --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/util.py @@ -0,0 +1,5 @@ +import re + + +def remove_meilisearch_disallowed_chars(id: str) -> str: + return re.sub(r'[^0-9a-zA-Z_-]', '', id) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/plugins/policies/default.py index d90177b5..2e7573f3 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,15 +1,15 @@ from textwrap import dedent from typing import Union -from ..plugins.steps.chat import SimpleChatStep -from ..plugins.steps.welcome import WelcomeStep -from .config import ContinueConfig -from ..plugins.steps.steps_on_startup import StepsOnStartupStep -from .main import Step, History, Policy -from .observation import UserInputObservation -from ..plugins.steps.core.core import MessageStep -from ..plugins.steps.custom_command import CustomCommandStep -from ..plugins.steps.main import EditHighlightedCodeStep +from ..steps.chat import SimpleChatStep +from ..steps.welcome import WelcomeStep +from ...core.config import ContinueConfig +from ..steps.steps_on_startup import StepsOnStartupStep +from ...core.main import Step, History, Policy +from ...core.observation import UserInputObservation +from ..steps.core.core import MessageStep +from ..steps.custom_command import CustomCommandStep +from ..steps.main import EditHighlightedCodeStep def parse_slash_command(inp: str, config: ContinueConfig) -> Union[None, Step]: @@ -45,7 +45,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]: class DefaultPolicy(Policy): - ran_code_last: bool = False + + default_step: Step = SimpleChatStep() def next(self, config: ContinueConfig, history: History) -> Step: # At the very start, run initial Steps spcecified in the config @@ -56,7 +57,6 @@ class DefaultPolicy(Policy): - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue - Use `/help` to ask questions about how to use Continue""")) >> WelcomeStep() >> - # CreateCodebaseIndexChroma() >> StepsOnStartupStep()) observation = history.get_current().observation @@ -75,6 +75,6 @@ class DefaultPolicy(Policy): if user_input.startswith("/edit"): return EditHighlightedCodeStep(user_input=user_input[5:]) - return SimpleChatStep() + return self.default_step return None diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index c80cecc3..de7cf3ac 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -525,7 +525,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user # Accumulate lines if "content" not in chunk: - continue + continue # ayo chunk = chunk["content"] chunk_lines = chunk.split("\n") chunk_lines[0] = unfinished_line + chunk_lines[0] @@ -546,12 +546,12 @@ Please output the code to be inserted at the cursor in order to fulfill the user break # Lines that should be ignored, like the <> tags elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0): - continue + continue # noice # Check if we are currently just copying the prefix elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]: # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line lines_of_prefix_copied += 1 - continue + continue # also nice # Because really short lines might be expected to be repeated, this is only a !heuristic! # Stop when it starts copying the file_suffix elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()): diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 07b50473..522a84a3 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -20,10 +20,10 @@ def find_all_matches_in_dir(pattern: str, dirpath: str) -> List[RangeInFile]: for root, dirs, files in os.walk(dirpath): dirname = os.path.basename(root) if dirname.startswith(".") or dirname in IGNORE_DIRS: - continue + continue # continue! for file in files: if file in IGNORE_FILES: - continue + continue # pun intended with open(os.path.join(root, file), "r") as f: # Find the index of all occurences of the pattern in the file. Use re. file_content = f.read() diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 98a5aea0..cf18c56b 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -176,7 +176,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we message = json.loads(message) if "messageType" not in message or "data" not in message: - continue + continue # :o message_type = message["messageType"] data = message["data"] diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e4c07029..6124f3bd 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -139,7 +139,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): msg_string = await self.websocket.receive_text() message = json.loads(msg_string) if "messageType" not in message or "data" not in message: - continue + continue # <-- hey that's the name of this repo! message_type = message["messageType"] data = message["data"] logger.debug(f"Received message while initializing {message_type}") @@ -311,7 +311,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onFileEdits(self, edits: List[FileEditWithFullContents]): if autopilot := self.__get_autopilot(): - autopilot.handle_manual_edits(edits) + pass def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index cf46028f..b5580fe8 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -1,4 +1,5 @@ import os +import traceback from fastapi import WebSocket from typing import Any, Dict, List, Union from uuid import uuid4 @@ -6,12 +7,10 @@ import json from fastapi.websockets import WebSocketState -from ..plugins.steps.core.core import DisplayErrorStep +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents -from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER -from ..core.policy import DefaultPolicy -from ..core.main import FullState +from ..core.main import FullState, HistoryNode from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task @@ -31,19 +30,6 @@ class Session: self.ws = None -class DemoAutopilot(Autopilot): - first_seen: bool = False - cumulative_edit_string = "" - - def handle_manual_edits(self, edits: List[FileEditWithFullContents]): - return - for edit in edits: - self.cumulative_edit_string += edit.fileEdit.replacement - self._manual_edits_buffer.append(edit) - # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge. - # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit) - - class SessionManager: sessions: Dict[str, Session] = {} # Mapping of session_id to IDE, where the IDE is still alive @@ -65,27 +51,47 @@ class SessionManager: async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: logger.debug(f"New session: {session_id}") + # Load the persisted state (not being used right now) full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: full_state = FullState(**json.load(f)) - autopilot = await DemoAutopilot.create( - policy=DefaultPolicy(), ide=ide, full_state=full_state) + # Register the session and ide (do this first so that the autopilot can access the session) + autopilot = Autopilot(ide=ide) session_id = session_id or str(uuid4()) ide.session_id = session_id session = Session(session_id=session_id, autopilot=autopilot) self.sessions[session_id] = session self.registered_ides[session_id] = ide + # Set up the autopilot to update the GUI async def on_update(state: FullState): await session_manager.send_ws_data(session_id, "state_update", { "state": state.dict() }) autopilot.on_update(on_update) - create_async_task(autopilot.run_policy( - ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) + + # Start the autopilot (must be after session is added to sessions) and the policy + try: + await autopilot.start() + except Exception as e: + # Have to manually add to history because autopilot isn't started + formatted_err = '\n'.join(traceback.format_exception(e)) + msg_step = MessageStep( + name="Error loading context manager", message=formatted_err) + msg_step.description = f"```\n{formatted_err}\n```" + autopilot.history.add_node(HistoryNode( + step=msg_step, + observation=None, + depth=0, + active=False + )) + logger.warning(f"Error loading context manager: {e}") + + create_async_task(autopilot.run_policy(), lambda e: autopilot.continue_sdk.run_step( + DisplayErrorStep(e=e))) return session async def remove_session(self, session_id: str): diff --git a/extension/src/suggestions.ts b/extension/src/suggestions.ts index 5c2b8860..b5be341d 100644 --- a/extension/src/suggestions.ts +++ b/extension/src/suggestions.ts @@ -72,7 +72,7 @@ export function rerenderDecorations(editorUri: string) { range.end.character === 0 ) { // Empty range, don't show it - continue; + continue; // is great } newRanges.push( new vscode.Range( |