summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-22 22:37:13 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-22 22:37:13 -0700
commit4d7e72970f770eb49627589fb142c93dfb6fd73b (patch)
tree7c85fb17a9e10ac8e387a001f021aa45c8c46582 /continuedev/src
parent007780d6d60095d4e0b238358ec26b2ec776b73e (diff)
downloadsncontinue-4d7e72970f770eb49627589fb142c93dfb6fd73b.tar.gz
sncontinue-4d7e72970f770eb49627589fb142c93dfb6fd73b.tar.bz2
sncontinue-4d7e72970f770eb49627589fb142c93dfb6fd73b.zip
@ feature (very large commit)
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py138
-rw-r--r--continuedev/src/continuedev/core/config.py3
-rw-r--r--continuedev/src/continuedev/core/context.py205
-rw-r--r--continuedev/src/continuedev/core/context_manager.py119
-rw-r--r--continuedev/src/continuedev/core/main.py59
-rw-r--r--continuedev/src/continuedev/core/sdk.py31
-rw-r--r--continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py191
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py20
-rw-r--r--continuedev/src/continuedev/models/generate_json_schema.py3
-rw-r--r--continuedev/src/continuedev/server/gui.py37
-rw-r--r--continuedev/src/continuedev/server/gui_protocol.py10
-rw-r--r--continuedev/src/continuedev/server/main.py28
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py56
-rw-r--r--continuedev/src/continuedev/server/session_manager.py6
14 files changed, 608 insertions, 298 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index abda50b0..c0f95414 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -9,10 +9,12 @@ from pydantic import root_validator
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
from .observation import Observation, InternalErrorObservation
+from .context import ContextItem, ContextItemDescription, ContextItemId, ContextManager
+from ..libs.context_providers.highlighted_code_context_provider import HighlightedCodeContextProvider
from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
from ..models.main import ContinueBaseModel
-from .main import Context, ContinueCustomException, HighlightedRangeContext, Policy, History, FullState, Step, HistoryNode
+from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode
from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
from ..libs.util.telemetry import capture_event
from .sdk import ContinueSDK
@@ -47,10 +49,11 @@ class Autopilot(ContinueBaseModel):
history: History = History.from_empty()
context: Context = Context()
full_state: Union[FullState, None] = None
- _on_update_callbacks: List[Callable[[FullState], None]] = []
-
+ context_manager: Union[ContextManager, None] = None
continue_sdk: ContinueSDK = None
+ _on_update_callbacks: List[Callable[[FullState], None]] = []
+
_active: bool = False
_should_halt: bool = False
_main_user_input_queue: List[str] = []
@@ -62,6 +65,14 @@ class Autopilot(ContinueBaseModel):
async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
autopilot = cls(ide=ide, policy=policy)
autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+
+ # Load documents into the search index
+ autopilot.context_manager = ContextManager(
+ autopilot.continue_sdk.config.context_providers + [
+ HighlightedCodeContextProvider(ide=ide)
+ ])
+ await autopilot.context_manager.load_index()
+
return autopilot
class Config:
@@ -75,15 +86,16 @@ class Autopilot(ContinueBaseModel):
values['history'] = full_state.history
return values
- def get_full_state(self) -> FullState:
+ async def get_full_state(self) -> FullState:
full_state = FullState(
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
default_model=self.continue_sdk.config.default_model,
- highlighted_ranges=self._highlighted_ranges,
slash_commands=self.get_available_slash_commands(),
- adding_highlighted_code=self._adding_highlighted_code,
+ adding_highlighted_code=self.context_manager.context_providers[
+ "code"].adding_highlighted_code,
+ selected_context_items=await self.context_manager.get_selected_items()
)
self.full_state = full_state
return full_state
@@ -104,8 +116,8 @@ class Autopilot(ContinueBaseModel):
self._main_user_input_queue = []
self._active = False
- # Also remove all context
- self._highlighted_ranges = []
+ # Clear context
+ await self.context_manager.clear_context()
await self.update_subscribers()
@@ -114,7 +126,7 @@ class Autopilot(ContinueBaseModel):
self._on_update_callbacks.append(callback)
async def update_subscribers(self):
- full_state = self.get_full_state()
+ full_state = await self.get_full_state()
for callback in self._on_update_callbacks:
await callback(full_state)
@@ -159,81 +171,10 @@ class Autopilot(ContinueBaseModel):
step = tb_step.step({"output": output, **tb_step.params})
await self._run_singular_step(step)
- _highlighted_ranges: List[HighlightedRangeContext] = []
- _adding_highlighted_code: bool = False
-
- def _make_sure_is_editing_range(self):
- """If none of the highlighted ranges are currently being edited, the first should be selected"""
- if len(self._highlighted_ranges) == 0:
- return
- if not any(map(lambda x: x.editing, self._highlighted_ranges)):
- self._highlighted_ranges[0].editing = True
-
- def _disambiguate_highlighted_ranges(self):
- """If any files have the same name, also display their folder name"""
- name_status: Dict[str, set] = {
- } # basename -> set of full paths with that basename
- for rif in self._highlighted_ranges:
- basename = os.path.basename(rif.range.filepath)
- if basename in name_status:
- name_status[basename].add(rif.range.filepath)
- else:
- name_status[basename] = {rif.range.filepath}
-
- for rif in self._highlighted_ranges:
- basename = os.path.basename(rif.range.filepath)
- if len(name_status[basename]) > 1:
- rif.display_name = os.path.join(
- os.path.basename(os.path.dirname(rif.range.filepath)), basename)
- else:
- rif.display_name = basename
-
async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):
- # Filter out rifs from ~/.continue/diffs folder
- range_in_files = [
- rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")]
-
- # Make sure all filepaths are relative to workspace
- workspace_path = self.continue_sdk.ide.workspace_directory
-
- # If not adding highlighted code
- if not self._adding_highlighted_code:
- if len(self._highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end):
- # If un-highlighting the range to edit, then remove the range
- self._highlighted_ranges = []
- await self.update_subscribers()
- elif len(range_in_files) > 0:
- # Otherwise, replace the current range with the new one
- # This is the first range to be highlighted
- self._highlighted_ranges = [HighlightedRangeContext(
- range=range_in_files[0], editing=True, pinned=False, display_name=os.path.basename(range_in_files[0].filepath))]
- await self.update_subscribers()
- return
-
- # If current range overlaps with any others, delete them and only keep the new range
- new_ranges = []
- for i, rif in enumerate(self._highlighted_ranges):
- found_overlap = False
- for new_rif in range_in_files:
- if rif.range.filepath == new_rif.filepath and rif.range.range.overlaps_with(new_rif.range):
- found_overlap = True
- break
-
- # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids
- # the bug where cmd+f causes repeated highlights
- if rif.range.filepath == new_rif.filepath and rif.range.contents == new_rif.contents:
- found_overlap = True
- break
-
- if not found_overlap:
- new_ranges.append(rif)
-
- self._highlighted_ranges = new_ranges + [HighlightedRangeContext(
- range=rif, editing=False, pinned=False, display_name=os.path.basename(rif.filepath)
- ) for rif in range_in_files]
-
- self._make_sure_is_editing_range()
- self._disambiguate_highlighted_ranges()
+ # Add to context manager
+ await self.context_manager.context_providers["code"].handle_highlighted_code(
+ range_in_files)
await self.update_subscribers()
@@ -250,29 +191,16 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
- async def delete_context_at_indices(self, indices: List[int]):
- kept_ranges = []
- for i, rif in enumerate(self._highlighted_ranges):
- if i not in indices:
- kept_ranges.append(rif)
- self._highlighted_ranges = kept_ranges
-
- self._make_sure_is_editing_range()
-
+ async def delete_context_with_ids(self, ids: List[str]):
+ await self.context_manager.delete_context_with_ids(ids)
await self.update_subscribers()
async def toggle_adding_highlighted_code(self):
- self._adding_highlighted_code = not self._adding_highlighted_code
- await self.update_subscribers()
-
- async def set_editing_at_indices(self, indices: List[int]):
- for i in range(len(self._highlighted_ranges)):
- self._highlighted_ranges[i].editing = i in indices
+ self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code
await self.update_subscribers()
- async def set_pinned_at_indices(self, indices: List[int]):
- for i in range(len(self._highlighted_ranges)):
- self._highlighted_ranges[i].pinned = i in indices
+ async def set_editing_at_ids(self, ids: List[str]):
+ self.context_manager.context_providers["code"].set_editing_at_ids(ids)
await self.update_subscribers()
async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]:
@@ -437,10 +365,6 @@ class Autopilot(ContinueBaseModel):
if len(self._main_user_input_queue) > 1:
return
- # Remove context unless pinned
- # self._highlighted_ranges = [
- # hr for hr in self._highlighted_ranges if hr.pinned]
-
# await self._request_halt()
# Just run the step that takes user input, and
# then up to the policy to decide how to deal with it.
@@ -456,3 +380,7 @@ class Autopilot(ContinueBaseModel):
await self._request_halt()
await self.reverse_to_index(index)
await self.run_from_step(UserInputStep(user_input=user_input))
+
+ async def select_context_item(self, id: str, query: str):
+ await self.context_manager.select_context_item(id, query)
+ await self.update_subscribers()
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 54f15143..bb9ca323 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -1,6 +1,7 @@
import json
import os
from .main import Step
+from .context import ContextProvider
from pydantic import BaseModel, validator
from typing import List, Literal, Optional, Dict, Type, Union
import yaml
@@ -50,6 +51,8 @@ class ContinueConfig(BaseModel):
system_message: Optional[str] = None
azure_openai_info: Optional[AzureInfo] = None
+ context_providers: List[ContextProvider] = []
+
# Want to force these to be the slash commands for now
@validator('slash_commands', pre=True)
def default_slash_commands_validator(cls, v):
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
new file mode 100644
index 00000000..67bba651
--- /dev/null
+++ b/continuedev/src/continuedev/core/context.py
@@ -0,0 +1,205 @@
+
+from abc import abstractmethod
+from typing import Dict, List
+import meilisearch
+from pydantic import BaseModel
+
+
+from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
+from ..server.meilisearch_server import check_meilisearch_running
+
+
+SEARCH_INDEX_NAME = "continue_context_items"
+
+
+class ContextProvider(BaseModel):
+ """
+ The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'.
+ When you type '@', the context provider will be asked to populate a list of options.
+ These options will be updated on each keystroke.
+ When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object).
+ """
+
+ title: str
+
+ selected_items: List[ContextItem] = []
+
+ async def get_selected_items(self) -> List[ContextItem]:
+ """
+ Returns all of the selected ContextItems.
+
+ Default implementation simply returns self.selected_items.
+
+ Other implementations may add an async processing step.
+ """
+ return self.selected_items
+
+ @abstractmethod
+ async def provide_context_items(self) -> List[ContextItem]:
+ """
+ Provide documents for search index. This is run on startup.
+
+ This is the only method that must be implemented.
+ """
+
+ async def get_chat_messages(self) -> List[ChatMessage]:
+ """
+ Returns all of the chat messages for the context provider.
+
+ Default implementation has a string template.
+ """
+ return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()]
+
+ async def get_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client) -> ContextItem:
+ """
+ Returns the ContextItem with the given id.
+
+ Default implementation uses the search index to get the item.
+ """
+ result = search_client.index(
+ SEARCH_INDEX_NAME).get_document(id.to_string())
+ return ContextItem(
+ description=ContextItemDescription(
+ name=result.name,
+ description=result.description,
+ id=id
+ ),
+ content=result.content
+ )
+
+ async def delete_context_with_ids(self, ids: List[ContextItemId]):
+ """
+ Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.
+
+ Default implementation simply deletes those with the given ids.
+ """
+ id_strings = {id.to_string() for id in ids}
+ self.selected_items = list(
+ filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items))
+
+ async def clear_context(self):
+ """
+ Clears all context.
+
+ Default implementation simply clears the selected items.
+ """
+ self.selected_items = []
+
+ async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client):
+ """
+ Adds the given ContextItem to the list of ContextItems.
+
+ Default implementation simply appends the item, not allowing duplicates.
+
+ This method also allows you not to have to load all of the information until an item is selected.
+ """
+
+ # Don't add duplicate context
+ for item in self.selected_items:
+ if item.description.id.item_id == id.item_id:
+ return
+
+ new_item = await self.get_item(id, query, search_client)
+ self.selected_items.append(new_item)
+
+
+class ContextManager:
+ """
+ The context manager is responsible for storing the context to be passed to the LLM, including
+ - ContextItems (highlighted code, GitHub Issues, etc.)
+ - ChatMessages in the history
+ - System Message
+ - Functions
+
+ It is responsible for compiling all of this information into a single prompt without exceeding the token limit.
+ """
+
+ async def get_selected_items(self) -> List[ContextItem]:
+ """
+ Returns all of the selected ContextItems.
+ """
+ return sum([await provider.get_selected_items() for provider in self.context_providers.values()], [])
+
+ async def get_chat_messages(self) -> List[ChatMessage]:
+ """
+ Returns chat messages from each provider.
+ """
+ return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], [])
+
+ search_client: meilisearch.Client
+
+ def __init__(self, context_providers: List[ContextProvider]):
+ self.search_client = meilisearch.Client('http://localhost:7700')
+
+ # If meilisearch isn't running, don't use any ContextProviders that might depend on it
+ if not check_meilisearch_running():
+ context_providers = list(
+ filter(lambda cp: cp.title == "code", context_providers))
+
+ self.context_providers = {
+ prov.title: prov for prov in context_providers}
+ self.provider_titles = {
+ provider.title for provider in context_providers}
+
+ async def load_index(self):
+ for _, provider in self.context_providers.items():
+ context_items = await provider.provide_context_items()
+ documents = [
+ {
+ "id": item.description.id.to_string(),
+ "name": item.description.name,
+ "description": item.description.description,
+ "content": item.content
+ }
+ for item in context_items
+ ]
+ if len(documents) > 0:
+ self.search_client.index(
+ SEARCH_INDEX_NAME).add_documents(documents)
+
+ # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
+ # """
+ # Compiles the chat prompt into a single string.
+ # """
+ # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
+
+ async def select_context_item(self, id: str, query: str):
+ """
+ Selects the ContextItem with the given id.
+ """
+ id: ContextItemId = ContextItemId.from_string(id)
+ if id.provider_title not in self.provider_titles:
+ raise ValueError(
+ f"Context provider with title {id.provider_title} not found")
+
+ await self.context_providers[id.provider_title].add_context_item(id, query, self.search_client)
+
+ async def delete_context_with_ids(self, ids: List[str]):
+ """
+ Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.
+ """
+
+ # Group by provider title
+ provider_title_to_ids: Dict[str, List[ContextItemId]] = {}
+ for id in ids:
+ id: ContextItemId = ContextItemId.from_string(id)
+ if id.provider_title not in provider_title_to_ids:
+ provider_title_to_ids[id.provider_title] = []
+ provider_title_to_ids[id.provider_title].append(id)
+
+ # Recalculate context for each updated provider
+ for provider_title, ids in provider_title_to_ids.items():
+ await self.context_providers[provider_title].delete_context_with_ids(ids)
+
+ async def clear_context(self):
+ """
+ Clears all context.
+ """
+ for provider in self.context_providers.values():
+ await self.context_providers[provider.title].clear_context()
+
+
+"""
+Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the
+same format of prompts so you don't have to redo all of this logic.
+"""
diff --git a/continuedev/src/continuedev/core/context_manager.py b/continuedev/src/continuedev/core/context_manager.py
deleted file mode 100644
index 37905535..00000000
--- a/continuedev/src/continuedev/core/context_manager.py
+++ /dev/null
@@ -1,119 +0,0 @@
-
-from abc import ABC, abstractmethod, abstractproperty
-from ast import List
-from pydantic import BaseModel
-
-from ..libs.util.count_tokens import compile_chat_messages
-
-
-class ContextItemDescription(BaseModel):
- """
- A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'.
-
- The id can be used to retrieve the ContextItem from the ContextManager.
- """
- name: str
- description: str
- id: str
-
-
-class ContextItem(BaseModel):
- """
- A ContextItem is a single item that is stored in the ContextManager.
- """
- description: ContextItemDescription
- content: str
-
-
-class ContextManager(ABC):
- """
- The context manager is responsible for storing the context to be passed to the LLM, including
- - ContextItems (highlighted code, GitHub Issues, etc.)
- - ChatMessages in the history
- - System Message
- - Functions
-
- It is responsible for compiling all of this information into a single prompt without exceeding the token limit.
- """
-
- def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
- """
- Compiles the chat prompt into a single string.
- """
- return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
-
-
-"""
-Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the
-same format of prompts so you don't have to redo all of this logic.
-"""
-
-
-class ContextProvider(ABC):
- """
- The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'.
- When you type '@', the context provider will be asked to populate a list of options.
- These options will be updated on each keystroke.
- When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object).
- """
-
- title: str
-
- @abstractmethod
- async def load(self):
- """
- Loads the ContextProvider, possibly reading persisted data from disk. This will be called on startup.
- """
-
- @abstractmethod
- async def save(self):
- """
- Saves the ContextProvider, possibly writing persisted data to disk. This will be called upon cache refresh.
- """
-
- @abstractmethod
- async def refresh_cache(self):
- """
- Refreshes the cache of items. This will be called on startup and periodically.
- """
-
- @abstractmethod
- async def get_item_descriptions(self, query: str) -> List[ContextItemDescription]:
- """
- Returns a list of options that should be displayed to the user.
- """
-
- @abstractmethod
- async def get_item(self, id: str) -> ContextItem:
- """
- Returns the ContextItem with the given id. This allows you not to have to load all of the information until an item is selected.
- """
-
- @abstractmethod
- async def should_refresh(self) -> bool:
- """
- Returns whether the ContextProvider should be refreshed.
-
- For example, embeddings might need to be recalculated after commits,
- or GitHub issues might need to be refreshed after a new issue is created.
-
- This method will be called every startup? Every once in a while? Every hour?
- User defined? Maybe just have a schedule instead of this method.
- """
-
-
-class GitHubIssuesContextProvider(ContextProvider):
- """
- The GitHubIssuesContextProvider is a ContextProvider that allows you to search GitHub issues in a repo.
- """
-
- title = "issues"
-
- def __init__(self, repo: str):
- self.repo = repo
-
- async def get_item_descriptions(self, query: str) -> List[ContextItemDescription]:
- pass
-
- async def get_item(self, id: str) -> ContextItem:
- pass
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 50d01f8d..6c6adccc 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -1,12 +1,11 @@
import json
-from textwrap import dedent
-from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union
+from typing import Coroutine, Dict, List, Literal, Union
+from pydantic.schema import schema
+
-from ..models.filesystem import RangeInFileWithContents
from ..models.main import ContinueBaseModel
-from pydantic import validator
+from pydantic import BaseModel, validator
from .observation import Observation
-from pydantic.schema import schema
ChatMessageRole = Literal["assistant", "user", "system", "function"]
@@ -201,12 +200,48 @@ class SlashCommandDescription(ContinueBaseModel):
description: str
-class HighlightedRangeContext(ContinueBaseModel):
- """Context for a highlighted range"""
- range: RangeInFileWithContents
- editing: bool
- pinned: bool
- display_name: str
+class ContextItemId(BaseModel):
+ """
+ A ContextItemId is a unique identifier for a ContextItem.
+ """
+ provider_title: str
+ item_id: str
+
+ def to_string(self) -> str:
+ return f"{self.provider_title}-{self.item_id}"
+
+ @staticmethod
+ def from_string(string: str) -> 'ContextItemId':
+ provider_title, item_id = string.split('-')
+ return ContextItemId(provider_title=provider_title, item_id=item_id)
+
+
+class ContextItemDescription(BaseModel):
+ """
+ A ContextItemDescription is a description of a ContextItem that is displayed to the user when they type '@'.
+
+ The id can be used to retrieve the ContextItem from the ContextManager.
+ """
+ name: str
+ description: str
+ id: ContextItemId
+
+
+class ContextItem(BaseModel):
+ """
+ A ContextItem is a single item that is stored in the ContextManager.
+ """
+ description: ContextItemDescription
+ content: str
+
+ @validator('content', pre=True)
+ def content_must_be_string(cls, v):
+ if v is None:
+ return ''
+ return v
+
+ editing: bool = False
+ editable: bool = False
class FullState(ContinueBaseModel):
@@ -215,9 +250,9 @@ class FullState(ContinueBaseModel):
active: bool
user_input_queue: List[str]
default_model: str
- highlighted_ranges: List[HighlightedRangeContext]
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
+ selected_context_items: List[ContextItem]
class ContinueSDK:
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 4100efa6..59f33707 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -3,8 +3,10 @@ from functools import cached_property
from typing import Coroutine, Dict, Union
import os
+
from ..steps.core.core import DefaultModelEditCodeStep
from ..models.main import Range
+from .context import ContextItem
from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig, load_config, load_global_config, update_global_config
from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
@@ -289,28 +291,13 @@ class ContinueSDK(AbstractContinueSDK):
async def get_chat_context(self) -> List[ChatMessage]:
history_context = self.history.to_chat_history()
- highlighted_code = [
- hr.range for hr in self.__autopilot._highlighted_ranges]
-
- preface = "The following code is highlighted"
-
- # If no higlighted ranges, use first file as context
- if len(highlighted_code) == 0:
- preface = "The following file is open"
- visible_files = await self.ide.getVisibleFiles()
- if len(visible_files) > 0:
- content = await self.ide.readFile(visible_files[0])
- highlighted_code = [
- RangeInFileWithContents.from_entire_file(visible_files[0], content)]
-
- for rif in highlighted_code:
- msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```",
- role="user", summary=f"{preface}: {rif.filepath}")
-
- # Don't insert after latest user message or function call
- i = -1
- if len(history_context) > 0 and (history_context[i].role == "user" or history_context[i].role == "function"):
- i -= 1
+
+ context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages()
+
+ # Insert at the end, but don't insert after latest user message or function call
+ i = -2 if (len(history_context) > 0 and (
+ history_context[-1].role == "user" or history_context[-1].role == "function")) else -1
+ for msg in context_messages:
history_context.insert(i, msg)
return history_context
diff --git a/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py b/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py
new file mode 100644
index 00000000..23d4fc86
--- /dev/null
+++ b/continuedev/src/continuedev/libs/context_providers/highlighted_code_context_provider.py
@@ -0,0 +1,191 @@
+import os
+from typing import Any, Dict, List
+
+import meilisearch
+from ...core.main import ChatMessage
+from ...models.filesystem import RangeInFile, RangeInFileWithContents
+from ...core.context import ContextItem, ContextItemDescription, ContextItemId
+from pydantic import BaseModel
+
+
+class HighlightedRangeContextItem(BaseModel):
+ rif: RangeInFileWithContents
+ item: ContextItem
+
+
+class HighlightedCodeContextProvider(BaseModel):
+ """
+ The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'.
+ When you type '@', the context provider will be asked to populate a list of options.
+ These options will be updated on each keystroke.
+ When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object).
+ """
+
+ title = "code"
+
+ ide: Any # IdeProtocolServer
+
+ highlighted_ranges: List[HighlightedRangeContextItem] = []
+ adding_highlighted_code: bool = False
+
+ should_get_fallback_context_item: bool = True
+ last_added_fallback: bool = False
+
+ async def _get_fallback_context_item(self) -> HighlightedRangeContextItem:
+ if not self.should_get_fallback_context_item:
+ return None
+
+ visible_files = await self.ide.getVisibleFiles()
+ if len(visible_files) > 0:
+ content = await self.ide.readFile(visible_files[0])
+ rif = RangeInFileWithContents.from_entire_file(
+ visible_files[0], content)
+
+ item = self._rif_to_context_item(rif, 0, True)
+ item.description.name = self._rif_to_name(
+ rif, show_line_nums=False)
+
+ self.last_added_fallback = True
+ return HighlightedRangeContextItem(rif=rif, item=item)
+
+ return None
+
+ async def get_selected_items(self) -> List[ContextItem]:
+ items = [hr.item for hr in self.highlighted_ranges]
+
+ if len(items) == 0 and (fallback_item := await self._get_fallback_context_item()):
+ items = [fallback_item.item]
+
+ return items
+
+ async def get_chat_messages(self) -> List[ContextItem]:
+ ranges = self.highlighted_ranges
+ if len(ranges) == 0 and (fallback_item := await self._get_fallback_context_item()):
+ ranges = [fallback_item]
+
+ return [ChatMessage(
+ role="user",
+ content=f"Code in this file is highlighted ({r.rif.filepath}):\n```\n{r.rif.contents}\n```",
+ summary=f"Code in this file is highlighted: {r.rif.filepath}"
+ ) for r in ranges]
+
+ def _make_sure_is_editing_range(self):
+ """If none of the highlighted ranges are currently being edited, the first should be selected"""
+ if len(self.highlighted_ranges) == 0:
+ return
+ if not any(map(lambda x: x.item.editing, self.highlighted_ranges)):
+ self.highlighted_ranges[0].item.editing = True
+
+ def _disambiguate_highlighted_ranges(self):
+ """If any files have the same name, also display their folder name"""
+ name_status: Dict[str, set] = {
+ } # basename -> set of full paths with that basename
+ for hr in self.highlighted_ranges:
+ basename = os.path.basename(hr.rif.filepath)
+ if basename in name_status:
+ name_status[basename].add(hr.rif.filepath)
+ else:
+ name_status[basename] = {hr.rif.filepath}
+
+ for hr in self.highlighted_ranges:
+ if len(name_status[basename]) > 1:
+ hr.item.description.name = self._rif_to_name(hr.rif, display_filename=os.path.join(
+ os.path.basename(os.path.dirname(hr.rif.filepath)), basename))
+ else:
+ hr.item.description.name = self._rif_to_name(
+ hr.rif, display_filename=basename)
+
+ async def provide_context_items(self) -> List[ContextItem]:
+ return []
+
+ async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]:
+ indices_to_delete = [
+ int(id.item_id) for id in ids
+ ]
+
+ kept_ranges = []
+ for i, hr in enumerate(self.highlighted_ranges):
+ if i not in indices_to_delete:
+ kept_ranges.append(hr)
+ self.highlighted_ranges = kept_ranges
+
+ self._make_sure_is_editing_range()
+
+ if len(self.highlighted_ranges) == 0 and self.last_added_fallback:
+ self.should_get_fallback_context_item = False
+
+ return [hr.item for hr in self.highlighted_ranges]
+
+ def _rif_to_name(self, rif: RangeInFileWithContents, display_filename: str = None, show_line_nums: bool = True) -> str:
+ line_nums = f" ({rif.range.start.line + 1}-{rif.range.end.line + 1})" if show_line_nums else ""
+ return f"{display_filename or os.path.basename(rif.filepath)}{line_nums}"
+
+ def _rif_to_context_item(self, rif: RangeInFileWithContents, idx: int, editing: bool) -> ContextItem:
+ return ContextItem(
+ description=ContextItemDescription(
+ name=self._rif_to_name(rif),
+ description=rif.filepath,
+ id=ContextItemId(
+ provider_title=self.title,
+ item_id=str(idx)
+ )
+ ),
+ content=rif.contents,
+ editing=editing,
+ editable=True
+ )
+
+ async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):
+ self.should_get_fallback_context_item = True
+ self.last_added_fallback = False
+
+ # Filter out rifs from ~/.continue/diffs folder
+ range_in_files = [
+ rif for rif in range_in_files if not os.path.dirname(rif.filepath) == os.path.expanduser("~/.continue/diffs")]
+
+ # If not adding highlighted code
+ if not self.adding_highlighted_code:
+ if len(self.highlighted_ranges) == 1 and len(range_in_files) <= 1 and (len(range_in_files) == 0 or range_in_files[0].range.start == range_in_files[0].range.end):
+ # If un-highlighting the range to edit, then remove the range
+ self.highlighted_ranges = []
+ elif len(range_in_files) > 0:
+ # Otherwise, replace the current range with the new one
+ # This is the first range to be highlighted
+ self.highlighted_ranges = [
+ HighlightedRangeContextItem(
+ rif=range_in_files[0],
+ item=self._rif_to_context_item(range_in_files[0], 0, True))]
+
+ return
+
+ # If current range overlaps with any others, delete them and only keep the new range
+ new_ranges = []
+ for i, hr in enumerate(self.highlighted_ranges):
+ found_overlap = False
+ for new_rif in range_in_files:
+ if hr.rif.filepath == new_rif.filepath and hr.rif.range.overlaps_with(new_rif.range):
+ found_overlap = True
+ break
+
+ # Also don't allow multiple ranges in same file with same content. This is useless to the model, and avoids
+ # the bug where cmd+f causes repeated highlights
+ if hr.rif.filepath == new_rif.filepath and hr.rif.contents == new_rif.contents:
+ found_overlap = True
+ break
+
+ if not found_overlap:
+ new_ranges.append(HighlightedRangeContextItem(rif=hr.rif, item=self._rif_to_context_item(
+ hr.rif, len(new_ranges), False)))
+
+ self.highlighted_ranges = new_ranges + [HighlightedRangeContextItem(rif=rif, item=self._rif_to_context_item(
+ rif, len(new_ranges) + idx, False)) for idx, rif in enumerate(range_in_files)]
+
+ self._make_sure_is_editing_range()
+ self._disambiguate_highlighted_ranges()
+
+ async def set_editing_at_ids(self, ids: List[str]):
+ for hr in self.highlighted_ranges:
+ hr.item.editing = hr.item.description.id.to_string() in ids
+
+ async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client, prev: List[ContextItem] = None) -> List[ContextItem]:
+ raise NotImplementedError()
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
index fddef887..d6ce13b3 100644
--- a/continuedev/src/continuedev/libs/util/paths.py
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -2,16 +2,26 @@ import os
from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER
-def getGlobalFolderPath():
- return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER)
+def getGlobalFolderPath():
+ path = os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER)
+ os.makedirs(path, exist_ok=True)
+ return path
def getSessionsFolderPath():
- return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER)
+ path = os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER)
+ os.makedirs(path, exist_ok=True)
+ return path
+
def getServerFolderPath():
- return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER)
+ path = os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER)
+ os.makedirs(path, exist_ok=True)
+ return path
+
def getSessionFilePath(session_id: str):
- return os.path.join(getSessionsFolderPath(), f"{session_id}.json") \ No newline at end of file
+ path = os.path.join(getSessionsFolderPath(), f"{session_id}.json")
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ return path
diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py
index 6cebf429..06614984 100644
--- a/continuedev/src/continuedev/models/generate_json_schema.py
+++ b/continuedev/src/continuedev/models/generate_json_schema.py
@@ -2,6 +2,7 @@ from .main import *
from .filesystem import RangeInFile, FileEdit
from .filesystem_edit import FileEditWithFullContents
from ..core.main import History, HistoryNode, FullState
+from ..core.context import ContextItem
from pydantic import schema_json_of
import os
@@ -13,6 +14,8 @@ MODELS_TO_GENERATE = [
FileEditWithFullContents
] + [
History, HistoryNode, FullState
+] + [
+ ContextItem
]
RENAMES = {
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index ae57c0b6..36b2f3fa 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -91,25 +91,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_clear_history()
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
- elif message_type == "delete_context_at_indices":
- self.on_delete_context_at_indices(data["indices"])
+ elif message_type == "delete_context_with_ids":
+ self.on_delete_context_with_ids(data["ids"])
elif message_type == "toggle_adding_highlighted_code":
self.on_toggle_adding_highlighted_code()
elif message_type == "set_editing_at_indices":
self.on_set_editing_at_indices(data["indices"])
- elif message_type == "set_pinned_at_indices":
- self.on_set_pinned_at_indices(data["indices"])
elif message_type == "show_logs_at_index":
self.on_show_logs_at_index(data["index"])
+ elif message_type == "select_context_item":
+ self.select_context_item(data["id"], data["query"])
except Exception as e:
print(e)
- async def send_state_update(self):
- state = self.session.autopilot.get_full_state().dict()
- await self._send_json("state_update", {
- "state": state
- })
-
def on_main_input(self, input: str):
# Do something with user input
create_async_task(self.session.autopilot.accept_user_input(
@@ -144,10 +138,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(self.session.autopilot.delete_at_index(
index), self.session.autopilot.continue_sdk.ide.unique_id)
- def on_delete_context_at_indices(self, indices: List[int]):
+ def on_delete_context_with_ids(self, ids: List[str]):
create_async_task(
- self.session.autopilot.delete_context_at_indices(
- indices), self.session.autopilot.continue_sdk.ide.unique_id
+ self.session.autopilot.delete_context_with_ids(
+ ids), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_toggle_adding_highlighted_code(self):
@@ -162,18 +156,17 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
indices), self.session.autopilot.continue_sdk.ide.unique_id
)
- def on_set_pinned_at_indices(self, indices: List[int]):
- create_async_task(
- self.session.autopilot.set_pinned_at_indices(
- indices), self.session.autopilot.continue_sdk.ide.unique_id
- )
-
def on_show_logs_at_index(self, index: int):
name = f"continue_logs.txt"
logs = "\n\n############################################\n\n".join(
["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
create_async_task(
- self.session.autopilot.ide.showVirtualFile(name, logs))
+ self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id)
+
+ def select_context_item(self, id: str, query: str):
+ """Called when user selects an item from the dropdown"""
+ create_async_task(
+ self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id)
@router.websocket("/ws")
@@ -188,7 +181,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
protocol.websocket = websocket
# Update any history that may have happened before connection
- await protocol.send_state_update()
+ await protocol.session.autopilot.update_subscribers()
while AppStatus.should_exit is False:
message = await websocket.receive_text()
@@ -214,5 +207,5 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
- session_manager.persist_session(session.session_id)
+ await session_manager.persist_session(session.session_id)
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py
index 9766fcd0..fb230216 100644
--- a/continuedev/src/continuedev/server/gui_protocol.py
+++ b/continuedev/src/continuedev/server/gui_protocol.py
@@ -1,6 +1,8 @@
from typing import Any, Dict, List
from abc import ABC, abstractmethod
+from ..core.context import ContextItem
+
class AbstractGUIProtocolServer(ABC):
@abstractmethod
@@ -24,10 +26,6 @@ class AbstractGUIProtocolServer(ABC):
"""Called when the user inputs a step"""
@abstractmethod
- async def send_state_update(self, state: dict):
- """Send a state update to the client"""
-
- @abstractmethod
def on_retry_at_index(self, index: int):
"""Called when the user requests a retry at a previous index"""
@@ -42,3 +40,7 @@ class AbstractGUIProtocolServer(ABC):
@abstractmethod
def on_delete_at_index(self, index: int):
"""Called when the user requests to delete a step at a given index"""
+
+ @abstractmethod
+ def select_context_item(self, id: str, query: str):
+ """Called when user selects an item from the dropdown"""
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index 42dc0cc1..7ee64041 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,15 +1,20 @@
+import asyncio
+import subprocess
import time
+import meilisearch
import psutil
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
-from .ide import router as ide_router
-from .gui import router as gui_router
-from .session_manager import session_manager
import atexit
import uvicorn
import argparse
+from .ide import router as ide_router
+from .gui import router as gui_router
+from .session_manager import session_manager
+from .meilisearch_server import start_meilisearch
+
app = FastAPI()
app.include_router(ide_router)
@@ -41,15 +46,20 @@ args = parser.parse_args()
# log_file = open('output.log', 'a')
# sys.stdout = log_file
-
def run_server():
uvicorn.run(app, host="0.0.0.0", port=args.port)
-def cleanup():
+async def cleanup_coroutine():
print("Cleaning up sessions")
for session_id in session_manager.sessions:
- session_manager.persist_session(session_id)
+ await session_manager.persist_session(session_id)
+
+
+def cleanup():
+ loop = asyncio.new_event_loop()
+ loop.run_until_complete(cleanup_coroutine())
+ loop.close()
def cpu_usage_report():
@@ -77,6 +87,12 @@ if __name__ == "__main__":
# cpu_thread = threading.Thread(target=cpu_usage_loop)
# cpu_thread.start()
+ try:
+ start_meilisearch()
+ except Exception as e:
+ print("Failed to start MeiliSearch")
+ print(e)
+
run_server()
except Exception as e:
cleanup()
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
new file mode 100644
index 00000000..419f081f
--- /dev/null
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -0,0 +1,56 @@
+import os
+import subprocess
+
+import meilisearch
+from ..libs.util.paths import getServerFolderPath
+
+
+def check_meilisearch_installed() -> bool:
+ """
+ Checks if MeiliSearch is installed.
+ """
+
+ serverPath = getServerFolderPath()
+ meilisearchPath = os.path.join(serverPath, "meilisearch")
+
+ return os.path.exists(meilisearchPath)
+
+
+def check_meilisearch_running() -> bool:
+ """
+ Checks if MeiliSearch is running.
+ """
+
+ try:
+ client = meilisearch.Client('http://localhost:7700')
+ resp = client.health()
+ if resp["status"] != "available":
+ return False
+ return True
+ except Exception:
+ return False
+
+
+def start_meilisearch():
+ """
+ Starts the MeiliSearch server, wait for it.
+ """
+
+ # Doesn't work on windows for now
+ if not os.name == "posix":
+ return
+
+ serverPath = getServerFolderPath()
+
+ # Check if MeiliSearch is installed
+ if not check_meilisearch_installed():
+ # Download MeiliSearch
+ print("Downloading MeiliSearch...")
+ subprocess.run(
+ f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath)
+
+ # Check if MeiliSearch is running
+ if not check_meilisearch_running():
+ print("Starting MeiliSearch...")
+ subprocess.Popen(["./meilisearch"], cwd=serverPath, stdout=subprocess.DEVNULL,
+ stderr=subprocess.STDOUT, close_fds=True, start_new_session=True)
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 90172a4e..96daf92c 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -74,7 +74,7 @@ class SessionManager:
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
- "state": autopilot.get_full_state().dict()
+ "state": state.dict()
})
autopilot.on_update(on_update)
@@ -84,9 +84,9 @@ class SessionManager:
def remove_session(self, session_id: str):
del self.sessions[session_id]
- def persist_session(self, session_id: str):
+ async def persist_session(self, session_id: str):
"""Save the session's FullState as a json file"""
- full_state = self.sessions[session_id].autopilot.get_full_state()
+ full_state = await self.sessions[session_id].autopilot.get_full_state()
if not os.path.exists(getSessionsFolderPath()):
os.mkdir(getSessionsFolderPath())
with open(getSessionFilePath(session_id), "w") as f: