diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-26 12:26:32 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-26 12:26:32 -0700 |
commit | 9a0cd644dcb5ff46817a6ea686a6de0fb764c960 (patch) | |
tree | 8151897f0bd2c0f0c92e34a10027c25058be4b57 /continuedev | |
parent | e69837541db800643f666f6f5a9635b43511295c (diff) | |
download | sncontinue-9a0cd644dcb5ff46817a6ea686a6de0fb764c960.tar.gz sncontinue-9a0cd644dcb5ff46817a6ea686a6de0fb764c960.tar.bz2 sncontinue-9a0cd644dcb5ff46817a6ea686a6de0fb764c960.zip |
fix: :bug: async with Client (meilisearch)
Diffstat (limited to 'continuedev')
3 files changed, 32 insertions, 33 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 7d302656..4a141830 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -50,22 +50,23 @@ class ContextProvider(BaseModel): """ return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] - async def get_item(self, id: ContextItemId, query: str, search_client: Client) -> ContextItem: + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: """ Returns the ContextItem with the given id. Default implementation uses the search index to get the item. """ - result = await search_client.index( - SEARCH_INDEX_NAME).get_document(id.to_string()) - return ContextItem( - description=ContextItemDescription( - name=result["name"], - description=result["description"], - id=id - ), - content=result["content"] - ) + async with Client('http://localhost:7700') as search_client: + result = await search_client.index( + SEARCH_INDEX_NAME).get_document(id.to_string()) + return ContextItem( + description=ContextItemDescription( + name=result["name"], + description=result["description"], + id=id + ), + content=result["content"] + ) async def delete_context_with_ids(self, ids: List[ContextItemId]): """ @@ -85,7 +86,7 @@ class ContextProvider(BaseModel): """ self.selected_items = [] - async def add_context_item(self, id: ContextItemId, query: str, search_client: Client): + async def add_context_item(self, id: ContextItemId, query: str): """ Adds the given ContextItem to the list of ContextItems. @@ -99,7 +100,7 @@ class ContextProvider(BaseModel): if item.description.id.item_id == id.item_id: return - new_item = await self.get_item(id, query, search_client) + new_item = await self.get_item(id, query) self.selected_items.append(new_item) @@ -126,10 +127,7 @@ class ContextManager: """ return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) - search_client: Client - - def __init__(self, context_providers: List[ContextProvider], search_client: Client): - self.search_client = search_client + def __init__(self, context_providers: List[ContextProvider]): self.context_providers = { prov.title: prov for prov in context_providers} self.provider_titles = { @@ -137,14 +135,15 @@ class ContextManager: @classmethod async def create(cls, context_providers: List[ContextProvider]): - search_client = Client('http://localhost:7700') - health = await search_client.health() - if not health.status == "available": - print("MeiliSearch not running, avoiding any dependent context providers") - context_providers = list( - filter(lambda cp: cp.title == "code", context_providers)) + async with Client('http://localhost:7700') as search_client: + health = await search_client.health() + if not health.status == "available": + print( + "MeiliSearch not running, avoiding any dependent context providers") + context_providers = list( + filter(lambda cp: cp.title == "code", context_providers)) - return cls(context_providers, search_client) + return cls(context_providers) async def load_index(self): for _, provider in self.context_providers.items(): @@ -159,7 +158,8 @@ class ContextManager: for item in context_items ] if len(documents) > 0: - await self.search_client.index(SEARCH_INDEX_NAME).add_documents(documents) + async with Client('http://localhost:7700') as search_client: + await search_client.index(SEARCH_INDEX_NAME).add_documents(documents) # def compile_chat_messages(self, max_tokens: int) -> List[Dict]: # """ @@ -176,7 +176,7 @@ class ContextManager: raise ValueError( f"Context provider with title {id.provider_title} not found") - await self.context_providers[id.provider_title].add_context_item(id, query, self.search_client) + await self.context_providers[id.provider_title].add_context_item(id, query) async def delete_context_with_ids(self, ids: List[str]): """ diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 426c0804..86c5b7ab 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -1,7 +1,6 @@ import os from typing import Any, Dict, List -from meilisearch_python_async import Client from ...core.main import ChatMessage from ...models.filesystem import RangeInFile, RangeInFileWithContents from ...core.context import ContextItem, ContextItemDescription, ContextItemId @@ -187,5 +186,5 @@ class HighlightedCodeContextProvider(BaseModel): for hr in self.highlighted_ranges: hr.item.editing = hr.item.description.id.to_string() in ids - async def add_context_item(self, id: ContextItemId, query: str, search_client: Client, prev: List[ContextItem] = None) -> List[ContextItem]: + async def add_context_item(self, id: ContextItemId, query: str, prev: List[ContextItem] = None) -> List[ContextItem]: raise NotImplementedError() diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 840a4b77..90714455 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -53,11 +53,11 @@ async def check_meilisearch_running() -> bool: """ try: - client = Client('http://localhost:7700') - resp = await client.health() - if resp["status"] != "available": - return False - return True + async with Client('http://localhost:7700') as client: + resp = await client.health() + if resp["status"] != "available": + return False + return True except Exception: return False |