diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-22 22:37:13 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-22 22:37:13 -0700 |
commit | 4d7e72970f770eb49627589fb142c93dfb6fd73b (patch) | |
tree | 7c85fb17a9e10ac8e387a001f021aa45c8c46582 /continuedev/src | |
parent | 007780d6d60095d4e0b238358ec26b2ec776b73e (diff) | |
download | sncontinue-4d7e72970f770eb49627589fb142c93dfb6fd73b.tar.gz sncontinue-4d7e72970f770eb49627589fb142c93dfb6fd73b.tar.bz2 sncontinue-4d7e72970f770eb49627589fb142c93dfb6fd73b.zip |
@ feature (very large commit)
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 138 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 205 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/context_manager.py | 119 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 59 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 31 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py | 191 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 20 | ||||
-rw-r--r-- | continuedev/src/continuedev/models/generate_json_schema.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 37 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 28 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 56 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 6 |
14 files changed, 608 insertions, 298 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index abda50b0..c0f95414 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,10 +9,12 @@ from pydantic import root_validator from ..models.filesystem import RangeInFileWithContents from ..models.filesystem_edit import FileEditWithFullContents from .observation import Observation, InternalErrorObservation +from .context import ContextItem, ContextItemDescription, ContextItemId, ContextManager +from ..libs.context_providers.highlighted_code_context_provider import HighlightedCodeContextProvider from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, HighlightedRangeContext, Policy, History, FullState, Step, HistoryNode +from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep from ..libs.util.telemetry import capture_event from .sdk import ContinueSDK @@ -47,10 +49,11 @@ class Autopilot(ContinueBaseModel): history: History = History.from_empty() context: Context = Context() full_state: Union[FullState, None] = None - _on_update_callbacks: List[Callable[[FullState], None]] = [] - + context_manager: Union[ContextManager, None] = None continue_sdk: ContinueSDK = None + _on_update_callbacks: List[Callable[[FullState], None]] = [] + _active: bool = False _should_halt: bool = False _main_user_input_queue: List[str] = [] @@ -62,6 +65,14 @@ class Autopilot(ContinueBaseModel): async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot": autopilot = cls(ide=ide, policy=policy) autopilot.continue_sdk = await ContinueSDK.create(autopilot) + + # Load documents into the search index + autopilot.context_manager = ContextManager( + autopilot.continue_sdk.config.context_providers + [ + HighlightedCodeContextProvider(ide=ide) + ]) + await autopilot.context_manager.load_index() + return autopilot class Config: @@ -75,15 +86,16 @@ class Autopilot(ContinueBaseModel): values['history'] = full_state.history return values - def get_full_state(self) -> FullState: + async def get_full_state(self) -> FullState: full_state = FullState( history=self.history, active=self._active, user_input_queue=self._main_user_input_queue, default_model=self.continue_sdk.config.default_model, - highlighted_ranges=self._highlighted_ranges, slash_commands=self.get_available_slash_commands(), - adding_highlighted_code=self._adding_highlighted_code, + adding_highlighted_code=self.context_manager.context_providers[ + "code"].adding_highlighted_code, + selected_context_items=await self.context_manager.get_selected_items() ) self.full_state = full_state return full_state @@ -104,8 +116,8 @@ class Autopilot(ContinueBaseModel): self._main_user_input_queue = [] self._active = False - # Also remove all context - self._highlighted_ranges = [] + # Clear context + await self.context_manager.clear_context() await self.update_subscribers() @@ -114,7 +126,7 @@ class Autopilot(ContinueBaseModel): self._on_update_callbacks.append(callback) async def update_subscribers(self): - full_state = self.get_full_state() + full_state = await self.get_full_state() for callback in self._on_update_callbacks: await callback(full_state) @@ -159,81 +171,10 @@ class Autopilot(ContinueBaseModel): step = tb_step.step({"output": output, **tb_step.params}) await self._run_singular_step(step) - _highlighted_ranges: List[HighlightedRangeContext] = [] - _adding_highlighted_code: bool = False - - def _make_sure_is_editing_range(self): - """If none of the highlighted ranges are currently being edited, the first should be selected""" - if len(self._highlighted_ranges) == 0: - return - if not any(map(lambda x: x.editing, self._highlighted_ranges)): - self._highlighted_ranges[0].editing = True - - def _disambiguate_highlighted_ranges(self): - """If any files have the same name, also display their folder name""" - name_status: Dict[str, set] = { - } # basename -> set of full paths with that basename - for rif in self._highlighted_ranges: - basename = os.path.basename(rif.range.filepath) - if basename in name_status: - name_status[basename].add(rif.range.filepath) - else: - name_status[basename] = {rif.range.filepath} - - for rif in self._highlighted_ranges: - basename = os.path.basename(rif.range.filepath) - if len(name_status[basename]) > 1: - rif.display_name = os.path.join( - os.path.basename(os.path.dirname(rif.range.filepath)), basename) - else: - rif.display_name = basename - async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): - # Filter out rifs from ~/.continue/diffs folder - range_in_files = [ - rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] - - # Make sure all filepaths are relative to workspace - workspace_path = self.continue_sdk.ide.workspace_directory - - # If not adding highlighted code - if not self._adding_highlighted_code: - if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): - # If un-highlighting the range to edit, then remove the range - self._highlighted_ranges = [] - await self.update_subscribers() - elif len(range_in_files) > 0: - # Otherwise, replace the current range with the new one - # This is the first range to be highlighted - self._highlighted_ranges = [HighlightedRangeContext( - range=range_in_files[0], editing=True, pinned=False, display_name=os.path.basename(range_in_files[0].filepath))] - await self.update_subscribers() - return - - # If current range overlaps with any others, delete them and only keep the new range - new_ranges = [] - for i, rif in enumerate(self._highlighted_ranges): - found_overlap = False - for new_rif in range_in_files: - if rif.range.filepath == new_rif.filepath and rif.range.range.overlaps_with(new_rif.range): - found_overlap = True - break - - # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids - # the bug where cmd+f causes repeated highlights - if rif.range.filepath == new_rif.filepath and rif.range.contents == new_rif.contents: - found_overlap = True - break - - if not found_overlap: - new_ranges.append(rif) - - self._highlighted_ranges = new_ranges + [HighlightedRangeContext( - range=rif, editing=False, pinned=False, display_name=os.path.basename(rif.filepath) - ) for rif in range_in_files] - - self._make_sure_is_editing_range() - self._disambiguate_highlighted_ranges() + # Add to context manager + await self.context_manager.context_providers["code"].handle_highlighted_code( + range_in_files) await self.update_subscribers() @@ -250,29 +191,16 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() - async def delete_context_at_indices(self, indices: List[int]): - kept_ranges = [] - for i, rif in enumerate(self._highlighted_ranges): - if i not in indices: - kept_ranges.append(rif) - self._highlighted_ranges = kept_ranges - - self._make_sure_is_editing_range() - + async def delete_context_with_ids(self, ids: List[str]): + await self.context_manager.delete_context_with_ids(ids) await self.update_subscribers() async def toggle_adding_highlighted_code(self): - self._adding_highlighted_code = not self._adding_highlighted_code - await self.update_subscribers() - - async def set_editing_at_indices(self, indices: List[int]): - for i in range(len(self._highlighted_ranges)): - self._highlighted_ranges[i].editing = i in indices + self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code await self.update_subscribers() - async def set_pinned_at_indices(self, indices: List[int]): - for i in range(len(self._highlighted_ranges)): - self._highlighted_ranges[i].pinned = i in indices + async def set_editing_at_ids(self, ids: List[str]): + self.context_manager.context_providers["code"].set_editing_at_ids(ids) await self.update_subscribers() async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: @@ -437,10 +365,6 @@ class Autopilot(ContinueBaseModel): if len(self._main_user_input_queue) > 1: return - # Remove context unless pinned - # self._highlighted_ranges = [ - # hr for hr in self._highlighted_ranges if hr.pinned] - # await self._request_halt() # Just run the step that takes user input, and # then up to the policy to decide how to deal with it. @@ -456,3 +380,7 @@ class Autopilot(ContinueBaseModel): await self._request_halt() await self.reverse_to_index(index) await self.run_from_step(UserInputStep(user_input=user_input)) + + async def select_context_item(self, id: str, query: str): + await self.context_manager.select_context_item(id, query) + await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 54f15143..bb9ca323 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,6 +1,7 @@ import json import os from .main import Step +from .context import ContextProvider from pydantic import BaseModel, validator from typing import List, Literal, Optional, Dict, Type, Union import yaml @@ -50,6 +51,8 @@ class ContinueConfig(BaseModel): system_message: Optional[str] = None azure_openai_info: Optional[AzureInfo] = None + context_providers: List[ContextProvider] = [] + # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) def default_slash_commands_validator(cls, v): diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py new file mode 100644 index 00000000..67bba651 --- /dev/null +++ b/continuedev/src/continuedev/core/context.py @@ -0,0 +1,205 @@ + +from abc import abstractmethod +from typing import Dict, List +import meilisearch +from pydantic import BaseModel + + +from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId +from ..server.meilisearch_server import check_meilisearch_running + + +SEARCH_INDEX_NAME = "continue_context_items" + + +class ContextProvider(BaseModel): + """ + The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. + When you type '@', the context provider will be asked to populate a list of options. + These options will be updated on each keystroke. + When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). + """ + + title: str + + selected_items: List[ContextItem] = [] + + async def get_selected_items(self) -> List[ContextItem]: + """ + Returns all of the selected ContextItems. + + Default implementation simply returns self.selected_items. + + Other implementations may add an async processing step. + """ + return self.selected_items + + @abstractmethod + async def provide_context_items(self) -> List[ContextItem]: + """ + Provide documents for search index. This is run on startup. + + This is the only method that must be implemented. + """ + + async def get_chat_messages(self) -> List[ChatMessage]: + """ + Returns all of the chat messages for the context provider. + + Default implementation has a string template. + """ + return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] + + async def get_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client) -> ContextItem: + """ + Returns the ContextItem with the given id. + + Default implementation uses the search index to get the item. + """ + result = search_client.index( + SEARCH_INDEX_NAME).get_document(id.to_string()) + return ContextItem( + description=ContextItemDescription( + name=result.name, + description=result.description, + id=id + ), + content=result.content + ) + + async def delete_context_with_ids(self, ids: List[ContextItemId]): + """ + Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + + Default implementation simply deletes those with the given ids. + """ + id_strings = {id.to_string() for id in ids} + self.selected_items = list( + filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items)) + + async def clear_context(self): + """ + Clears all context. + + Default implementation simply clears the selected items. + """ + self.selected_items = [] + + async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client): + """ + Adds the given ContextItem to the list of ContextItems. + + Default implementation simply appends the item, not allowing duplicates. + + This method also allows you not to have to load all of the information until an item is selected. + """ + + # Don't add duplicate context + for item in self.selected_items: + if item.description.id.item_id == id.item_id: + return + + new_item = await self.get_item(id, query, search_client) + self.selected_items.append(new_item) + + +class ContextManager: + """ + The context manager is responsible for storing the context to be passed to the LLM, including + - ContextItems (highlighted code, GitHub Issues, etc.) + - ChatMessages in the history + - System Message + - Functions + + It is responsible for compiling all of this information into a single prompt without exceeding the token limit. + """ + + async def get_selected_items(self) -> List[ContextItem]: + """ + Returns all of the selected ContextItems. + """ + return sum([await provider.get_selected_items() for provider in self.context_providers.values()], []) + + async def get_chat_messages(self) -> List[ChatMessage]: + """ + Returns chat messages from each provider. + """ + return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) + + search_client: meilisearch.Client + + def __init__(self, context_providers: List[ContextProvider]): + self.search_client = meilisearch.Client('http://localhost:7700') + + # If meilisearch isn't running, don't use any ContextProviders that might depend on it + if not check_meilisearch_running(): + context_providers = list( + filter(lambda cp: cp.title == "code", context_providers)) + + self.context_providers = { + prov.title: prov for prov in context_providers} + self.provider_titles = { + provider.title for provider in context_providers} + + async def load_index(self): + for _, provider in self.context_providers.items(): + context_items = await provider.provide_context_items() + documents = [ + { + "id": item.description.id.to_string(), + "name": item.description.name, + "description": item.description.description, + "content": item.content + } + for item in context_items + ] + if len(documents) > 0: + self.search_client.index( + SEARCH_INDEX_NAME).add_documents(documents) + + # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: + # """ + # Compiles the chat prompt into a single string. + # """ + # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) + + async def select_context_item(self, id: str, query: str): + """ + Selects the ContextItem with the given id. + """ + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in self.provider_titles: + raise ValueError( + f"Context provider with title {id.provider_title} not found") + + await self.context_providers[id.provider_title].add_context_item(id, query, self.search_client) + + async def delete_context_with_ids(self, ids: List[str]): + """ + Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + """ + + # Group by provider title + provider_title_to_ids: Dict[str, List[ContextItemId]] = {} + for id in ids: + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in provider_title_to_ids: + provider_title_to_ids[id.provider_title] = [] + provider_title_to_ids[id.provider_title].append(id) + + # Recalculate context for each updated provider + for provider_title, ids in provider_title_to_ids.items(): + await self.context_providers[provider_title].delete_context_with_ids(ids) + + async def clear_context(self): + """ + Clears all context. + """ + for provider in self.context_providers.values(): + await self.context_providers[provider.title].clear_context() + + +""" +Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the +same format of prompts so you don't have to redo all of this logic. +""" diff --git a/continuedev/src/continuedev/core/context_manager.py b/continuedev/src/continuedev/core/context_manager.py deleted file mode 100644 index 37905535..00000000 --- a/continuedev/src/continuedev/core/context_manager.py +++ /dev/null @@ -1,119 +0,0 @@ - -from abc import ABC, abstractmethod, abstractproperty -from ast import List -from pydantic import BaseModel - -from ..libs.util.count_tokens import compile_chat_messages - - -class ContextItemDescription(BaseModel): - """ - A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'. - - The id can be used to retrieve the ContextItem from the ContextManager. - """ - name: str - description: str - id: str - - -class ContextItem(BaseModel): - """ - A ContextItem is a single item that is stored in the ContextManager. - """ - description: ContextItemDescription - content: str - - -class ContextManager(ABC): - """ - The context manager is responsible for storing the context to be passed to the LLM, including - - ContextItems (highlighted code, GitHub Issues, etc.) - - ChatMessages in the history - - System Message - - Functions - - It is responsible for compiling all of this information into a single prompt without exceeding the token limit. - """ - - def compile_chat_messages(self, max_tokens: int) -> List[Dict]: - """ - Compiles the chat prompt into a single string. - """ - return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) - - -""" -Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the -same format of prompts so you don't have to redo all of this logic. -""" - - -class ContextProvider(ABC): - """ - The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. - When you type '@', the context provider will be asked to populate a list of options. - These options will be updated on each keystroke. - When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). - """ - - title: str - - @abstractmethod - async def load(self): - """ - Loads the ContextProvider, possibly reading persisted data from disk. This will be called on startup. - """ - - @abstractmethod - async def save(self): - """ - Saves the ContextProvider, possibly writing persisted data to disk. This will be called upon cache refresh. - """ - - @abstractmethod - async def refresh_cache(self): - """ - Refreshes the cache of items. This will be called on startup and periodically. - """ - - @abstractmethod - async def get_item_descriptions(self, query: str) -> List[ContextItemDescription]: - """ - Returns a list of options that should be displayed to the user. - """ - - @abstractmethod - async def get_item(self, id: str) -> ContextItem: - """ - Returns the ContextItem with the given id. This allows you not to have to load all of the information until an item is selected. - """ - - @abstractmethod - async def should_refresh(self) -> bool: - """ - Returns whether the ContextProvider should be refreshed. - - For example, embeddings might need to be recalculated after commits, - or GitHub issues might need to be refreshed after a new issue is created. - - This method will be called every startup? Every once in a while? Every hour? - User defined? Maybe just have a schedule instead of this method. - """ - - -class GitHubIssuesContextProvider(ContextProvider): - """ - The GitHubIssuesContextProvider is a ContextProvider that allows you to search GitHub issues in a repo. - """ - - title = "issues" - - def __init__(self, repo: str): - self.repo = repo - - async def get_item_descriptions(self, query: str) -> List[ContextItemDescription]: - pass - - async def get_item(self, id: str) -> ContextItem: - pass diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 50d01f8d..6c6adccc 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,12 +1,11 @@ import json -from textwrap import dedent -from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union +from typing import Coroutine, Dict, List, Literal, Union +from pydantic.schema import schema + -from ..models.filesystem import RangeInFileWithContents from ..models.main import ContinueBaseModel -from pydantic import validator +from pydantic import BaseModel, validator from .observation import Observation -from pydantic.schema import schema ChatMessageRole = Literal["assistant", "user", "system", "function"] @@ -201,12 +200,48 @@ class SlashCommandDescription(ContinueBaseModel): description: str -class HighlightedRangeContext(ContinueBaseModel): - """Context for a highlighted range""" - range: RangeInFileWithContents - editing: bool - pinned: bool - display_name: str +class ContextItemId(BaseModel): + """ + A ContextItemId is a unique identifier for a ContextItem. + """ + provider_title: str + item_id: str + + def to_string(self) -> str: + return f"{self.provider_title}-{self.item_id}" + + @staticmethod + def from_string(string: str) -> 'ContextItemId': + provider_title, item_id = string.split('-') + return ContextItemId(provider_title=provider_title, item_id=item_id) + + +class ContextItemDescription(BaseModel): + """ + A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'. + + The id can be used to retrieve the ContextItem from the ContextManager. + """ + name: str + description: str + id: ContextItemId + + +class ContextItem(BaseModel): + """ + A ContextItem is a single item that is stored in the ContextManager. + """ + description: ContextItemDescription + content: str + + @validator('content', pre=True) + def content_must_be_string(cls, v): + if v is None: + return '' + return v + + editing: bool = False + editable: bool = False class FullState(ContinueBaseModel): @@ -215,9 +250,9 @@ class FullState(ContinueBaseModel): active: bool user_input_queue: List[str] default_model: str - highlighted_ranges: List[HighlightedRangeContext] slash_commands: List[SlashCommandDescription] adding_highlighted_code: bool + selected_context_items: List[ContextItem] class ContinueSDK: diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 4100efa6..59f33707 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -3,8 +3,10 @@ from functools import cached_property from typing import Coroutine, Dict, Union import os + from ..steps.core.core import DefaultModelEditCodeStep from ..models.main import Range +from .context import ContextItem from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig, load_config, load_global_config, update_global_config from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory @@ -289,28 +291,13 @@ class ContinueSDK(AbstractContinueSDK): async def get_chat_context(self) -> List[ChatMessage]: history_context = self.history.to_chat_history() - highlighted_code = [ - hr.range for hr in self.__autopilot._highlighted_ranges] - - preface = "The following code is highlighted" - - # If no higlighted ranges, use first file as context - if len(highlighted_code) == 0: - preface = "The following file is open" - visible_files = await self.ide.getVisibleFiles() - if len(visible_files) > 0: - content = await self.ide.readFile(visible_files[0]) - highlighted_code = [ - RangeInFileWithContents.from_entire_file(visible_files[0], content)] - - for rif in highlighted_code: - msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", - role="user", summary=f"{preface}: {rif.filepath}") - - # Don't insert after latest user message or function call - i = -1 - if len(history_context) > 0 and (history_context[i].role == "user" or history_context[i].role == "function"): - i -= 1 + + context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages() + + # Insert at the end, but don't insert after latest user message or function call + i = -2 if (len(history_context) > 0 and ( + history_context[-1].role == "user" or history_context[-1].role == "function")) else -1 + for msg in context_messages: history_context.insert(i, msg) return history_context diff --git a/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py b/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py new file mode 100644 index 00000000..23d4fc86 --- /dev/null +++ b/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py @@ -0,0 +1,191 @@ +import os +from typing import Any, Dict, List + +import meilisearch +from ...core.main import ChatMessage +from ...models.filesystem import RangeInFile, RangeInFileWithContents +from ...core.context import ContextItem, ContextItemDescription, ContextItemId +from pydantic import BaseModel + + +class HighlightedRangeContextItem(BaseModel): + rif: RangeInFileWithContents + item: ContextItem + + +class HighlightedCodeContextProvider(BaseModel): + """ + The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. + When you type '@', the context provider will be asked to populate a list of options. + These options will be updated on each keystroke. + When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). + """ + + title = "code" + + ide: Any # IdeProtocolServer + + highlighted_ranges: List[HighlightedRangeContextItem] = [] + adding_highlighted_code: bool = False + + should_get_fallback_context_item: bool = True + last_added_fallback: bool = False + + async def _get_fallback_context_item(self) -> HighlightedRangeContextItem: + if not self.should_get_fallback_context_item: + return None + + visible_files = await self.ide.getVisibleFiles() + if len(visible_files) > 0: + content = await self.ide.readFile(visible_files[0]) + rif = RangeInFileWithContents.from_entire_file( + visible_files[0], content) + + item = self._rif_to_context_item(rif, 0, True) + item.description.name = self._rif_to_name( + rif, show_line_nums=False) + + self.last_added_fallback = True + return HighlightedRangeContextItem(rif=rif, item=item) + + return None + + async def get_selected_items(self) -> List[ContextItem]: + items = [hr.item for hr in self.highlighted_ranges] + + if len(items) == 0 and (fallback_item := await self._get_fallback_context_item()): + items = [fallback_item.item] + + return items + + async def get_chat_messages(self) -> List[ContextItem]: + ranges = self.highlighted_ranges + if len(ranges) == 0 and (fallback_item := await self._get_fallback_context_item()): + ranges = [fallback_item] + + return [ChatMessage( + role="user", + content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", + summary=f"Code in this file is highlighted: {r.rif.filepath}" + ) for r in ranges] + + def _make_sure_is_editing_range(self): + """If none of the highlighted ranges are currently being edited, the first should be selected""" + if len(self.highlighted_ranges) == 0: + return + if not any(map(lambda x: x.item.editing, self.highlighted_ranges)): + self.highlighted_ranges[0].item.editing = True + + def _disambiguate_highlighted_ranges(self): + """If any files have the same name, also display their folder name""" + name_status: Dict[str, set] = { + } # basename -> set of full paths with that basename + for hr in self.highlighted_ranges: + basename = os.path.basename(hr.rif.filepath) + if basename in name_status: + name_status[basename].add(hr.rif.filepath) + else: + name_status[basename] = {hr.rif.filepath} + + for hr in self.highlighted_ranges: + if len(name_status[basename]) > 1: + hr.item.description.name = self._rif_to_name(hr.rif, display_filename=os.path.join( + os.path.basename(os.path.dirname(hr.rif.filepath)), basename)) + else: + hr.item.description.name = self._rif_to_name( + hr.rif, display_filename=basename) + + async def provide_context_items(self) -> List[ContextItem]: + return [] + + async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]: + indices_to_delete = [ + int(id.item_id) for id in ids + ] + + kept_ranges = [] + for i, hr in enumerate(self.highlighted_ranges): + if i not in indices_to_delete: + kept_ranges.append(hr) + self.highlighted_ranges = kept_ranges + + self._make_sure_is_editing_range() + + if len(self.highlighted_ranges) == 0 and self.last_added_fallback: + self.should_get_fallback_context_item = False + + return [hr.item for hr in self.highlighted_ranges] + + def _rif_to_name(self, rif: RangeInFileWithContents, display_filename: str = None, show_line_nums: bool = True) -> str: + line_nums = f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" if show_line_nums else "" + return f"{display_filename or os.path.basename(rif.filepath)}{line_nums}" + + def _rif_to_context_item(self, rif: RangeInFileWithContents, idx: int, editing: bool) -> ContextItem: + return ContextItem( + description=ContextItemDescription( + name=self._rif_to_name(rif), + description=rif.filepath, + id=ContextItemId( + provider_title=self.title, + item_id=str(idx) + ) + ), + content=rif.contents, + editing=editing, + editable=True + ) + + async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): + self.should_get_fallback_context_item = True + self.last_added_fallback = False + + # Filter out rifs from ~/.continue/diffs folder + range_in_files = [ + rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] + + # If not adding highlighted code + if not self.adding_highlighted_code: + if len(self.highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): + # If un-highlighting the range to edit, then remove the range + self.highlighted_ranges = [] + elif len(range_in_files) > 0: + # Otherwise, replace the current range with the new one + # This is the first range to be highlighted + self.highlighted_ranges = [ + HighlightedRangeContextItem( + rif=range_in_files[0], + item=self._rif_to_context_item(range_in_files[0], 0, True))] + + return + + # If current range overlaps with any others, delete them and only keep the new range + new_ranges = [] + for i, hr in enumerate(self.highlighted_ranges): + found_overlap = False + for new_rif in range_in_files: + if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with(new_rif.range): + found_overlap = True + break + + # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids + # the bug where cmd+f causes repeated highlights + if hr.rif.filepath == new_rif.filepath and hr.rif.contents == new_rif.contents: + found_overlap = True + break + + if not found_overlap: + new_ranges.append(HighlightedRangeContextItem(rif=hr.rif, item=self._rif_to_context_item( + hr.rif, len(new_ranges), False))) + + self.highlighted_ranges = new_ranges + [HighlightedRangeContextItem(rif=rif, item=self._rif_to_context_item( + rif, len(new_ranges) + idx, False)) for idx, rif in enumerate(range_in_files)] + + self._make_sure_is_editing_range() + self._disambiguate_highlighted_ranges() + + async def set_editing_at_ids(self, ids: List[str]): + for hr in self.highlighted_ranges: + hr.item.editing = hr.item.description.id.to_string() in ids + + async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client, prev: List[ContextItem] = None) -> List[ContextItem]: + raise NotImplementedError() diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index fddef887..d6ce13b3 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -2,16 +2,26 @@ import os from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER -def getGlobalFolderPath(): - return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) +def getGlobalFolderPath(): + path = os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) + os.makedirs(path, exist_ok=True) + return path def getSessionsFolderPath(): - return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) + path = os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) + os.makedirs(path, exist_ok=True) + return path + def getServerFolderPath(): - return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) + path = os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) + os.makedirs(path, exist_ok=True) + return path + def getSessionFilePath(session_id: str): - return os.path.join(getSessionsFolderPath(), f"{session_id}.json")
\ No newline at end of file + path = os.path.join(getSessionsFolderPath(), f"{session_id}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + return path diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 6cebf429..06614984 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -2,6 +2,7 @@ from .main import * from .filesystem import RangeInFile, FileEdit from .filesystem_edit import FileEditWithFullContents from ..core.main import History, HistoryNode, FullState +from ..core.context import ContextItem from pydantic import schema_json_of import os @@ -13,6 +14,8 @@ MODELS_TO_GENERATE = [ FileEditWithFullContents ] + [ History, HistoryNode, FullState +] + [ + ContextItem ] RENAMES = { diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index ae57c0b6..36b2f3fa 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -91,25 +91,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer): self.on_clear_history() elif message_type == "delete_at_index": self.on_delete_at_index(data["index"]) - elif message_type == "delete_context_at_indices": - self.on_delete_context_at_indices(data["indices"]) + elif message_type == "delete_context_with_ids": + self.on_delete_context_with_ids(data["ids"]) elif message_type == "toggle_adding_highlighted_code": self.on_toggle_adding_highlighted_code() elif message_type == "set_editing_at_indices": self.on_set_editing_at_indices(data["indices"]) - elif message_type == "set_pinned_at_indices": - self.on_set_pinned_at_indices(data["indices"]) elif message_type == "show_logs_at_index": self.on_show_logs_at_index(data["index"]) + elif message_type == "select_context_item": + self.select_context_item(data["id"], data["query"]) except Exception as e: print(e) - async def send_state_update(self): - state = self.session.autopilot.get_full_state().dict() - await self._send_json("state_update", { - "state": state - }) - def on_main_input(self, input: str): # Do something with user input create_async_task(self.session.autopilot.accept_user_input( @@ -144,10 +138,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task(self.session.autopilot.delete_at_index( index), self.session.autopilot.continue_sdk.ide.unique_id) - def on_delete_context_at_indices(self, indices: List[int]): + def on_delete_context_with_ids(self, ids: List[str]): create_async_task( - self.session.autopilot.delete_context_at_indices( - indices), self.session.autopilot.continue_sdk.ide.unique_id + self.session.autopilot.delete_context_with_ids( + ids), self.session.autopilot.continue_sdk.ide.unique_id ) def on_toggle_adding_highlighted_code(self): @@ -162,18 +156,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer): indices), self.session.autopilot.continue_sdk.ide.unique_id ) - def on_set_pinned_at_indices(self, indices: List[int]): - create_async_task( - self.session.autopilot.set_pinned_at_indices( - indices), self.session.autopilot.continue_sdk.ide.unique_id - ) - def on_show_logs_at_index(self, index: int): name = f"continue_logs.txt" logs = "\n\n############################################\n\n".join( ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs)) + self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id) + + def select_context_item(self, id: str, query: str): + """Called when user selects an item from the dropdown""" + create_async_task( + self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id) @router.websocket("/ws") @@ -188,7 +181,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we protocol.websocket = websocket # Update any history that may have happened before connection - await protocol.send_state_update() + await protocol.session.autopilot.update_subscribers() while AppStatus.should_exit is False: message = await websocket.receive_text() @@ -214,5 +207,5 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() - session_manager.persist_session(session.session_id) + await session_manager.persist_session(session.session_id) session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 9766fcd0..fb230216 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,6 +1,8 @@ from typing import Any, Dict, List from abc import ABC, abstractmethod +from ..core.context import ContextItem + class AbstractGUIProtocolServer(ABC): @abstractmethod @@ -24,10 +26,6 @@ class AbstractGUIProtocolServer(ABC): """Called when the user inputs a step""" @abstractmethod - async def send_state_update(self, state: dict): - """Send a state update to the client""" - - @abstractmethod def on_retry_at_index(self, index: int): """Called when the user requests a retry at a previous index""" @@ -42,3 +40,7 @@ class AbstractGUIProtocolServer(ABC): @abstractmethod def on_delete_at_index(self, index: int): """Called when the user requests to delete a step at a given index""" + + @abstractmethod + def select_context_item(self, id: str, query: str): + """Called when user selects an item from the dropdown""" diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 42dc0cc1..7ee64041 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,15 +1,20 @@ +import asyncio +import subprocess import time +import meilisearch import psutil import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from .ide import router as ide_router -from .gui import router as gui_router -from .session_manager import session_manager import atexit import uvicorn import argparse +from .ide import router as ide_router +from .gui import router as gui_router +from .session_manager import session_manager +from .meilisearch_server import start_meilisearch + app = FastAPI() app.include_router(ide_router) @@ -41,15 +46,20 @@ args = parser.parse_args() # log_file = open('output.log', 'a') # sys.stdout = log_file - def run_server(): uvicorn.run(app, host="0.0.0.0", port=args.port) -def cleanup(): +async def cleanup_coroutine(): print("Cleaning up sessions") for session_id in session_manager.sessions: - session_manager.persist_session(session_id) + await session_manager.persist_session(session_id) + + +def cleanup(): + loop = asyncio.new_event_loop() + loop.run_until_complete(cleanup_coroutine()) + loop.close() def cpu_usage_report(): @@ -77,6 +87,12 @@ if __name__ == "__main__": # cpu_thread = threading.Thread(target=cpu_usage_loop) # cpu_thread.start() + try: + start_meilisearch() + except Exception as e: + print("Failed to start MeiliSearch") + print(e) + run_server() except Exception as e: cleanup() diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py new file mode 100644 index 00000000..419f081f --- /dev/null +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -0,0 +1,56 @@ +import os +import subprocess + +import meilisearch +from ..libs.util.paths import getServerFolderPath + + +def check_meilisearch_installed() -> bool: + """ + Checks if MeiliSearch is installed. + """ + + serverPath = getServerFolderPath() + meilisearchPath = os.path.join(serverPath, "meilisearch") + + return os.path.exists(meilisearchPath) + + +def check_meilisearch_running() -> bool: + """ + Checks if MeiliSearch is running. + """ + + try: + client = meilisearch.Client('http://localhost:7700') + resp = client.health() + if resp["status"] != "available": + return False + return True + except Exception: + return False + + +def start_meilisearch(): + """ + Starts the MeiliSearch server, wait for it. + """ + + # Doesn't work on windows for now + if not os.name == "posix": + return + + serverPath = getServerFolderPath() + + # Check if MeiliSearch is installed + if not check_meilisearch_installed(): + # Download MeiliSearch + print("Downloading MeiliSearch...") + subprocess.run( + f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) + + # Check if MeiliSearch is running + if not check_meilisearch_running(): + print("Starting MeiliSearch...") + subprocess.Popen(["./meilisearch"], cwd=serverPath, stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 90172a4e..96daf92c 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -74,7 +74,7 @@ class SessionManager: async def on_update(state: FullState): await session_manager.send_ws_data(session_id, "state_update", { - "state": autopilot.get_full_state().dict() + "state": state.dict() }) autopilot.on_update(on_update) @@ -84,9 +84,9 @@ class SessionManager: def remove_session(self, session_id: str): del self.sessions[session_id] - def persist_session(self, session_id: str): + async def persist_session(self, session_id: str): """Save the session's FullState as a json file""" - full_state = self.sessions[session_id].autopilot.get_full_state() + full_state = await self.sessions[session_id].autopilot.get_full_state() if not os.path.exists(getSessionsFolderPath()): os.mkdir(getSessionsFolderPath()) with open(getSessionFilePath(session_id), "w") as f: |