diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 119 | 
1 files changed, 74 insertions, 45 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 2f250839..e5d6e13b 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -1,17 +1,16 @@ - -from abc import abstractmethod  import asyncio  import time +from abc import abstractmethod  from typing import Dict, List +  from meilisearch_python_async import Client  from pydantic import BaseModel - -from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId -from ..server.meilisearch_server import poll_meilisearch_running +from ..libs.util.create_async_task import create_async_task  from ..libs.util.logging import logger  from ..libs.util.telemetry import posthog_logger -from ..libs.util.create_async_task import create_async_task +from ..server.meilisearch_server import poll_meilisearch_running +from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId  SEARCH_INDEX_NAME = "continue_context_items" @@ -52,7 +51,14 @@ class ContextProvider(BaseModel):          Default implementation has a string template.          """ -        return [ChatMessage(role="user", content=f"{item.description.name}: {item.description.description}\n\n{item.content}", summary=item.description.description) for item in await self.get_selected_items()] +        return [ +            ChatMessage( +                role="user", +                content=f"{item.description.name}: {item.description.description}\n\n{item.content}", +                summary=item.description.description, +            ) +            for item in await self.get_selected_items() +        ]      async def get_item(self, id: ContextItemId, query: str) -> ContextItem:          """ @@ -60,21 +66,19 @@ class ContextProvider(BaseModel):          Default implementation uses the search index to get the item.          """ -        async with Client('http://localhost:7700') as search_client: +        async with Client("http://localhost:7700") as search_client:              try: -                result = await search_client.index( -                    SEARCH_INDEX_NAME).get_document(id.to_string()) +                result = await search_client.index(SEARCH_INDEX_NAME).get_document( +                    id.to_string() +                )                  return ContextItem(                      description=ContextItemDescription( -                        name=result["name"], -                        description=result["description"], -                        id=id +                        name=result["name"], description=result["description"], id=id                      ), -                    content=result["content"] +                    content=result["content"],                  )              except Exception as e: -                logger.warning( -                    f"Error while retrieving document from meilisearch: {e}") +                logger.warning(f"Error while retrieving document from meilisearch: {e}")              return None @@ -86,7 +90,11 @@ class ContextProvider(BaseModel):          """          id_strings = {id.to_string() for id in ids}          self.selected_items = list( -            filter(lambda item: item.description.id.to_string() not in id_strings, self.selected_items)) +            filter( +                lambda item: item.description.id.to_string() not in id_strings, +                self.selected_items, +            ) +        )      async def clear_context(self):          """ @@ -129,62 +137,76 @@ class ContextManager:          """          Returns all of the selected ContextItems.          """ -        return sum([await provider.get_selected_items() for provider in self.context_providers.values()], []) +        return sum( +            [ +                await provider.get_selected_items() +                for provider in self.context_providers.values() +            ], +            [], +        )      async def get_chat_messages(self) -> List[ChatMessage]:          """          Returns chat messages from each provider.          """ -        return sum([await provider.get_chat_messages() for provider in self.context_providers.values()], []) +        return sum( +            [ +                await provider.get_chat_messages() +                for provider in self.context_providers.values() +            ], +            [], +        )      def __init__(self):          self.context_providers = {}          self.provider_titles = set() -    async def start(self, context_providers: List[ContextProvider], workspace_directory: str): +    async def start( +        self, context_providers: List[ContextProvider], workspace_directory: str +    ):          """          Starts the context manager.          """          # Use only non-meilisearch-dependent providers until it is loaded          self.context_providers = { -            provider.title: provider for provider in context_providers if provider.title == "code" +            provider.title: provider +            for provider in context_providers +            if provider.title == "code"          } -        self.provider_titles = { -            provider.title for provider in context_providers} +        self.provider_titles = {provider.title for provider in context_providers}          # Start MeiliSearch in the background without blocking          async def start_meilisearch(context_providers):              try:                  await asyncio.wait_for(poll_meilisearch_running(), timeout=20)                  self.context_providers = { -                    prov.title: prov for prov in context_providers} +                    prov.title: prov for prov in context_providers +                }                  logger.debug("Loading Meilisearch index...")                  await self.load_index(workspace_directory)                  logger.debug("Loaded Meilisearch index")              except asyncio.TimeoutError:                  logger.warning("MeiliSearch did not start within 5 seconds")                  logger.warning( -                    "MeiliSearch not running, avoiding any dependent context providers") +                    "MeiliSearch not running, avoiding any dependent context providers" +                )          create_async_task(start_meilisearch(context_providers))      async def load_index(self, workspace_dir: str):          try: -            async with Client('http://localhost:7700') as search_client: +            async with Client("http://localhost:7700") as search_client:                  # First, create the index if it doesn't exist                  # The index is currently shared by all workspaces                  await search_client.create_index(SEARCH_INDEX_NAME)                  globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME) -                await globalSearchIndex.update_ranking_rules([ -                    "attribute", -                    "words", -                    "typo", -                    "proximity", -                    "sort", -                    "exactness" -                ]) -                await globalSearchIndex.update_filterable_attributes( -                    ["workspace_dir"]) +                await globalSearchIndex.update_ranking_rules( +                    ["attribute", "words", "typo", "proximity", "sort", "exactness"] +                ) +                await globalSearchIndex.update_searchable_attributes( +                    ["name", "description"] +                ) +                await globalSearchIndex.update_filterable_attributes(["workspace_dir"])                  for _, provider in self.context_providers.items():                      ti = time.time() @@ -201,11 +223,14 @@ class ContextManager:                          for item in context_items                      ]                      if len(documents) > 0: -                        await asyncio.wait_for(globalSearchIndex.add_documents(documents), timeout=5) +                        await asyncio.wait_for( +                            globalSearchIndex.add_documents(documents), timeout=5 +                        )                      tf = time.time()                      logger.debug( -                        f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}") +                        f"Loaded {len(documents)} documents into meilisearch in {tf - ti} seconds for context provider {provider.title}" +                    )          except Exception as e:              logger.debug(f"Error loading meilisearch index: {e}") @@ -216,13 +241,17 @@ class ContextManager:          id: ContextItemId = ContextItemId.from_string(id)          if id.provider_title not in self.provider_titles:              raise ValueError( -                f"Context provider with title {id.provider_title} not found") - -        posthog_logger.capture_event("select_context_item", { -            "provider_title": id.provider_title, -            "item_id": id.item_id, -            "query": query -        }) +                f"Context provider with title {id.provider_title} not found" +            ) + +        posthog_logger.capture_event( +            "select_context_item", +            { +                "provider_title": id.provider_title, +                "item_id": id.item_id, +                "query": query, +            }, +        )          await self.context_providers[id.provider_title].add_context_item(id, query)      async def delete_context_with_ids(self, ids: List[str]):  | 
