diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-10-09 18:37:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-09 18:37:27 -0700 |
commit | f09150617ed2454f3074bcf93f53aae5ae637d40 (patch) | |
tree | 5cfe614a64d921dfe58b049f426d67a8b832c71f /server/continuedev/core/context.py | |
parent | 985304a213f620cdff3f8f65f74ed7e3b79be29d (diff) | |
download | sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.gz sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.bz2 sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.zip |
Preview (#541)
* Strong typing (#533)
* refactor: :recycle: get rid of continuedev.src.continuedev structure
* refactor: :recycle: switching back to server folder
* feat: :sparkles: make config.py imports shorter
* feat: :bookmark: publish as pre-release vscode extension
* refactor: :recycle: refactor and add more completion params to ui
* build: :building_construction: download from preview S3
* fix: :bug: fix paths
* fix: :green_heart: package:pre-release
* ci: :green_heart: more time for tests
* fix: :green_heart: fix build scripts
* fix: :bug: fix import in run.py
* fix: :bookmark: update version to try again
* ci: 💚 Update package.json version [skip ci]
* refactor: :fire: don't check for old extensions version
* fix: :bug: small bug fixes
* fix: :bug: fix config.py import paths
* ci: 💚 Update package.json version [skip ci]
* ci: :green_heart: platform-specific builds test #1
* feat: :green_heart: ship with binary
* fix: :green_heart: fix copy statement to include.exe for windows
* fix: :green_heart: cd extension before packaging
* chore: :loud_sound: count tokens generated
* fix: :green_heart: remove npm_config_arch
* fix: :green_heart: publish as pre-release!
* chore: :bookmark: update version
* perf: :green_heart: hardcode distro paths
* fix: :bug: fix yaml syntax error
* chore: :bookmark: update version
* fix: :green_heart: update permissions and version
* feat: :bug: kill old server if needed
* feat: :lipstick: update marketplace icon for pre-release
* ci: 💚 Update package.json version [skip ci]
* feat: :sparkles: auto-reload for config.py
* feat: :wrench: update default config.py imports
* feat: :sparkles: codelens in config.py
* feat: :sparkles: select model param count from UI
* ci: 💚 Update package.json version [skip ci]
* feat: :sparkles: more model options, ollama error handling
* perf: :zap: don't show server loading immediately
* fix: :bug: fixing small UI details
* ci: 💚 Update package.json version [skip ci]
* feat: :rocket: headers param on LLM class
* fix: :bug: fix headers for openai.;y
* feat: :sparkles: highlight code on cmd+shift+L
* ci: 💚 Update package.json version [skip ci]
* feat: :lipstick: sticky top bar in gui.tsx
* fix: :loud_sound: websocket logging and horizontal scrollbar
* ci: 💚 Update package.json version [skip ci]
* feat: :sparkles: allow AzureOpenAI Service through GGML
* ci: 💚 Update package.json version [skip ci]
* fix: :bug: fix automigration
* ci: 💚 Update package.json version [skip ci]
* ci: :green_heart: upload binaries in ci, download apple silicon
* chore: :fire: remove notes
* fix: :green_heart: use curl to download binary
* fix: :green_heart: set permissions on apple silicon binary
* fix: :green_heart: testing
* fix: :green_heart: cleanup file
* fix: :green_heart: fix preview.yaml
* fix: :green_heart: only upload once per binary
* fix: :green_heart: install rosetta
* ci: :green_heart: download binary after tests
* ci: 💚 Update package.json version [skip ci]
* ci: :green_heart: prepare ci for merge to main
---------
Co-authored-by: GitHub Action <action@github.com>
Diffstat (limited to 'server/continuedev/core/context.py')
-rw-r--r-- | server/continuedev/core/context.py | 516 |
1 files changed, 516 insertions, 0 deletions
diff --git a/server/continuedev/core/context.py b/server/continuedev/core/context.py new file mode 100644 index 00000000..547a1593 --- /dev/null +++ b/server/continuedev/core/context.py @@ -0,0 +1,516 @@ +import asyncio +import time +from abc import abstractmethod +from typing import Awaitable, Callable, Dict, List, Optional + +from meilisearch_python_async import Client +from pydantic import BaseModel, Field + +from ..libs.util.create_async_task import create_async_task +from ..libs.util.devdata import dev_data_logger +from ..libs.util.logging import logger +from ..libs.util.telemetry import posthog_logger +from ..server.meilisearch_server import ( + check_meilisearch_running, + get_meilisearch_url, + poll_meilisearch_running, + restart_meilisearch, + start_meilisearch, +) +from .main import ( + ChatMessage, + ContextItem, + ContextItemDescription, + ContextItemId, + ContextProviderDescription, +) + + +class ContinueSDK(BaseModel): + """To avoid circular imports""" + + ... + + +SEARCH_INDEX_NAME = "continue_context_items" + + +class ContextProvider(BaseModel): + """ + The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. + When you type '@', the context provider will be asked to populate a list of options. + These options will be updated on each keystroke. + When you hit enter on an option, the context provider will add that item to the autopilot's list of context (which is all stored in the ContextManager object). + """ + + title: str = Field( + ..., + description="The title of the ContextProvider. This is what must be typed in the input to trigger the ContextProvider.", + ) + sdk: ContinueSDK = Field( + None, description="The ContinueSDK instance accessible by the ContextProvider" + ) + delete_documents: Callable[[List[str]], Awaitable] = Field( + None, description="Function to delete documents" + ) + update_documents: Callable[[List[ContextItem], str], Awaitable] = Field( + None, description="Function to update documents" + ) + + display_title: str = Field( + ..., + description="The display title of the ContextProvider shown in the dropdown menu", + ) + description: str = Field( + ..., + description="A description of the ContextProvider displayed in the dropdown menu", + ) + dynamic: bool = Field( + ..., description="Indicates whether the ContextProvider is dynamic" + ) + requires_query: bool = Field( + False, + description="Indicates whether the ContextProvider requires a query. For example, the SearchContextProvider requires you to type '@search <STRING_TO_SEARCH>'. This will change the behavior of the UI so that it can indicate the expectation for a query.", + ) + + selected_items: List[ContextItem] = Field( + [], description="List of selected items in the ContextProvider" + ) + + def dict(self, *args, **kwargs): + original_dict = super().dict(*args, **kwargs) + original_dict.pop("sdk", None) + original_dict.pop("delete_documents", None) + original_dict.pop("update_documents", None) + return original_dict + + async def start(self, sdk: ContinueSDK, delete_documents, update_documents): + """ + Starts the context provider. + + Default implementation sets the sdk. + """ + self.sdk = sdk + self.delete_documents = delete_documents + self.update_documents = update_documents + + async def get_selected_items(self) -> List[ContextItem]: + """ + Returns all of the selected ContextItems. + + Default implementation simply returns self.selected_items. + + Other implementations may add an async processing step. + """ + return self.selected_items + + @abstractmethod + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + """ + Provide documents for search index. This is run on startup. + + This is the only method that must be implemented. + """ + + async def get_chat_messages(self) -> List[ChatMessage]: + """ + Returns all of the chat messages for the context provider. + + Default implementation has a string template. + """ + return [ + ChatMessage( + role="user", + content=f"{item.description.name}: {item.description.description}\n\n{item.content}", + summary=item.description.description, + ) + for item in await self.get_selected_items() + ] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + """ + Returns the ContextItem with the given id. + + Default implementation uses the search index to get the item. + """ + async with Client(get_meilisearch_url()) as search_client: + try: + result = await search_client.index(SEARCH_INDEX_NAME).get_document( + id.to_string() + ) + return ContextItem( + description=ContextItemDescription( + name=result["name"], description=result["description"], id=id + ), + content=result["content"], + ) + except Exception as e: + logger.warning(f"Error while retrieving document from meilisearch: {e}") + + return None + + async def delete_context_with_ids(self, ids: List[ContextItemId]): + """ + Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + + Default implementation simply deletes those with the given ids. + """ + id_strings = {id.to_string() for id in ids} + self.selected_items = list( + filter( + lambda item: item.description.id.to_string() not in id_strings, + self.selected_items, + ) + ) + + async def clear_context(self): + """ + Clears all context. + + Default implementation simply clears the selected items. + """ + self.selected_items = [] + + async def add_context_item(self, id: ContextItemId, query: str): + """ + Adds the given ContextItem to the list of ContextItems. + + Default implementation simply appends the item, not allowing duplicates. + + This method also allows you not to have to load all of the information until an item is selected. + """ + + # Don't add duplicate context + for item in self.selected_items: + if item.description.id.item_id == id.item_id: + return + + if new_item := await self.get_item(id, query): + self.selected_items.append(new_item) + + async def manually_add_context_item(self, context_item: ContextItem): + for item in self.selected_items: + if item.description.id.item_id == context_item.description.id.item_id: + return + + self.selected_items.append(context_item) + + +class ContextManager: + """ + The context manager is responsible for storing the context to be passed to the LLM, including + - ContextItems (highlighted code, GitHub Issues, etc.) + - ChatMessages in the history + - System Message + - Functions + + It is responsible for compiling all of this information into a single prompt without exceeding the token limit. + """ + + def get_provider_descriptions(self) -> List[ContextProviderDescription]: + """ + Returns a list of ContextProviderDescriptions for each context provider. + """ + return [ + ContextProviderDescription( + title=provider.title, + display_title=provider.display_title, + description=provider.description, + dynamic=provider.dynamic, + requires_query=provider.requires_query, + ) + for provider in self.context_providers.values() + if provider.title != "code" + ] + + async def get_selected_items(self) -> List[ContextItem]: + """ + Returns all of the selected ContextItems. + """ + return sum( + [ + await provider.get_selected_items() + for provider in self.context_providers.values() + ], + [], + ) + + async def get_chat_messages(self) -> List[ChatMessage]: + """ + Returns chat messages from each provider. + """ + return sum( + [ + await provider.get_chat_messages() + for provider in self.context_providers.values() + ], + [], + ) + + def __init__(self): + self.context_providers = {} + self.provider_titles = set() + + async def start( + self, + context_providers: List[ContextProvider], + sdk: ContinueSDK, + only_reloading: bool = False, + ): + """ + Starts the context manager. + """ + new_context_providers = { + provider.title: provider + for provider in context_providers + if provider.title not in self.provider_titles + } + + self.context_providers = { + provider.title: provider for provider in context_providers + } + self.provider_titles = {provider.title for provider in context_providers} + + for provider in context_providers: + await provider.start( + sdk, + ContextManager.delete_documents, + ContextManager.update_documents, + ) + + async def on_err(e): + logger.warning(f"Error loading meilisearch index: {e}") + + # Start MeiliSearch in the background without blocking + async def load_index(providers_to_load: List[ContextProvider]): + running = await check_meilisearch_running() + if not running: + await start_meilisearch() + try: + await asyncio.wait_for(poll_meilisearch_running(), timeout=20) + except asyncio.TimeoutError: + logger.warning( + "Meilisearch did not start in less than 20 seconds. Stopping polling." + ) + return + + logger.debug("Loading Meilisearch index...") + await self.load_index( + sdk.ide.workspace_directory, providers_to_load=providers_to_load + ) + logger.debug("Loaded Meilisearch index") + + providers_to_load = ( + new_context_providers if only_reloading else context_providers + ) + create_async_task(load_index(providers_to_load), on_err) + + @staticmethod + async def update_documents(context_items: List[ContextItem], workspace_dir: str): + """ + Updates the documents in the search index. + """ + documents = [ + { + "id": item.description.id.to_string(), + "name": item.description.name, + "description": item.description.description, + "content": item.content, + "workspace_dir": workspace_dir, + "provider_name": item.description.id.provider_title, + } + for item in context_items + ] + async with Client(get_meilisearch_url()) as search_client: + + async def add_docs(): + index = await search_client.get_index(SEARCH_INDEX_NAME) + await index.add_documents(documents or []) + + try: + await asyncio.wait_for(add_docs(), timeout=20) + except asyncio.TimeoutError: + logger.warning("Failed to add document to meilisearch in 20 seconds") + except Exception as e: + logger.warning(f"Error adding document to meilisearch: {e}") + + @staticmethod + async def delete_documents(ids): + """ + Deletes the documents in the search index. + """ + async with Client(get_meilisearch_url()) as search_client: + try: + await asyncio.wait_for( + search_client.index(SEARCH_INDEX_NAME).delete_documents(ids), + timeout=20, + ) + except asyncio.TimeoutError: + logger.warning( + "Failed to delete document from meilisearch in 20 seconds" + ) + except Exception as e: + logger.warning(f"Error deleting document from meilisearch: {e}") + + async def load_index( + self, + workspace_dir: str, + should_retry: bool = True, + providers_to_load: Optional[List[ContextProvider]] = None, + ): + try: + async with Client(get_meilisearch_url()) as search_client: + # First, create the index if it doesn't exist + # The index is currently shared by all workspaces + await search_client.create_index(SEARCH_INDEX_NAME) + globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME) + await globalSearchIndex.update_ranking_rules( + ["attribute", "words", "typo", "proximity", "sort", "exactness"] + ) + await globalSearchIndex.update_searchable_attributes( + ["name", "description"] + ) + await globalSearchIndex.update_filterable_attributes( + ["workspace_dir", "provider_name"] + ) + + async def load_context_provider(provider: ContextProvider): + context_items = await provider.provide_context_items(workspace_dir) + documents = [ + { + "id": item.description.id.to_string(), + "name": item.description.name, + "description": item.description.description, + "content": item.content, + "workspace_dir": workspace_dir, + "provider_name": provider.title, + } + for item in context_items + ] + if len(documents) > 0: + await globalSearchIndex.add_documents(documents) + + return len(documents) + + async def safe_load(provider: ContextProvider): + ti = time.time() + try: + num_documents = await asyncio.wait_for( + load_context_provider(provider), timeout=20 + ) + except asyncio.TimeoutError: + logger.warning( + f"Failed to add documents to meilisearch for context provider {provider.__class__.__name__} in 20 seconds" + ) + return + except Exception as e: + logger.warning( + f"Error adding documents to meilisearch for context provider {provider.__class__.__name__}: {e}" + ) + return + + tf = time.time() + logger.debug( + f"Loaded {num_documents} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}" + ) + + tasks = [ + safe_load(provider) + for _, provider in ( + providers_to_load or self.context_providers + ).items() + ] + await asyncio.wait_for(asyncio.gather(*tasks), timeout=20) + + except Exception as e: + logger.debug(f"Error loading meilisearch index: {e}") + if should_retry: + await restart_meilisearch() + try: + await asyncio.wait_for(poll_meilisearch_running(), timeout=20) + except asyncio.TimeoutError: + logger.warning( + "Meilisearch did not restart in less than 20 seconds. Stopping polling." + ) + await self.load_index(workspace_dir, False) + + async def select_context_item(self, id: str, query: str): + """ + Selects the ContextItem with the given id. + """ + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in self.provider_titles: + raise ValueError( + f"Context provider with title {id.provider_title} not found" + ) + + posthog_logger.capture_event( + "select_context_item", + { + "provider_title": id.provider_title, + "item_id": id.item_id, + "query": query, + }, + ) + dev_data_logger.capture( + "select_context_item", + { + "provider_title": id.provider_title, + "item_id": id.item_id, + "query": query, + }, + ) + await self.context_providers[id.provider_title].add_context_item(id, query) + + async def get_context_item(self, id: str, query: str) -> ContextItem: + """ + Returns the ContextItem with the given id. + """ + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in self.provider_titles: + raise ValueError( + f"Context provider with title {id.provider_title} not found" + ) + + return await self.context_providers[id.provider_title].get_item(id, query) + + async def delete_context_with_ids(self, ids: List[str]): + """ + Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. + """ + + # Group by provider title + provider_title_to_ids: Dict[str, List[ContextItemId]] = {} + for id in ids: + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in provider_title_to_ids: + provider_title_to_ids[id.provider_title] = [] + provider_title_to_ids[id.provider_title].append(id) + + # Recalculate context for each updated provider + for provider_title, ids in provider_title_to_ids.items(): + await self.context_providers[provider_title].delete_context_with_ids(ids) + + async def clear_context(self): + """ + Clears all context. + """ + for provider in self.context_providers.values(): + await self.context_providers[provider.title].clear_context() + + async def manually_add_context_item(self, item: ContextItem): + """ + Adds the given ContextItem to the list of ContextItems. + """ + if item.description.id.provider_title not in self.provider_titles: + return + + await self.context_providers[ + item.description.id.provider_title + ].manually_add_context_item(item) + + +""" +Should define "ArgsTransformer" and "PromptTransformer" classes for the different LLMs. A standard way for them to ingest the +same format of prompts so you don't have to redo all of this logic. +""" |