diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 138 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 205 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/context_manager.py | 119 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 59 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 31 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py | 191 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/paths.py | 20 | ||||
| -rw-r--r-- | continuedev/src/continuedev/models/generate_json_schema.py | 3 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 37 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/main.py | 28 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 56 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/session_manager.py | 6 | 
14 files changed, 608 insertions, 298 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index abda50b0..c0f95414 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -9,10 +9,12 @@ from pydantic import root_validator  from ..models.filesystem import RangeInFileWithContents  from ..models.filesystem_edit import FileEditWithFullContents  from .observation import Observation, InternalErrorObservation +from .context import ContextItem, ContextItemDescription, ContextItemId, ContextManager +from ..libs.context_providers.highlighted_code_context_provider import HighlightedCodeContextProvider  from ..server.ide_protocol import AbstractIdeProtocolServer  from ..libs.util.queue import AsyncSubscriptionQueue  from ..models.main import ContinueBaseModel -from .main import Context, ContinueCustomException, HighlightedRangeContext, Policy, History, FullState, Step, HistoryNode +from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode  from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep  from ..libs.util.telemetry import capture_event  from .sdk import ContinueSDK @@ -47,10 +49,11 @@ class Autopilot(ContinueBaseModel):      history: History = History.from_empty()      context: Context = Context()      full_state: Union[FullState, None] = None -    _on_update_callbacks: List[Callable[[FullState], None]] = [] - +    context_manager: Union[ContextManager, None] = None      continue_sdk: ContinueSDK = None +    _on_update_callbacks: List[Callable[[FullState], None]] = [] +      _active: bool = False      _should_halt: bool = False      _main_user_input_queue: List[str] = [] @@ -62,6 +65,14 @@ class Autopilot(ContinueBaseModel):      async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":          autopilot = cls(ide=ide, policy=policy)          autopilot.continue_sdk = await ContinueSDK.create(autopilot) + +        # Load documents into the search index +        autopilot.context_manager = ContextManager( +            autopilot.continue_sdk.config.context_providers + [ +                HighlightedCodeContextProvider(ide=ide) +            ]) +        await autopilot.context_manager.load_index() +          return autopilot      class Config: @@ -75,15 +86,16 @@ class Autopilot(ContinueBaseModel):              values['history'] = full_state.history          return values -    def get_full_state(self) -> FullState: +    async def get_full_state(self) -> FullState:          full_state = FullState(              history=self.history,              active=self._active,              user_input_queue=self._main_user_input_queue,              default_model=self.continue_sdk.config.default_model, -            highlighted_ranges=self._highlighted_ranges,              slash_commands=self.get_available_slash_commands(), -            adding_highlighted_code=self._adding_highlighted_code, +            adding_highlighted_code=self.context_manager.context_providers[ +                "code"].adding_highlighted_code, +            selected_context_items=await self.context_manager.get_selected_items()          )          self.full_state = full_state          return full_state @@ -104,8 +116,8 @@ class Autopilot(ContinueBaseModel):          self._main_user_input_queue = []          self._active = False -        # Also remove all context -        self._highlighted_ranges = [] +        # Clear context +        await self.context_manager.clear_context()          await self.update_subscribers() @@ -114,7 +126,7 @@ class Autopilot(ContinueBaseModel):          self._on_update_callbacks.append(callback)      async def update_subscribers(self): -        full_state = self.get_full_state() +        full_state = await self.get_full_state()          for callback in self._on_update_callbacks:              await callback(full_state) @@ -159,81 +171,10 @@ class Autopilot(ContinueBaseModel):                      step = tb_step.step({"output": output, **tb_step.params})                      await self._run_singular_step(step) -    _highlighted_ranges: List[HighlightedRangeContext] = [] -    _adding_highlighted_code: bool = False - -    def _make_sure_is_editing_range(self): -        """If none of the highlighted ranges are currently being edited, the first should be selected""" -        if len(self._highlighted_ranges) == 0: -            return -        if not any(map(lambda x: x.editing, self._highlighted_ranges)): -            self._highlighted_ranges[0].editing = True - -    def _disambiguate_highlighted_ranges(self): -        """If any files have the same name, also display their folder name""" -        name_status: Dict[str, set] = { -        }  # basename -> set of full paths with that basename -        for rif in self._highlighted_ranges: -            basename = os.path.basename(rif.range.filepath) -            if basename in name_status: -                name_status[basename].add(rif.range.filepath) -            else: -                name_status[basename] = {rif.range.filepath} - -        for rif in self._highlighted_ranges: -            basename = os.path.basename(rif.range.filepath) -            if len(name_status[basename]) > 1: -                rif.display_name = os.path.join( -                    os.path.basename(os.path.dirname(rif.range.filepath)), basename) -            else: -                rif.display_name = basename -      async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): -        # Filter out rifs from ~/.continue/diffs folder -        range_in_files = [ -            rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] - -        # Make sure all filepaths are relative to workspace -        workspace_path = self.continue_sdk.ide.workspace_directory - -        # If not adding highlighted code -        if not self._adding_highlighted_code: -            if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): -                # If un-highlighting the range to edit, then remove the range -                self._highlighted_ranges = [] -                await self.update_subscribers() -            elif len(range_in_files) > 0: -                # Otherwise, replace the current range with the new one -                # This is the first range to be highlighted -                self._highlighted_ranges = [HighlightedRangeContext( -                    range=range_in_files[0], editing=True, pinned=False, display_name=os.path.basename(range_in_files[0].filepath))] -                await self.update_subscribers() -            return - -        # If current range overlaps with any others, delete them and only keep the new range -        new_ranges = [] -        for i, rif in enumerate(self._highlighted_ranges): -            found_overlap = False -            for new_rif in range_in_files: -                if rif.range.filepath == new_rif.filepath and rif.range.range.overlaps_with(new_rif.range): -                    found_overlap = True -                    break - -                # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids -                # the bug where cmd+f causes repeated highlights -                if rif.range.filepath == new_rif.filepath and rif.range.contents == new_rif.contents: -                    found_overlap = True -                    break - -            if not found_overlap: -                new_ranges.append(rif) - -        self._highlighted_ranges = new_ranges + [HighlightedRangeContext( -            range=rif, editing=False, pinned=False, display_name=os.path.basename(rif.filepath) -        ) for rif in range_in_files] - -        self._make_sure_is_editing_range() -        self._disambiguate_highlighted_ranges() +        # Add to context manager +        await self.context_manager.context_providers["code"].handle_highlighted_code( +            range_in_files)          await self.update_subscribers() @@ -250,29 +191,16 @@ class Autopilot(ContinueBaseModel):          await self.update_subscribers() -    async def delete_context_at_indices(self, indices: List[int]): -        kept_ranges = [] -        for i, rif in enumerate(self._highlighted_ranges): -            if i not in indices: -                kept_ranges.append(rif) -        self._highlighted_ranges = kept_ranges - -        self._make_sure_is_editing_range() - +    async def delete_context_with_ids(self, ids: List[str]): +        await self.context_manager.delete_context_with_ids(ids)          await self.update_subscribers()      async def toggle_adding_highlighted_code(self): -        self._adding_highlighted_code = not self._adding_highlighted_code -        await self.update_subscribers() - -    async def set_editing_at_indices(self, indices: List[int]): -        for i in range(len(self._highlighted_ranges)): -            self._highlighted_ranges[i].editing = i in indices +        self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code          await self.update_subscribers() -    async def set_pinned_at_indices(self, indices: List[int]): -        for i in range(len(self._highlighted_ranges)): -            self._highlighted_ranges[i].pinned = i in indices +    async def set_editing_at_ids(self, ids: List[str]): +        self.context_manager.context_providers["code"].set_editing_at_ids(ids)          await self.update_subscribers()      async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: @@ -437,10 +365,6 @@ class Autopilot(ContinueBaseModel):          if len(self._main_user_input_queue) > 1:              return -        # Remove context unless pinned -        # self._highlighted_ranges = [ -        #     hr for hr in self._highlighted_ranges if hr.pinned] -          # await self._request_halt()          # Just run the step that takes user input, and          # then up to the policy to decide how to deal with it. @@ -456,3 +380,7 @@ class Autopilot(ContinueBaseModel):          await self._request_halt()          await self.reverse_to_index(index)          await self.run_from_step(UserInputStep(user_input=user_input)) + +    async def select_context_item(self, id: str, query: str): +        await self.context_manager.select_context_item(id, query) +        await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 54f15143..bb9ca323 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,6 +1,7 @@  import json  import os  from .main import Step +from .context import ContextProvider  from pydantic import BaseModel, validator  from typing import List, Literal, Optional, Dict, Type, Union  import yaml @@ -50,6 +51,8 @@ class ContinueConfig(BaseModel):      system_message: Optional[str] = None      azure_openai_info: Optional[AzureInfo] = None +    context_providers: List[ContextProvider] = [] +      # Want to force these to be the slash commands for now      @validator('slash_commands', pre=True)      def default_slash_commands_validator(cls, v): diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py new file mode 100644 index 00000000..67bba651 --- /dev/null +++ b/continuedev/src/continuedev/core/context.py @@ -0,0 +1,205 @@ + +from abc import abstractmethod +from typing import Dict, List +import meilisearch +from pydantic import BaseModel + + +from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId +from ..server.meilisearch_server import check_meilisearch_running + + +SEARCH_INDEX_NAME = "continue_context_items" + + +class ContextProvider(BaseModel): +    """ +    The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. +    When you type '@', the context provider will be asked to populate a list of options. +    These options will be updated on each keystroke. +    When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). +    """ + +    title: str + +    selected_items: List[ContextItem] = [] + +    async def get_selected_items(self) -> List[ContextItem]: +        """ +        Returns all of the selected ContextItems. + +        Default implementation simply returns self.selected_items. + +        Other implementations may add an async processing step. +        """ +        return self.selected_items + +    @abstractmethod +    async def provide_context_items(self) -> List[ContextItem]: +        """ +        Provide documents for search index. This is run on startup. + +        This is the only method that must be implemented. +        """ + +    async def get_chat_messages(self) -> List[ChatMessage]: +        """ +        Returns all of the chat messages for the context provider. + +        Default implementation has a string template. +        """ +        return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] + +    async def get_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client) -> ContextItem: +        """ +        Returns the ContextItem with the given id. + +        Default implementation uses the search index to get the item. +        """ +        result = search_client.index( +            SEARCH_INDEX_NAME).get_document(id.to_string()) +        return ContextItem( +            description=ContextItemDescription( +                name=result.name, +                description=result.description, +                id=id +            ), +            content=result.content +        ) + +    async def delete_context_with_ids(self, ids: List[ContextItemId]): +        """ +        Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + +        Default implementation simply deletes those with the given ids. +        """ +        id_strings = {id.to_string() for id in ids} +        self.selected_items = list( +            filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items)) + +    async def clear_context(self): +        """ +        Clears all context. + +        Default implementation simply clears the selected items. +        """ +        self.selected_items = [] + +    async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client): +        """ +        Adds the given ContextItem to the list of ContextItems. + +        Default implementation simply appends the item, not allowing duplicates. + +        This method also allows you not to have to load all of the information until an item is selected. +        """ + +        # Don't add duplicate context +        for item in self.selected_items: +            if item.description.id.item_id == id.item_id: +                return + +        new_item = await self.get_item(id, query, search_client) +        self.selected_items.append(new_item) + + +class ContextManager: +    """ +    The context manager is responsible for storing the context to be passed to the LLM, including +    - ContextItems (highlighted code, GitHub Issues, etc.) +    - ChatMessages in the history +    - System Message +    - Functions + +    It is responsible for compiling all of this information into a single prompt without exceeding the token limit. +    """ + +    async def get_selected_items(self) -> List[ContextItem]: +        """ +        Returns all of the selected ContextItems. +        """ +        return sum([await provider.get_selected_items() for provider in self.context_providers.values()], []) + +    async def get_chat_messages(self) -> List[ChatMessage]: +        """ +        Returns chat messages from each provider. +        """ +        return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) + +    search_client: meilisearch.Client + +    def __init__(self, context_providers: List[ContextProvider]): +        self.search_client = meilisearch.Client('http://localhost:7700') + +        # If meilisearch isn't running, don't use any ContextProviders that might depend on it +        if not check_meilisearch_running(): +            context_providers = list( +                filter(lambda cp: cp.title == "code", context_providers)) + +        self.context_providers = { +            prov.title: prov for prov in context_providers} +        self.provider_titles = { +            provider.title for provider in context_providers} + +    async def load_index(self): +        for _, provider in self.context_providers.items(): +            context_items = await provider.provide_context_items() +            documents = [ +                { +                    "id": item.description.id.to_string(), +                    "name": item.description.name, +                    "description": item.description.description, +                    "content": item.content +                } +                for item in context_items +            ] +            if len(documents) > 0: +                self.search_client.index( +                    SEARCH_INDEX_NAME).add_documents(documents) + +    # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: +    #     """ +    #     Compiles the chat prompt into a single string. +    #     """ +    #     return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) + +    async def select_context_item(self, id: str, query: str): +        """ +        Selects the ContextItem with the given id. +        """ +        id: ContextItemId = ContextItemId.from_string(id) +        if id.provider_title not in self.provider_titles: +            raise ValueError( +                f"Context provider with title {id.provider_title} not found") + +        await self.context_providers[id.provider_title].add_context_item(id, query, self.search_client) + +    async def delete_context_with_ids(self, ids: List[str]): +        """ +        Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. +        """ + +        # Group by provider title +        provider_title_to_ids: Dict[str, List[ContextItemId]] = {} +        for id in ids: +            id: ContextItemId = ContextItemId.from_string(id) +            if id.provider_title not in provider_title_to_ids: +                provider_title_to_ids[id.provider_title] = [] +            provider_title_to_ids[id.provider_title].append(id) + +        # Recalculate context for each updated provider +        for provider_title, ids in provider_title_to_ids.items(): +            await self.context_providers[provider_title].delete_context_with_ids(ids) + +    async def clear_context(self): +        """ +        Clears all context. +        """ +        for provider in self.context_providers.values(): +            await self.context_providers[provider.title].clear_context() + + +""" +Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the +same format of prompts so you don't have to redo all of this logic. +""" diff --git a/continuedev/src/continuedev/core/context_manager.py b/continuedev/src/continuedev/core/context_manager.py deleted file mode 100644 index 37905535..00000000 --- a/continuedev/src/continuedev/core/context_manager.py +++ /dev/null @@ -1,119 +0,0 @@ - -from abc import ABC, abstractmethod, abstractproperty -from ast import List -from pydantic import BaseModel - -from ..libs.util.count_tokens import compile_chat_messages - - -class ContextItemDescription(BaseModel): -    """ -    A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'. - -    The id can be used to retrieve the ContextItem from the ContextManager. -    """ -    name: str -    description: str -    id:  str - - -class ContextItem(BaseModel): -    """ -    A ContextItem is a single item that is stored in the ContextManager. -    """ -    description: ContextItemDescription -    content: str - - -class ContextManager(ABC): -    """ -    The context manager is responsible for storing the context to be passed to the LLM, including -    - ContextItems (highlighted code, GitHub Issues, etc.) -    - ChatMessages in the history -    - System Message -    - Functions - -    It is responsible for compiling all of this information into a single prompt without exceeding the token limit. -    """ - -    def compile_chat_messages(self, max_tokens: int) -> List[Dict]: -        """ -        Compiles the chat prompt into a single string. -        """ -        return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message) - - -""" -Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the -same format of prompts so you don't have to redo all of this logic. -""" - - -class ContextProvider(ABC): -    """ -    The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. -    When you type '@', the context provider will be asked to populate a list of options. -    These options will be updated on each keystroke. -    When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). -    """ - -    title: str - -    @abstractmethod -    async def load(self): -        """ -        Loads the ContextProvider, possibly reading persisted data from disk. This will be called on startup. -        """ - -    @abstractmethod -    async def save(self): -        """ -        Saves the ContextProvider, possibly writing persisted data to disk. This will be called upon cache refresh. -        """ - -    @abstractmethod -    async def refresh_cache(self): -        """ -        Refreshes the cache of items. This will be called on startup and periodically. -        """ - -    @abstractmethod -    async def get_item_descriptions(self, query: str) -> List[ContextItemDescription]: -        """ -        Returns a list of options that should be displayed to the user. -        """ - -    @abstractmethod -    async def get_item(self, id: str) -> ContextItem: -        """ -        Returns the ContextItem with the given id. This allows you not to have to load all of the information until an item is selected. -        """ - -    @abstractmethod -    async def should_refresh(self) -> bool: -        """ -        Returns whether the ContextProvider should be refreshed. - -        For example, embeddings might need to be recalculated after commits, -        or GitHub issues might need to be refreshed after a new issue is created. - -        This method will be called every startup? Every once in a while? Every hour? -        User defined? Maybe just have a schedule instead of this method. -        """ - - -class GitHubIssuesContextProvider(ContextProvider): -    """ -    The GitHubIssuesContextProvider is a ContextProvider that allows you to search GitHub issues in a repo. -    """ - -    title = "issues" - -    def __init__(self, repo: str): -        self.repo = repo - -    async def get_item_descriptions(self, query: str) -> List[ContextItemDescription]: -        pass - -    async def get_item(self, id: str) -> ContextItem: -        pass diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index 50d01f8d..6c6adccc 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,12 +1,11 @@  import json -from textwrap import dedent -from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union +from typing import Coroutine, Dict, List, Literal, Union +from pydantic.schema import schema + -from ..models.filesystem import RangeInFileWithContents  from ..models.main import ContinueBaseModel -from pydantic import validator +from pydantic import BaseModel, validator  from .observation import Observation -from pydantic.schema import schema  ChatMessageRole = Literal["assistant", "user", "system", "function"] @@ -201,12 +200,48 @@ class SlashCommandDescription(ContinueBaseModel):      description: str -class HighlightedRangeContext(ContinueBaseModel): -    """Context for a highlighted range""" -    range: RangeInFileWithContents -    editing: bool -    pinned: bool -    display_name: str +class ContextItemId(BaseModel): +    """ +    A ContextItemId is a unique identifier for a ContextItem. +    """ +    provider_title: str +    item_id: str + +    def to_string(self) -> str: +        return f"{self.provider_title}-{self.item_id}" + +    @staticmethod +    def from_string(string: str) -> 'ContextItemId': +        provider_title, item_id = string.split('-') +        return ContextItemId(provider_title=provider_title, item_id=item_id) + + +class ContextItemDescription(BaseModel): +    """ +    A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'. + +    The id can be used to retrieve the ContextItem from the ContextManager. +    """ +    name: str +    description: str +    id:  ContextItemId + + +class ContextItem(BaseModel): +    """ +    A ContextItem is a single item that is stored in the ContextManager. +    """ +    description: ContextItemDescription +    content: str + +    @validator('content', pre=True) +    def content_must_be_string(cls, v): +        if v is None: +            return '' +        return v + +    editing: bool = False +    editable: bool = False  class FullState(ContinueBaseModel): @@ -215,9 +250,9 @@ class FullState(ContinueBaseModel):      active: bool      user_input_queue: List[str]      default_model: str -    highlighted_ranges: List[HighlightedRangeContext]      slash_commands: List[SlashCommandDescription]      adding_highlighted_code: bool +    selected_context_items: List[ContextItem]  class ContinueSDK: diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 4100efa6..59f33707 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -3,8 +3,10 @@ from functools import cached_property  from typing import Coroutine, Dict, Union  import os +  from ..steps.core.core import DefaultModelEditCodeStep  from ..models.main import Range +from .context import ContextItem  from .abstract_sdk import AbstractContinueSDK  from .config import ContinueConfig, load_config, load_global_config, update_global_config  from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory @@ -289,28 +291,13 @@ class ContinueSDK(AbstractContinueSDK):      async def get_chat_context(self) -> List[ChatMessage]:          history_context = self.history.to_chat_history() -        highlighted_code = [ -            hr.range for hr in self.__autopilot._highlighted_ranges] - -        preface = "The following code is highlighted" - -        # If no higlighted ranges, use first file as context -        if len(highlighted_code) == 0: -            preface = "The following file is open" -            visible_files = await self.ide.getVisibleFiles() -            if len(visible_files) > 0: -                content = await self.ide.readFile(visible_files[0]) -                highlighted_code = [ -                    RangeInFileWithContents.from_entire_file(visible_files[0], content)] - -        for rif in highlighted_code: -            msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```", -                              role="user", summary=f"{preface}: {rif.filepath}") - -            # Don't insert after latest user message or function call -            i = -1 -            if len(history_context) > 0 and (history_context[i].role == "user" or history_context[i].role == "function"): -                i -= 1 + +        context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages() + +        # Insert at the end, but don't insert after latest user message or function call +        i = -2 if (len(history_context) > 0 and ( +            history_context[-1].role == "user" or history_context[-1].role == "function")) else -1 +        for msg in context_messages:              history_context.insert(i, msg)          return history_context diff --git a/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py b/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py new file mode 100644 index 00000000..23d4fc86 --- /dev/null +++ b/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py @@ -0,0 +1,191 @@ +import os +from typing import Any, Dict, List + +import meilisearch +from ...core.main import ChatMessage +from ...models.filesystem import RangeInFile, RangeInFileWithContents +from ...core.context import ContextItem, ContextItemDescription, ContextItemId +from pydantic import BaseModel + + +class HighlightedRangeContextItem(BaseModel): +    rif: RangeInFileWithContents +    item: ContextItem + + +class HighlightedCodeContextProvider(BaseModel): +    """ +    The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. +    When you type '@', the context provider will be asked to populate a list of options. +    These options will be updated on each keystroke. +    When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). +    """ + +    title = "code" + +    ide: Any  # IdeProtocolServer + +    highlighted_ranges: List[HighlightedRangeContextItem] = [] +    adding_highlighted_code: bool = False + +    should_get_fallback_context_item: bool = True +    last_added_fallback: bool = False + +    async def _get_fallback_context_item(self) -> HighlightedRangeContextItem: +        if not self.should_get_fallback_context_item: +            return None + +        visible_files = await self.ide.getVisibleFiles() +        if len(visible_files) > 0: +            content = await self.ide.readFile(visible_files[0]) +            rif = RangeInFileWithContents.from_entire_file( +                visible_files[0], content) + +            item = self._rif_to_context_item(rif, 0, True) +            item.description.name = self._rif_to_name( +                rif, show_line_nums=False) + +            self.last_added_fallback = True +            return HighlightedRangeContextItem(rif=rif, item=item) + +        return None + +    async def get_selected_items(self) -> List[ContextItem]: +        items = [hr.item for hr in self.highlighted_ranges] + +        if len(items) == 0 and (fallback_item := await self._get_fallback_context_item()): +            items = [fallback_item.item] + +        return items + +    async def get_chat_messages(self) -> List[ContextItem]: +        ranges = self.highlighted_ranges +        if len(ranges) == 0 and (fallback_item := await self._get_fallback_context_item()): +            ranges = [fallback_item] + +        return [ChatMessage( +            role="user", +            content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```", +            summary=f"Code in this file is highlighted: {r.rif.filepath}" +        ) for r in ranges] + +    def _make_sure_is_editing_range(self): +        """If none of the highlighted ranges are currently being edited, the first should be selected""" +        if len(self.highlighted_ranges) == 0: +            return +        if not any(map(lambda x: x.item.editing, self.highlighted_ranges)): +            self.highlighted_ranges[0].item.editing = True + +    def _disambiguate_highlighted_ranges(self): +        """If any files have the same name, also display their folder name""" +        name_status: Dict[str, set] = { +        }  # basename -> set of full paths with that basename +        for hr in self.highlighted_ranges: +            basename = os.path.basename(hr.rif.filepath) +            if basename in name_status: +                name_status[basename].add(hr.rif.filepath) +            else: +                name_status[basename] = {hr.rif.filepath} + +        for hr in self.highlighted_ranges: +            if len(name_status[basename]) > 1: +                hr.item.description.name = self._rif_to_name(hr.rif, display_filename=os.path.join( +                    os.path.basename(os.path.dirname(hr.rif.filepath)), basename)) +            else: +                hr.item.description.name = self._rif_to_name( +                    hr.rif, display_filename=basename) + +    async def provide_context_items(self) -> List[ContextItem]: +        return [] + +    async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]: +        indices_to_delete = [ +            int(id.item_id) for id in ids +        ] + +        kept_ranges = [] +        for i, hr in enumerate(self.highlighted_ranges): +            if i not in indices_to_delete: +                kept_ranges.append(hr) +        self.highlighted_ranges = kept_ranges + +        self._make_sure_is_editing_range() + +        if len(self.highlighted_ranges) == 0 and self.last_added_fallback: +            self.should_get_fallback_context_item = False + +        return [hr.item for hr in self.highlighted_ranges] + +    def _rif_to_name(self, rif: RangeInFileWithContents, display_filename: str = None, show_line_nums: bool = True) -> str: +        line_nums = f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" if show_line_nums else "" +        return f"{display_filename or os.path.basename(rif.filepath)}{line_nums}" + +    def _rif_to_context_item(self, rif: RangeInFileWithContents, idx: int, editing: bool) -> ContextItem: +        return ContextItem( +            description=ContextItemDescription( +                name=self._rif_to_name(rif), +                description=rif.filepath, +                id=ContextItemId( +                    provider_title=self.title, +                    item_id=str(idx) +                ) +            ), +            content=rif.contents, +            editing=editing, +            editable=True +        ) + +    async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): +        self.should_get_fallback_context_item = True +        self.last_added_fallback = False + +        # Filter out rifs from ~/.continue/diffs folder +        range_in_files = [ +            rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")] + +        # If not adding highlighted code +        if not self.adding_highlighted_code: +            if len(self.highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end): +                # If un-highlighting the range to edit, then remove the range +                self.highlighted_ranges = [] +            elif len(range_in_files) > 0: +                # Otherwise, replace the current range with the new one +                # This is the first range to be highlighted +                self.highlighted_ranges = [ +                    HighlightedRangeContextItem( +                        rif=range_in_files[0], +                        item=self._rif_to_context_item(range_in_files[0], 0, True))] + +            return + +        # If current range overlaps with any others, delete them and only keep the new range +        new_ranges = [] +        for i, hr in enumerate(self.highlighted_ranges): +            found_overlap = False +            for new_rif in range_in_files: +                if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with(new_rif.range): +                    found_overlap = True +                    break + +                # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids +                # the bug where cmd+f causes repeated highlights +                if hr.rif.filepath == new_rif.filepath and hr.rif.contents == new_rif.contents: +                    found_overlap = True +                    break + +            if not found_overlap: +                new_ranges.append(HighlightedRangeContextItem(rif=hr.rif, item=self._rif_to_context_item( +                    hr.rif, len(new_ranges), False))) + +        self.highlighted_ranges = new_ranges + [HighlightedRangeContextItem(rif=rif, item=self._rif_to_context_item( +            rif, len(new_ranges) + idx, False)) for idx, rif in enumerate(range_in_files)] + +        self._make_sure_is_editing_range() +        self._disambiguate_highlighted_ranges() + +    async def set_editing_at_ids(self, ids: List[str]): +        for hr in self.highlighted_ranges: +            hr.item.editing = hr.item.description.id.to_string() in ids + +    async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client, prev: List[ContextItem] = None) -> List[ContextItem]: +        raise NotImplementedError() diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py index fddef887..d6ce13b3 100644 --- a/continuedev/src/continuedev/libs/util/paths.py +++ b/continuedev/src/continuedev/libs/util/paths.py @@ -2,16 +2,26 @@ import os  from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER -def getGlobalFolderPath(): -    return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) +def getGlobalFolderPath(): +    path = os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER) +    os.makedirs(path, exist_ok=True) +    return path  def getSessionsFolderPath(): -    return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) +    path = os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER) +    os.makedirs(path, exist_ok=True) +    return path +  def getServerFolderPath(): -    return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) +    path = os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER) +    os.makedirs(path, exist_ok=True) +    return path +  def getSessionFilePath(session_id: str): -    return os.path.join(getSessionsFolderPath(), f"{session_id}.json")
\ No newline at end of file +    path = os.path.join(getSessionsFolderPath(), f"{session_id}.json") +    os.makedirs(os.path.dirname(path), exist_ok=True) +    return path diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 6cebf429..06614984 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -2,6 +2,7 @@ from .main import *  from .filesystem import RangeInFile, FileEdit  from .filesystem_edit import FileEditWithFullContents  from ..core.main import History, HistoryNode, FullState +from ..core.context import ContextItem  from pydantic import schema_json_of  import os @@ -13,6 +14,8 @@ MODELS_TO_GENERATE = [      FileEditWithFullContents  ] + [      History, HistoryNode, FullState +] + [ +    ContextItem  ]  RENAMES = { diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index ae57c0b6..36b2f3fa 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -91,25 +91,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  self.on_clear_history()              elif message_type == "delete_at_index":                  self.on_delete_at_index(data["index"]) -            elif message_type == "delete_context_at_indices": -                self.on_delete_context_at_indices(data["indices"]) +            elif message_type == "delete_context_with_ids": +                self.on_delete_context_with_ids(data["ids"])              elif message_type == "toggle_adding_highlighted_code":                  self.on_toggle_adding_highlighted_code()              elif message_type == "set_editing_at_indices":                  self.on_set_editing_at_indices(data["indices"]) -            elif message_type == "set_pinned_at_indices": -                self.on_set_pinned_at_indices(data["indices"])              elif message_type == "show_logs_at_index":                  self.on_show_logs_at_index(data["index"]) +            elif message_type == "select_context_item": +                self.select_context_item(data["id"], data["query"])          except Exception as e:              print(e) -    async def send_state_update(self): -        state = self.session.autopilot.get_full_state().dict() -        await self._send_json("state_update", { -            "state": state -        }) -      def on_main_input(self, input: str):          # Do something with user input          create_async_task(self.session.autopilot.accept_user_input( @@ -144,10 +138,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):          create_async_task(self.session.autopilot.delete_at_index(              index), self.session.autopilot.continue_sdk.ide.unique_id) -    def on_delete_context_at_indices(self, indices: List[int]): +    def on_delete_context_with_ids(self, ids: List[str]):          create_async_task( -            self.session.autopilot.delete_context_at_indices( -                indices), self.session.autopilot.continue_sdk.ide.unique_id +            self.session.autopilot.delete_context_with_ids( +                ids), self.session.autopilot.continue_sdk.ide.unique_id          )      def on_toggle_adding_highlighted_code(self): @@ -162,18 +156,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  indices), self.session.autopilot.continue_sdk.ide.unique_id          ) -    def on_set_pinned_at_indices(self, indices: List[int]): -        create_async_task( -            self.session.autopilot.set_pinned_at_indices( -                indices), self.session.autopilot.continue_sdk.ide.unique_id -        ) -      def on_show_logs_at_index(self, index: int):          name = f"continue_logs.txt"          logs = "\n\n############################################\n\n".join(              ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)          create_async_task( -            self.session.autopilot.ide.showVirtualFile(name, logs)) +            self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id) + +    def select_context_item(self, id: str, query: str): +        """Called when user selects an item from the dropdown""" +        create_async_task( +            self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id)  @router.websocket("/ws") @@ -188,7 +181,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we          protocol.websocket = websocket          # Update any history that may have happened before connection -        await protocol.send_state_update() +        await protocol.session.autopilot.update_subscribers()          while AppStatus.should_exit is False:              message = await websocket.receive_text() @@ -214,5 +207,5 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we          if websocket.client_state != WebSocketState.DISCONNECTED:              await websocket.close() -        session_manager.persist_session(session.session_id) +        await session_manager.persist_session(session.session_id)          session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py index 9766fcd0..fb230216 100644 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ b/continuedev/src/continuedev/server/gui_protocol.py @@ -1,6 +1,8 @@  from typing import Any, Dict, List  from abc import ABC, abstractmethod +from ..core.context import ContextItem +  class AbstractGUIProtocolServer(ABC):      @abstractmethod @@ -24,10 +26,6 @@ class AbstractGUIProtocolServer(ABC):          """Called when the user inputs a step"""      @abstractmethod -    async def send_state_update(self, state: dict): -        """Send a state update to the client""" - -    @abstractmethod      def on_retry_at_index(self, index: int):          """Called when the user requests a retry at a previous index""" @@ -42,3 +40,7 @@ class AbstractGUIProtocolServer(ABC):      @abstractmethod      def on_delete_at_index(self, index: int):          """Called when the user requests to delete a step at a given index""" + +    @abstractmethod +    def select_context_item(self, id: str, query: str): +        """Called when user selects an item from the dropdown""" diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 42dc0cc1..7ee64041 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,15 +1,20 @@ +import asyncio +import subprocess  import time +import meilisearch  import psutil  import os  from fastapi import FastAPI  from fastapi.middleware.cors import CORSMiddleware -from .ide import router as ide_router -from .gui import router as gui_router -from .session_manager import session_manager  import atexit  import uvicorn  import argparse +from .ide import router as ide_router +from .gui import router as gui_router +from .session_manager import session_manager +from .meilisearch_server import start_meilisearch +  app = FastAPI()  app.include_router(ide_router) @@ -41,15 +46,20 @@ args = parser.parse_args()  # log_file = open('output.log', 'a')  # sys.stdout = log_file -  def run_server():      uvicorn.run(app, host="0.0.0.0", port=args.port) -def cleanup(): +async def cleanup_coroutine():      print("Cleaning up sessions")      for session_id in session_manager.sessions: -        session_manager.persist_session(session_id) +        await session_manager.persist_session(session_id) + + +def cleanup(): +    loop = asyncio.new_event_loop() +    loop.run_until_complete(cleanup_coroutine()) +    loop.close()  def cpu_usage_report(): @@ -77,6 +87,12 @@ if __name__ == "__main__":          # cpu_thread = threading.Thread(target=cpu_usage_loop)          # cpu_thread.start() +        try: +            start_meilisearch() +        except Exception as e: +            print("Failed to start MeiliSearch") +            print(e) +          run_server()      except Exception as e:          cleanup() diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py new file mode 100644 index 00000000..419f081f --- /dev/null +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -0,0 +1,56 @@ +import os +import subprocess + +import meilisearch +from ..libs.util.paths import getServerFolderPath + + +def check_meilisearch_installed() -> bool: +    """ +    Checks if MeiliSearch is installed. +    """ + +    serverPath = getServerFolderPath() +    meilisearchPath = os.path.join(serverPath, "meilisearch") + +    return os.path.exists(meilisearchPath) + + +def check_meilisearch_running() -> bool: +    """ +    Checks if MeiliSearch is running. +    """ + +    try: +        client = meilisearch.Client('http://localhost:7700') +        resp = client.health() +        if resp["status"] != "available": +            return False +        return True +    except Exception: +        return False + + +def start_meilisearch(): +    """ +    Starts the MeiliSearch server, wait for it. +    """ + +    # Doesn't work on windows for now +    if not os.name == "posix": +        return + +    serverPath = getServerFolderPath() + +    # Check if MeiliSearch is installed +    if not check_meilisearch_installed(): +        # Download MeiliSearch +        print("Downloading MeiliSearch...") +        subprocess.run( +            f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) + +    # Check if MeiliSearch is running +    if not check_meilisearch_running(): +        print("Starting MeiliSearch...") +        subprocess.Popen(["./meilisearch"], cwd=serverPath, stdout=subprocess.DEVNULL, +                         stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 90172a4e..96daf92c 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -74,7 +74,7 @@ class SessionManager:          async def on_update(state: FullState):              await session_manager.send_ws_data(session_id, "state_update", { -                "state": autopilot.get_full_state().dict() +                "state": state.dict()              })          autopilot.on_update(on_update) @@ -84,9 +84,9 @@ class SessionManager:      def remove_session(self, session_id: str):          del self.sessions[session_id] -    def persist_session(self, session_id: str): +    async def persist_session(self, session_id: str):          """Save the session's FullState as a json file""" -        full_state = self.sessions[session_id].autopilot.get_full_state() +        full_state = await self.sessions[session_id].autopilot.get_full_state()          if not os.path.exists(getSessionsFolderPath()):              os.mkdir(getSessionsFolderPath())          with open(getSessionFilePath(session_id), "w") as f:  | 
