summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-18 15:13:02 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-18 15:13:02 -0700
commitab7a90a0972188dcc7b8c28b1263c918776ca19d (patch)
tree23bcea6911a566a6bc49b12ea8860d6459a2bc08 /continuedev/src
parent36074201c626281b4e42bcca02c85a8a931c5914 (diff)
downloadsncontinue-ab7a90a0972188dcc7b8c28b1263c918776ca19d.tar.gz
sncontinue-ab7a90a0972188dcc7b8c28b1263c918776ca19d.tar.bz2
sncontinue-ab7a90a0972188dcc7b8c28b1263c918776ca19d.zip
fix: :children_crossing: don't order meilisearch results by contnet
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/context.py119
1 files changed, 74 insertions, 45 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 2f250839..e5d6e13b 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -1,17 +1,16 @@
-
-from abc import abstractmethod
import asyncio
import time
+from abc import abstractmethod
from typing import Dict, List
+
from meilisearch_python_async import Client
from pydantic import BaseModel
-
-from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
-from ..server.meilisearch_server import poll_meilisearch_running
+from ..libs.util.create_async_task import create_async_task
from ..libs.util.logging import logger
from ..libs.util.telemetry import posthog_logger
-from ..libs.util.create_async_task import create_async_task
+from ..server.meilisearch_server import poll_meilisearch_running
+from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
SEARCH_INDEX_NAME = "continue_context_items"
@@ -52,7 +51,14 @@ class ContextProvider(BaseModel):
Default implementation has a string template.
"""
- return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()]
+ return [
+ ChatMessage(
+ role="user",
+ content=f"{item.description.name}: {item.description.description}\n\n{item.content}",
+ summary=item.description.description,
+ )
+ for item in await self.get_selected_items()
+ ]
async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
"""
@@ -60,21 +66,19 @@ class ContextProvider(BaseModel):
Default implementation uses the search index to get the item.
"""
- async with Client('http://localhost:7700') as search_client:
+ async with Client("http://localhost:7700") as search_client:
try:
- result = await search_client.index(
- SEARCH_INDEX_NAME).get_document(id.to_string())
+ result = await search_client.index(SEARCH_INDEX_NAME).get_document(
+ id.to_string()
+ )
return ContextItem(
description=ContextItemDescription(
- name=result["name"],
- description=result["description"],
- id=id
+ name=result["name"], description=result["description"], id=id
),
- content=result["content"]
+ content=result["content"],
)
except Exception as e:
- logger.warning(
- f"Error while retrieving document from meilisearch: {e}")
+ logger.warning(f"Error while retrieving document from meilisearch: {e}")
return None
@@ -86,7 +90,11 @@ class ContextProvider(BaseModel):
"""
id_strings = {id.to_string() for id in ids}
self.selected_items = list(
- filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items))
+ filter(
+ lambda item: item.description.id.to_string() not in id_strings,
+ self.selected_items,
+ )
+ )
async def clear_context(self):
"""
@@ -129,62 +137,76 @@ class ContextManager:
"""
Returns all of the selected ContextItems.
"""
- return sum([await provider.get_selected_items() for provider in self.context_providers.values()], [])
+ return sum(
+ [
+ await provider.get_selected_items()
+ for provider in self.context_providers.values()
+ ],
+ [],
+ )
async def get_chat_messages(self) -> List[ChatMessage]:
"""
Returns chat messages from each provider.
"""
- return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], [])
+ return sum(
+ [
+ await provider.get_chat_messages()
+ for provider in self.context_providers.values()
+ ],
+ [],
+ )
def __init__(self):
self.context_providers = {}
self.provider_titles = set()
- async def start(self, context_providers: List[ContextProvider], workspace_directory: str):
+ async def start(
+ self, context_providers: List[ContextProvider], workspace_directory: str
+ ):
"""
Starts the context manager.
"""
# Use only non-meilisearch-dependent providers until it is loaded
self.context_providers = {
- provider.title: provider for provider in context_providers if provider.title == "code"
+ provider.title: provider
+ for provider in context_providers
+ if provider.title == "code"
}
- self.provider_titles = {
- provider.title for provider in context_providers}
+ self.provider_titles = {provider.title for provider in context_providers}
# Start MeiliSearch in the background without blocking
async def start_meilisearch(context_providers):
try:
await asyncio.wait_for(poll_meilisearch_running(), timeout=20)
self.context_providers = {
- prov.title: prov for prov in context_providers}
+ prov.title: prov for prov in context_providers
+ }
logger.debug("Loading Meilisearch index...")
await self.load_index(workspace_directory)
logger.debug("Loaded Meilisearch index")
except asyncio.TimeoutError:
logger.warning("MeiliSearch did not start within 5 seconds")
logger.warning(
- "MeiliSearch not running, avoiding any dependent context providers")
+ "MeiliSearch not running, avoiding any dependent context providers"
+ )
create_async_task(start_meilisearch(context_providers))
async def load_index(self, workspace_dir: str):
try:
- async with Client('http://localhost:7700') as search_client:
+ async with Client("http://localhost:7700") as search_client:
# First, create the index if it doesn't exist
# The index is currently shared by all workspaces
await search_client.create_index(SEARCH_INDEX_NAME)
globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME)
- await globalSearchIndex.update_ranking_rules([
- "attribute",
- "words",
- "typo",
- "proximity",
- "sort",
- "exactness"
- ])
- await globalSearchIndex.update_filterable_attributes(
- ["workspace_dir"])
+ await globalSearchIndex.update_ranking_rules(
+ ["attribute", "words", "typo", "proximity", "sort", "exactness"]
+ )
+ await globalSearchIndex.update_searchable_attributes(
+ ["name", "description"]
+ )
+ await globalSearchIndex.update_filterable_attributes(["workspace_dir"])
for _, provider in self.context_providers.items():
ti = time.time()
@@ -201,11 +223,14 @@ class ContextManager:
for item in context_items
]
if len(documents) > 0:
- await asyncio.wait_for(globalSearchIndex.add_documents(documents), timeout=5)
+ await asyncio.wait_for(
+ globalSearchIndex.add_documents(documents), timeout=5
+ )
tf = time.time()
logger.debug(
- f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}")
+ f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}"
+ )
except Exception as e:
logger.debug(f"Error loading meilisearch index: {e}")
@@ -216,13 +241,17 @@ class ContextManager:
id: ContextItemId = ContextItemId.from_string(id)
if id.provider_title not in self.provider_titles:
raise ValueError(
- f"Context provider with title {id.provider_title} not found")
-
- posthog_logger.capture_event("select_context_item", {
- "provider_title": id.provider_title,
- "item_id": id.item_id,
- "query": query
- })
+ f"Context provider with title {id.provider_title} not found"
+ )
+
+ posthog_logger.capture_event(
+ "select_context_item",
+ {
+ "provider_title": id.provider_title,
+ "item_id": id.item_id,
+ "query": query,
+ },
+ )
await self.context_providers[id.provider_title].add_context_item(id, query)
async def delete_context_with_ids(self, ids: List[str]):