diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-25 23:09:21 -0700 |
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-25 23:09:21 -0700 |
| commit | d2afe5ff258eb2443e0f2178da523150fdb5eb0d (patch) | |
| tree | 181be2609e924d084a8ad9781283cb725377b87f /continuedev/src/continuedev/core | |
| parent | 861a873f7ecf455b9c7833040b2a8163e369e062 (diff) | |
| download | sncontinue-d2afe5ff258eb2443e0f2178da523150fdb5eb0d.tar.gz sncontinue-d2afe5ff258eb2443e0f2178da523150fdb5eb0d.tar.bz2 sncontinue-d2afe5ff258eb2443e0f2178da523150fdb5eb0d.zip | |
meilisearch async client
Diffstat (limited to 'continuedev/src/continuedev/core')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 40 |
2 files changed, 23 insertions, 19 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index d018c29e..42a58423 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -69,7 +69,7 @@ class Autopilot(ContinueBaseModel): autopilot.continue_sdk = await ContinueSDK.create(autopilot) # Load documents into the search index - autopilot.context_manager = ContextManager( + autopilot.context_manager = await ContextManager.create( autopilot.continue_sdk.config.context_providers + [ HighlightedCodeContextProvider(ide=ide), FileContextProvider(workspace_dir=ide.workspace_directory) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 67bba651..7d302656 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Dict, List -import meilisearch +from meilisearch_python_async import Client from pydantic import BaseModel @@ -50,21 +50,21 @@ class ContextProvider(BaseModel): """ return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] - async def get_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client) -> ContextItem: + async def get_item(self, id: ContextItemId, query: str, search_client: Client) -> ContextItem: """ Returns the ContextItem with the given id. Default implementation uses the search index to get the item. """ - result = search_client.index( + result = await search_client.index( SEARCH_INDEX_NAME).get_document(id.to_string()) return ContextItem( description=ContextItemDescription( - name=result.name, - description=result.description, + name=result["name"], + description=result["description"], id=id ), - content=result.content + content=result["content"] ) async def delete_context_with_ids(self, ids: List[ContextItemId]): @@ -85,7 +85,7 @@ class ContextProvider(BaseModel): """ self.selected_items = [] - async def add_context_item(self, id: ContextItemId, query: str, search_client: meilisearch.Client): + async def add_context_item(self, id: ContextItemId, query: str, search_client: Client): """ Adds the given ContextItem to the list of ContextItems. @@ -126,21 +126,26 @@ class ContextManager: """ return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) - search_client: meilisearch.Client - - def __init__(self, context_providers: List[ContextProvider]): - self.search_client = meilisearch.Client('http://localhost:7700') - - # If meilisearch isn't running, don't use any ContextProviders that might depend on it - if not check_meilisearch_running(): - context_providers = list( - filter(lambda cp: cp.title == "code", context_providers)) + search_client: Client + def __init__(self, context_providers: List[ContextProvider], search_client: Client): + self.search_client = search_client self.context_providers = { prov.title: prov for prov in context_providers} self.provider_titles = { provider.title for provider in context_providers} + @classmethod + async def create(cls, context_providers: List[ContextProvider]): + search_client = Client('http://localhost:7700') + health = await search_client.health() + if not health.status == "available": + print("MeiliSearch not running, avoiding any dependent context providers") + context_providers = list( + filter(lambda cp: cp.title == "code", context_providers)) + + return cls(context_providers, search_client) + async def load_index(self): for _, provider in self.context_providers.items(): context_items = await provider.provide_context_items() @@ -154,8 +159,7 @@ class ContextManager: for item in context_items ] if len(documents) > 0: - self.search_client.index( - SEARCH_INDEX_NAME).add_documents(documents) + await self.search_client.index(SEARCH_INDEX_NAME).add_documents(documents) # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: # """ |
