summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-26 12:26:32 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-26 12:26:32 -0700
commit9a0cd644dcb5ff46817a6ea686a6de0fb764c960 (patch)
tree8151897f0bd2c0f0c92e34a10027c25058be4b57 /continuedev
parente69837541db800643f666f6f5a9635b43511295c (diff)
downloadsncontinue-9a0cd644dcb5ff46817a6ea686a6de0fb764c960.tar.gz
sncontinue-9a0cd644dcb5ff46817a6ea686a6de0fb764c960.tar.bz2
sncontinue-9a0cd644dcb5ff46817a6ea686a6de0fb764c960.zip
fix: :bug: async with Client (meilisearch)
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/context.py52
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/highlighted_code.py3
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py10
3 files changed, 32 insertions, 33 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 7d302656..4a141830 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -50,22 +50,23 @@ class ContextProvider(BaseModel):
"""
return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()]
- async def get_item(self, id: ContextItemId, query: str, search_client: Client) -> ContextItem:
+ async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
"""
Returns the ContextItem with the given id.
Default implementation uses the search index to get the item.
"""
- result = await search_client.index(
- SEARCH_INDEX_NAME).get_document(id.to_string())
- return ContextItem(
- description=ContextItemDescription(
- name=result["name"],
- description=result["description"],
- id=id
- ),
- content=result["content"]
- )
+ async with Client('http://localhost:7700') as search_client:
+ result = await search_client.index(
+ SEARCH_INDEX_NAME).get_document(id.to_string())
+ return ContextItem(
+ description=ContextItemDescription(
+ name=result["name"],
+ description=result["description"],
+ id=id
+ ),
+ content=result["content"]
+ )
async def delete_context_with_ids(self, ids: List[ContextItemId]):
"""
@@ -85,7 +86,7 @@ class ContextProvider(BaseModel):
"""
self.selected_items = []
- async def add_context_item(self, id: ContextItemId, query: str, search_client: Client):
+ async def add_context_item(self, id: ContextItemId, query: str):
"""
Adds the given ContextItem to the list of ContextItems.
@@ -99,7 +100,7 @@ class ContextProvider(BaseModel):
if item.description.id.item_id == id.item_id:
return
- new_item = await self.get_item(id, query, search_client)
+ new_item = await self.get_item(id, query)
self.selected_items.append(new_item)
@@ -126,10 +127,7 @@ class ContextManager:
"""
return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], [])
- search_client: Client
-
- def __init__(self, context_providers: List[ContextProvider], search_client: Client):
- self.search_client = search_client
+ def __init__(self, context_providers: List[ContextProvider]):
self.context_providers = {
prov.title: prov for prov in context_providers}
self.provider_titles = {
@@ -137,14 +135,15 @@ class ContextManager:
@classmethod
async def create(cls, context_providers: List[ContextProvider]):
- search_client = Client('http://localhost:7700')
- health = await search_client.health()
- if not health.status == "available":
- print("MeiliSearch not running, avoiding any dependent context providers")
- context_providers = list(
- filter(lambda cp: cp.title == "code", context_providers))
+ async with Client('http://localhost:7700') as search_client:
+ health = await search_client.health()
+ if not health.status == "available":
+ print(
+ "MeiliSearch not running, avoiding any dependent context providers")
+ context_providers = list(
+ filter(lambda cp: cp.title == "code", context_providers))
- return cls(context_providers, search_client)
+ return cls(context_providers)
async def load_index(self):
for _, provider in self.context_providers.items():
@@ -159,7 +158,8 @@ class ContextManager:
for item in context_items
]
if len(documents) > 0:
- await self.search_client.index(SEARCH_INDEX_NAME).add_documents(documents)
+ async with Client('http://localhost:7700') as search_client:
+ await search_client.index(SEARCH_INDEX_NAME).add_documents(documents)
# def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
# """
@@ -176,7 +176,7 @@ class ContextManager:
raise ValueError(
f"Context provider with title {id.provider_title} not found")
- await self.context_providers[id.provider_title].add_context_item(id, query, self.search_client)
+ await self.context_providers[id.provider_title].add_context_item(id, query)
async def delete_context_with_ids(self, ids: List[str]):
"""
diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
index 426c0804..86c5b7ab 100644
--- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
+++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
@@ -1,7 +1,6 @@
import os
from typing import Any, Dict, List
-from meilisearch_python_async import Client
from ...core.main import ChatMessage
from ...models.filesystem import RangeInFile, RangeInFileWithContents
from ...core.context import ContextItem, ContextItemDescription, ContextItemId
@@ -187,5 +186,5 @@ class HighlightedCodeContextProvider(BaseModel):
for hr in self.highlighted_ranges:
hr.item.editing = hr.item.description.id.to_string() in ids
- async def add_context_item(self, id: ContextItemId, query: str, search_client: Client, prev: List[ContextItem] = None) -> List[ContextItem]:
+ async def add_context_item(self, id: ContextItemId, query: str, prev: List[ContextItem] = None) -> List[ContextItem]:
raise NotImplementedError()
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 840a4b77..90714455 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -53,11 +53,11 @@ async def check_meilisearch_running() -> bool:
"""
try:
- client = Client('http://localhost:7700')
- resp = await client.health()
- if resp["status"] != "available":
- return False
- return True
+ async with Client('http://localhost:7700') as client:
+ resp = await client.health()
+ if resp["status"] != "available":
+ return False
+ return True
except Exception:
return False