summaryrefslogtreecommitdiff
path: root/server/continuedev/core/context.py
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-10-09 18:37:27 -0700
committerGitHub <noreply@github.com>2023-10-09 18:37:27 -0700
commitf09150617ed2454f3074bcf93f53aae5ae637d40 (patch)
tree5cfe614a64d921dfe58b049f426d67a8b832c71f /server/continuedev/core/context.py
parent985304a213f620cdff3f8f65f74ed7e3b79be29d (diff)
downloadsncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.gz
sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.bz2
sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.zip
Preview (#541)
* Strong typing (#533) * refactor: :recycle: get rid of continuedev.src.continuedev structure * refactor: :recycle: switching back to server folder * feat: :sparkles: make config.py imports shorter * feat: :bookmark: publish as pre-release vscode extension * refactor: :recycle: refactor and add more completion params to ui * build: :building_construction: download from preview S3 * fix: :bug: fix paths * fix: :green_heart: package:pre-release * ci: :green_heart: more time for tests * fix: :green_heart: fix build scripts * fix: :bug: fix import in run.py * fix: :bookmark: update version to try again * ci: 💚 Update package.json version [skip ci] * refactor: :fire: don't check for old extensions version * fix: :bug: small bug fixes * fix: :bug: fix config.py import paths * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: platform-specific builds test #1 * feat: :green_heart: ship with binary * fix: :green_heart: fix copy statement to include.exe for windows * fix: :green_heart: cd extension before packaging * chore: :loud_sound: count tokens generated * fix: :green_heart: remove npm_config_arch * fix: :green_heart: publish as pre-release! * chore: :bookmark: update version * perf: :green_heart: hardcode distro paths * fix: :bug: fix yaml syntax error * chore: :bookmark: update version * fix: :green_heart: update permissions and version * feat: :bug: kill old server if needed * feat: :lipstick: update marketplace icon for pre-release * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: auto-reload for config.py * feat: :wrench: update default config.py imports * feat: :sparkles: codelens in config.py * feat: :sparkles: select model param count from UI * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: more model options, ollama error handling * perf: :zap: don't show server loading immediately * fix: :bug: fixing small UI details * ci: 💚 Update package.json version [skip ci] * feat: :rocket: headers param on LLM class * fix: :bug: fix headers for openai.;y * feat: :sparkles: highlight code on cmd+shift+L * ci: 💚 Update package.json version [skip ci] * feat: :lipstick: sticky top bar in gui.tsx * fix: :loud_sound: websocket logging and horizontal scrollbar * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: allow AzureOpenAI Service through GGML * ci: 💚 Update package.json version [skip ci] * fix: :bug: fix automigration * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: upload binaries in ci, download apple silicon * chore: :fire: remove notes * fix: :green_heart: use curl to download binary * fix: :green_heart: set permissions on apple silicon binary * fix: :green_heart: testing * fix: :green_heart: cleanup file * fix: :green_heart: fix preview.yaml * fix: :green_heart: only upload once per binary * fix: :green_heart: install rosetta * ci: :green_heart: download binary after tests * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: prepare ci for merge to main --------- Co-authored-by: GitHub Action <action@github.com>
Diffstat (limited to 'server/continuedev/core/context.py')
-rw-r--r--server/continuedev/core/context.py516
1 files changed, 516 insertions, 0 deletions
diff --git a/server/continuedev/core/context.py b/server/continuedev/core/context.py
new file mode 100644
index 00000000..547a1593
--- /dev/null
+++ b/server/continuedev/core/context.py
@@ -0,0 +1,516 @@
+import asyncio
+import time
+from abc import abstractmethod
+from typing import Awaitable, Callable, Dict, List, Optional
+
+from meilisearch_python_async import Client
+from pydantic import BaseModel, Field
+
+from ..libs.util.create_async_task import create_async_task
+from ..libs.util.devdata import dev_data_logger
+from ..libs.util.logging import logger
+from ..libs.util.telemetry import posthog_logger
+from ..server.meilisearch_server import (
+ check_meilisearch_running,
+ get_meilisearch_url,
+ poll_meilisearch_running,
+ restart_meilisearch,
+ start_meilisearch,
+)
+from .main import (
+ ChatMessage,
+ ContextItem,
+ ContextItemDescription,
+ ContextItemId,
+ ContextProviderDescription,
+)
+
+
+class ContinueSDK(BaseModel):
+ """To avoid circular imports"""
+
+ ...
+
+
+SEARCH_INDEX_NAME = "continue_context_items"
+
+
+class ContextProvider(BaseModel):
+ """
+ The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'.
+ When you type '@', the context provider will be asked to populate a list of options.
+ These options will be updated on each keystroke.
+ When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object).
+ """
+
+ title: str = Field(
+ ...,
+ description="The title of the ContextProvider. This is what must be typed in the input to trigger the ContextProvider.",
+ )
+ sdk: ContinueSDK = Field(
+ None, description="The ContinueSDK instance accessible by the ContextProvider"
+ )
+ delete_documents: Callable[[List[str]], Awaitable] = Field(
+ None, description="Function to delete documents"
+ )
+ update_documents: Callable[[List[ContextItem], str], Awaitable] = Field(
+ None, description="Function to update documents"
+ )
+
+ display_title: str = Field(
+ ...,
+ description="The display title of the ContextProvider shown in the dropdown menu",
+ )
+ description: str = Field(
+ ...,
+ description="A description of the ContextProvider displayed in the dropdown menu",
+ )
+ dynamic: bool = Field(
+ ..., description="Indicates whether the ContextProvider is dynamic"
+ )
+ requires_query: bool = Field(
+ False,
+ description="Indicates whether the ContextProvider requires a query. For example, the SearchContextProvider requires you to type '@search <STRING_TO_SEARCH>'. This will change the behavior of the UI so that it can indicate the expectation for a query.",
+ )
+
+ selected_items: List[ContextItem] = Field(
+ [], description="List of selected items in the ContextProvider"
+ )
+
+ def dict(self, *args, **kwargs):
+ original_dict = super().dict(*args, **kwargs)
+ original_dict.pop("sdk", None)
+ original_dict.pop("delete_documents", None)
+ original_dict.pop("update_documents", None)
+ return original_dict
+
+ async def start(self, sdk: ContinueSDK, delete_documents, update_documents):
+ """
+ Starts the context provider.
+
+ Default implementation sets the sdk.
+ """
+ self.sdk = sdk
+ self.delete_documents = delete_documents
+ self.update_documents = update_documents
+
+ async def get_selected_items(self) -> List[ContextItem]:
+ """
+ Returns all of the selected ContextItems.
+
+ Default implementation simply returns self.selected_items.
+
+ Other implementations may add an async processing step.
+ """
+ return self.selected_items
+
+ @abstractmethod
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
+ """
+ Provide documents for search index. This is run on startup.
+
+ This is the only method that must be implemented.
+ """
+
+ async def get_chat_messages(self) -> List[ChatMessage]:
+ """
+ Returns all of the chat messages for the context provider.
+
+ Default implementation has a string template.
+ """
+ return [
+ ChatMessage(
+ role="user",
+ content=f"{item.description.name}: {item.description.description}\n\n{item.content}",
+ summary=item.description.description,
+ )
+ for item in await self.get_selected_items()
+ ]
+
+ async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
+ """
+ Returns the ContextItem with the given id.
+
+ Default implementation uses the search index to get the item.
+ """
+ async with Client(get_meilisearch_url()) as search_client:
+ try:
+ result = await search_client.index(SEARCH_INDEX_NAME).get_document(
+ id.to_string()
+ )
+ return ContextItem(
+ description=ContextItemDescription(
+ name=result["name"], description=result["description"], id=id
+ ),
+ content=result["content"],
+ )
+ except Exception as e:
+ logger.warning(f"Error while retrieving document from meilisearch: {e}")
+
+ return None
+
+ async def delete_context_with_ids(self, ids: List[ContextItemId]):
+ """
+ Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.
+
+ Default implementation simply deletes those with the given ids.
+ """
+ id_strings = {id.to_string() for id in ids}
+ self.selected_items = list(
+ filter(
+ lambda item: item.description.id.to_string() not in id_strings,
+ self.selected_items,
+ )
+ )
+
+ async def clear_context(self):
+ """
+ Clears all context.
+
+ Default implementation simply clears the selected items.
+ """
+ self.selected_items = []
+
+ async def add_context_item(self, id: ContextItemId, query: str):
+ """
+ Adds the given ContextItem to the list of ContextItems.
+
+ Default implementation simply appends the item, not allowing duplicates.
+
+ This method also allows you not to have to load all of the information until an item is selected.
+ """
+
+ # Don't add duplicate context
+ for item in self.selected_items:
+ if item.description.id.item_id == id.item_id:
+ return
+
+ if new_item := await self.get_item(id, query):
+ self.selected_items.append(new_item)
+
+ async def manually_add_context_item(self, context_item: ContextItem):
+ for item in self.selected_items:
+ if item.description.id.item_id == context_item.description.id.item_id:
+ return
+
+ self.selected_items.append(context_item)
+
+
+class ContextManager:
+ """
+ The context manager is responsible for storing the context to be passed to the LLM, including
+ - ContextItems (highlighted code, GitHub Issues, etc.)
+ - ChatMessages in the history
+ - System Message
+ - Functions
+
+ It is responsible for compiling all of this information into a single prompt without exceeding the token limit.
+ """
+
+ def get_provider_descriptions(self) -> List[ContextProviderDescription]:
+ """
+ Returns a list of ContextProviderDescriptions for each context provider.
+ """
+ return [
+ ContextProviderDescription(
+ title=provider.title,
+ display_title=provider.display_title,
+ description=provider.description,
+ dynamic=provider.dynamic,
+ requires_query=provider.requires_query,
+ )
+ for provider in self.context_providers.values()
+ if provider.title != "code"
+ ]
+
+ async def get_selected_items(self) -> List[ContextItem]:
+ """
+ Returns all of the selected ContextItems.
+ """
+ return sum(
+ [
+ await provider.get_selected_items()
+ for provider in self.context_providers.values()
+ ],
+ [],
+ )
+
+ async def get_chat_messages(self) -> List[ChatMessage]:
+ """
+ Returns chat messages from each provider.
+ """
+ return sum(
+ [
+ await provider.get_chat_messages()
+ for provider in self.context_providers.values()
+ ],
+ [],
+ )
+
+ def __init__(self):
+ self.context_providers = {}
+ self.provider_titles = set()
+
+ async def start(
+ self,
+ context_providers: List[ContextProvider],
+ sdk: ContinueSDK,
+ only_reloading: bool = False,
+ ):
+ """
+ Starts the context manager.
+ """
+ new_context_providers = {
+ provider.title: provider
+ for provider in context_providers
+ if provider.title not in self.provider_titles
+ }
+
+ self.context_providers = {
+ provider.title: provider for provider in context_providers
+ }
+ self.provider_titles = {provider.title for provider in context_providers}
+
+ for provider in context_providers:
+ await provider.start(
+ sdk,
+ ContextManager.delete_documents,
+ ContextManager.update_documents,
+ )
+
+ async def on_err(e):
+ logger.warning(f"Error loading meilisearch index: {e}")
+
+ # Start MeiliSearch in the background without blocking
+ async def load_index(providers_to_load: List[ContextProvider]):
+ running = await check_meilisearch_running()
+ if not running:
+ await start_meilisearch()
+ try:
+ await asyncio.wait_for(poll_meilisearch_running(), timeout=20)
+ except asyncio.TimeoutError:
+ logger.warning(
+ "Meilisearch did not start in less than 20 seconds. Stopping polling."
+ )
+ return
+
+ logger.debug("Loading Meilisearch index...")
+ await self.load_index(
+ sdk.ide.workspace_directory, providers_to_load=providers_to_load
+ )
+ logger.debug("Loaded Meilisearch index")
+
+ providers_to_load = (
+ new_context_providers if only_reloading else context_providers
+ )
+ create_async_task(load_index(providers_to_load), on_err)
+
+ @staticmethod
+ async def update_documents(context_items: List[ContextItem], workspace_dir: str):
+ """
+ Updates the documents in the search index.
+ """
+ documents = [
+ {
+ "id": item.description.id.to_string(),
+ "name": item.description.name,
+ "description": item.description.description,
+ "content": item.content,
+ "workspace_dir": workspace_dir,
+ "provider_name": item.description.id.provider_title,
+ }
+ for item in context_items
+ ]
+ async with Client(get_meilisearch_url()) as search_client:
+
+ async def add_docs():
+ index = await search_client.get_index(SEARCH_INDEX_NAME)
+ await index.add_documents(documents or [])
+
+ try:
+ await asyncio.wait_for(add_docs(), timeout=20)
+ except asyncio.TimeoutError:
+ logger.warning("Failed to add document to meilisearch in 20 seconds")
+ except Exception as e:
+ logger.warning(f"Error adding document to meilisearch: {e}")
+
+ @staticmethod
+ async def delete_documents(ids):
+ """
+ Deletes the documents in the search index.
+ """
+ async with Client(get_meilisearch_url()) as search_client:
+ try:
+ await asyncio.wait_for(
+ search_client.index(SEARCH_INDEX_NAME).delete_documents(ids),
+ timeout=20,
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ "Failed to delete document from meilisearch in 20 seconds"
+ )
+ except Exception as e:
+ logger.warning(f"Error deleting document from meilisearch: {e}")
+
+ async def load_index(
+ self,
+ workspace_dir: str,
+ should_retry: bool = True,
+ providers_to_load: Optional[List[ContextProvider]] = None,
+ ):
+ try:
+ async with Client(get_meilisearch_url()) as search_client:
+ # First, create the index if it doesn't exist
+ # The index is currently shared by all workspaces
+ await search_client.create_index(SEARCH_INDEX_NAME)
+ globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME)
+ await globalSearchIndex.update_ranking_rules(
+ ["attribute", "words", "typo", "proximity", "sort", "exactness"]
+ )
+ await globalSearchIndex.update_searchable_attributes(
+ ["name", "description"]
+ )
+ await globalSearchIndex.update_filterable_attributes(
+ ["workspace_dir", "provider_name"]
+ )
+
+ async def load_context_provider(provider: ContextProvider):
+ context_items = await provider.provide_context_items(workspace_dir)
+ documents = [
+ {
+ "id": item.description.id.to_string(),
+ "name": item.description.name,
+ "description": item.description.description,
+ "content": item.content,
+ "workspace_dir": workspace_dir,
+ "provider_name": provider.title,
+ }
+ for item in context_items
+ ]
+ if len(documents) > 0:
+ await globalSearchIndex.add_documents(documents)
+
+ return len(documents)
+
+ async def safe_load(provider: ContextProvider):
+ ti = time.time()
+ try:
+ num_documents = await asyncio.wait_for(
+ load_context_provider(provider), timeout=20
+ )
+ except asyncio.TimeoutError:
+ logger.warning(
+ f"Failed to add documents to meilisearch for context provider {provider.__class__.__name__} in 20 seconds"
+ )
+ return
+ except Exception as e:
+ logger.warning(
+ f"Error adding documents to meilisearch for context provider {provider.__class__.__name__}: {e}"
+ )
+ return
+
+ tf = time.time()
+ logger.debug(
+ f"Loaded {num_documents} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}"
+ )
+
+ tasks = [
+ safe_load(provider)
+ for _, provider in (
+ providers_to_load or self.context_providers
+ ).items()
+ ]
+ await asyncio.wait_for(asyncio.gather(*tasks), timeout=20)
+
+ except Exception as e:
+ logger.debug(f"Error loading meilisearch index: {e}")
+ if should_retry:
+ await restart_meilisearch()
+ try:
+ await asyncio.wait_for(poll_meilisearch_running(), timeout=20)
+ except asyncio.TimeoutError:
+ logger.warning(
+ "Meilisearch did not restart in less than 20 seconds. Stopping polling."
+ )
+ await self.load_index(workspace_dir, False)
+
+ async def select_context_item(self, id: str, query: str):
+ """
+ Selects the ContextItem with the given id.
+ """
+ id: ContextItemId = ContextItemId.from_string(id)
+ if id.provider_title not in self.provider_titles:
+ raise ValueError(
+ f"Context provider with title {id.provider_title} not found"
+ )
+
+ posthog_logger.capture_event(
+ "select_context_item",
+ {
+ "provider_title": id.provider_title,
+ "item_id": id.item_id,
+ "query": query,
+ },
+ )
+ dev_data_logger.capture(
+ "select_context_item",
+ {
+ "provider_title": id.provider_title,
+ "item_id": id.item_id,
+ "query": query,
+ },
+ )
+ await self.context_providers[id.provider_title].add_context_item(id, query)
+
+ async def get_context_item(self, id: str, query: str) -> ContextItem:
+ """
+ Returns the ContextItem with the given id.
+ """
+ id: ContextItemId = ContextItemId.from_string(id)
+ if id.provider_title not in self.provider_titles:
+ raise ValueError(
+ f"Context provider with title {id.provider_title} not found"
+ )
+
+ return await self.context_providers[id.provider_title].get_item(id, query)
+
+ async def delete_context_with_ids(self, ids: List[str]):
+ """
+ Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.
+ """
+
+ # Group by provider title
+ provider_title_to_ids: Dict[str, List[ContextItemId]] = {}
+ for id in ids:
+ id: ContextItemId = ContextItemId.from_string(id)
+ if id.provider_title not in provider_title_to_ids:
+ provider_title_to_ids[id.provider_title] = []
+ provider_title_to_ids[id.provider_title].append(id)
+
+ # Recalculate context for each updated provider
+ for provider_title, ids in provider_title_to_ids.items():
+ await self.context_providers[provider_title].delete_context_with_ids(ids)
+
+ async def clear_context(self):
+ """
+ Clears all context.
+ """
+ for provider in self.context_providers.values():
+ await self.context_providers[provider.title].clear_context()
+
+ async def manually_add_context_item(self, item: ContextItem):
+ """
+ Adds the given ContextItem to the list of ContextItems.
+ """
+ if item.description.id.provider_title not in self.provider_titles:
+ return
+
+ await self.context_providers[
+ item.description.id.provider_title
+ ].manually_add_context_item(item)
+
+
+"""
+Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the
+same format of prompts so you don't have to redo all of this logic.
+"""