diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-18 15:13:02 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-18 15:13:02 -0700 |
commit | ab7a90a0972188dcc7b8c28b1263c918776ca19d (patch) | |
tree | 23bcea6911a566a6bc49b12ea8860d6459a2bc08 /continuedev/src | |
parent | 36074201c626281b4e42bcca02c85a8a931c5914 (diff) | |
download | sncontinue-ab7a90a0972188dcc7b8c28b1263c918776ca19d.tar.gz sncontinue-ab7a90a0972188dcc7b8c28b1263c918776ca19d.tar.bz2 sncontinue-ab7a90a0972188dcc7b8c28b1263c918776ca19d.zip |
fix: :children_crossing: don't order meilisearch results by contnet
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 119 |
1 files changed, 74 insertions, 45 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 2f250839..e5d6e13b 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,17 +1,16 @@ - -from abc import abstractmethod import asyncio import time +from abc import abstractmethod from typing import Dict, List + from meilisearch_python_async import Client from pydantic import BaseModel - -from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId -from ..server.meilisearch_server import poll_meilisearch_running +from ..libs.util.create_async_task import create_async_task from ..libs.util.logging import logger from ..libs.util.telemetry import posthog_logger -from ..libs.util.create_async_task import create_async_task +from ..server.meilisearch_server import poll_meilisearch_running +from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId SEARCH_INDEX_NAME = "continue_context_items" @@ -52,7 +51,14 @@ class ContextProvider(BaseModel): Default implementation has a string template. """ - return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] + return [ + ChatMessage( + role="user", + content=f"{item.description.name}: {item.description.description}\n\n{item.content}", + summary=item.description.description, + ) + for item in await self.get_selected_items() + ] async def get_item(self, id: ContextItemId, query: str) -> ContextItem: """ @@ -60,21 +66,19 @@ class ContextProvider(BaseModel): Default implementation uses the search index to get the item. """ - async with Client('http://localhost:7700') as search_client: + async with Client("http://localhost:7700") as search_client: try: - result = await search_client.index( - SEARCH_INDEX_NAME).get_document(id.to_string()) + result = await search_client.index(SEARCH_INDEX_NAME).get_document( + id.to_string() + ) return ContextItem( description=ContextItemDescription( - name=result["name"], - description=result["description"], - id=id + name=result["name"], description=result["description"], id=id ), - content=result["content"] + content=result["content"], ) except Exception as e: - logger.warning( - f"Error while retrieving document from meilisearch: {e}") + logger.warning(f"Error while retrieving document from meilisearch: {e}") return None @@ -86,7 +90,11 @@ class ContextProvider(BaseModel): """ id_strings = {id.to_string() for id in ids} self.selected_items = list( - filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items)) + filter( + lambda item: item.description.id.to_string() not in id_strings, + self.selected_items, + ) + ) async def clear_context(self): """ @@ -129,62 +137,76 @@ class ContextManager: """ Returns all of the selected ContextItems. """ - return sum([await provider.get_selected_items() for provider in self.context_providers.values()], []) + return sum( + [ + await provider.get_selected_items() + for provider in self.context_providers.values() + ], + [], + ) async def get_chat_messages(self) -> List[ChatMessage]: """ Returns chat messages from each provider. """ - return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) + return sum( + [ + await provider.get_chat_messages() + for provider in self.context_providers.values() + ], + [], + ) def __init__(self): self.context_providers = {} self.provider_titles = set() - async def start(self, context_providers: List[ContextProvider], workspace_directory: str): + async def start( + self, context_providers: List[ContextProvider], workspace_directory: str + ): """ Starts the context manager. """ # Use only non-meilisearch-dependent providers until it is loaded self.context_providers = { - provider.title: provider for provider in context_providers if provider.title == "code" + provider.title: provider + for provider in context_providers + if provider.title == "code" } - self.provider_titles = { - provider.title for provider in context_providers} + self.provider_titles = {provider.title for provider in context_providers} # Start MeiliSearch in the background without blocking async def start_meilisearch(context_providers): try: await asyncio.wait_for(poll_meilisearch_running(), timeout=20) self.context_providers = { - prov.title: prov for prov in context_providers} + prov.title: prov for prov in context_providers + } logger.debug("Loading Meilisearch index...") await self.load_index(workspace_directory) logger.debug("Loaded Meilisearch index") except asyncio.TimeoutError: logger.warning("MeiliSearch did not start within 5 seconds") logger.warning( - "MeiliSearch not running, avoiding any dependent context providers") + "MeiliSearch not running, avoiding any dependent context providers" + ) create_async_task(start_meilisearch(context_providers)) async def load_index(self, workspace_dir: str): try: - async with Client('http://localhost:7700') as search_client: + async with Client("http://localhost:7700") as search_client: # First, create the index if it doesn't exist # The index is currently shared by all workspaces await search_client.create_index(SEARCH_INDEX_NAME) globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME) - await globalSearchIndex.update_ranking_rules([ - "attribute", - "words", - "typo", - "proximity", - "sort", - "exactness" - ]) - await globalSearchIndex.update_filterable_attributes( - ["workspace_dir"]) + await globalSearchIndex.update_ranking_rules( + ["attribute", "words", "typo", "proximity", "sort", "exactness"] + ) + await globalSearchIndex.update_searchable_attributes( + ["name", "description"] + ) + await globalSearchIndex.update_filterable_attributes(["workspace_dir"]) for _, provider in self.context_providers.items(): ti = time.time() @@ -201,11 +223,14 @@ class ContextManager: for item in context_items ] if len(documents) > 0: - await asyncio.wait_for(globalSearchIndex.add_documents(documents), timeout=5) + await asyncio.wait_for( + globalSearchIndex.add_documents(documents), timeout=5 + ) tf = time.time() logger.debug( - f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}") + f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}" + ) except Exception as e: logger.debug(f"Error loading meilisearch index: {e}") @@ -216,13 +241,17 @@ class ContextManager: id: ContextItemId = ContextItemId.from_string(id) if id.provider_title not in self.provider_titles: raise ValueError( - f"Context provider with title {id.provider_title} not found") - - posthog_logger.capture_event("select_context_item", { - "provider_title": id.provider_title, - "item_id": id.item_id, - "query": query - }) + f"Context provider with title {id.provider_title} not found" + ) + + posthog_logger.capture_event( + "select_context_item", + { + "provider_title": id.provider_title, + "item_id": id.item_id, + "query": query, + }, + ) await self.context_providers[id.provider_title].add_context_item(id, query) async def delete_context_with_ids(self, ids: List[str]): |