diff options
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/pyproject.toml | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 15 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 48 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ggml.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/replicate.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 122 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/file.py | 22 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/steps/help.py | 3 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 16 |
9 files changed, 197 insertions, 38 deletions
diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 49b3c5ed..90ff0572 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -9,7 +9,7 @@ readme = "README.md" python = "^3.8.1" fastapi = "^0.95.1" typer = "^0.7.0" -openai = "^0.27.8" +openai = "^0.27.5" boltons = "^23.0.0" pydantic = "^1.10.7" uvicorn = "^0.21.1" diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 256f3439..9100c34e 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -81,11 +81,7 @@ class Autopilot(ContinueBaseModel): self.continue_sdk.config.context_providers + [ HighlightedCodeContextProvider(ide=self.ide), FileContextProvider(workspace_dir=self.ide.workspace_directory) - ]) - - logger.debug("Loading index") - create_async_task(self.context_manager.load_index( - self.ide.workspace_directory)) + ], self.ide.workspace_directory) if full_state is not None: self.history = full_state.history @@ -188,6 +184,9 @@ class Autopilot(ContinueBaseModel): await self._run_singular_step(step) async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]): + if "code" not in self.context_manager.context_providers: + return + # Add to context manager await self.context_manager.context_providers["code"].handle_highlighted_code( range_in_files) @@ -212,10 +211,16 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() async def toggle_adding_highlighted_code(self): + if "code" not in self.context_manager.context_providers: + return + self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code await self.update_subscribers() async def set_editing_at_ids(self, ids: List[str]): + if "code" not in self.context_manager.context_providers: + return + await self.context_manager.context_providers["code"].set_editing_at_ids(ids) await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index b1f68b50..db1c770a 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -8,9 +8,10 @@ from pydantic import BaseModel from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId -from ..server.meilisearch_server import check_meilisearch_running +from ..server.meilisearch_server import poll_meilisearch_running from ..libs.util.logging import logger from ..libs.util.telemetry import posthog_logger +from ..libs.util.create_async_task import create_async_task SEARCH_INDEX_NAME = "continue_context_items" @@ -140,31 +141,32 @@ class ContextManager: self.context_providers = {} self.provider_titles = set() - async def start(self, context_providers: List[ContextProvider]): + async def start(self, context_providers: List[ContextProvider], workspace_directory: str): """ Starts the context manager. """ + # Use only non-meilisearch-dependent providers until it is loaded self.context_providers = { - prov.title: prov for prov in context_providers} + title: provider for title, provider in self.context_providers.items() if title == "code" + } self.provider_titles = { provider.title for provider in context_providers} - async with Client('http://localhost:7700') as search_client: - meilisearch_running = True + # Start MeiliSearch in the background without blocking + async def start_meilisearch(context_providers): try: - - health = await search_client.health() - if not health.status == "available": - meilisearch_running = False - except: - meilisearch_running = False - - if not meilisearch_running: + await asyncio.wait_for(poll_meilisearch_running(), timeout=20) + self.context_providers = { + prov.title: prov for prov in context_providers} + logger.debug("Loading Meilisearch index...") + await self.load_index(workspace_directory) + logger.debug("Loaded Meilisearch index") + except asyncio.TimeoutError: + logger.warning("MeiliSearch did not start within 5 seconds") logger.warning( "MeiliSearch not running, avoiding any dependent context providers") - self.context_providers = { - title: provider for title, provider in self.context_providers.items() if title == "code" - } + + create_async_task(start_meilisearch(context_providers)) async def load_index(self, workspace_dir: str): for _, provider in self.context_providers.items(): @@ -176,14 +178,24 @@ class ContextManager: "id": item.description.id.to_string(), "name": item.description.name, "description": item.description.description, - "content": item.content + "content": item.content, + "workspace_dir": workspace_dir, } for item in context_items ] if len(documents) > 0: try: async with Client('http://localhost:7700') as search_client: - await asyncio.wait_for(search_client.index(SEARCH_INDEX_NAME).add_documents(documents), timeout=5) + # First, create the index if it doesn't exist + await search_client.create_index(SEARCH_INDEX_NAME) + # The index is currently shared by all workspaces + globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME) + await asyncio.wait_for(asyncio.gather( + # Ensure that the index has the correct filterable attributes + globalSearchIndex.update_filterable_attributes( + ["workspace_dir"]), + globalSearchIndex.add_documents(documents) + ), timeout=5) except Exception as e: logger.debug(f"Error loading meilisearch index: {e}") diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 2f131354..25a61e63 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -82,7 +82,10 @@ class GGML(LLM): chunks = json_chunk.split("\n") for chunk in chunks: if chunk.strip() != "": - yield json.loads(chunk[6:])["choices"][0]["delta"] + yield { + "role": "assistant", + "content": json.loads(chunk[6:])["choices"][0]["delta"] + } except: raise Exception(str(line[0])) diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index 235fd906..0dd359e7 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -25,7 +25,7 @@ class ReplicateLLM(LLM): @property def default_args(self): - return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} def count_tokens(self, text: str): return count_tokens(self.name, text) diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py new file mode 100644 index 00000000..c3f171c9 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -0,0 +1,122 @@ +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + + +class TogetherLLM(LLM): + # this is model-specific + api_key: str + model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct" + max_context_length: int = 2048 + base_url: str = "https://api.together.xyz" + verify_ssl: bool = True + + _client_session: aiohttp.ClientSession = None + + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) + + async def stop(self): + await self._client_session.close() + + @property + def name(self): + return self.model + + @property + def context_length(self): + return self.max_context_length + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str: + system_message = None + if chat_messages[0]["role"] == "system": + system_message = chat_messages.pop(0)["content"] + + prompt = "\n" + if system_message: + prompt += f"<human>: Hi!\n<bot>: {system_message}\n" + for message in chat_messages: + prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n' + + prompt += "<bot>:" + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = self.default_args.copy() + args.update(kwargs) + args["stream_tokens"] = True + + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) + + async with self._client_session.post(f"{self.base_url}/inference", json={ + "prompt": self.convert_to_prompt(messages), + **args + }, headers={ + "Authorization": f"Bearer {self.api_key}" + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + yield line.decode("utf-8") + except: + raise Exception(str(line)) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) + args["stream_tokens"] = True + + async with self._client_session.post(f"{self.base_url}/inference", json={ + "prompt": self.convert_to_prompt(messages), + **args + }, headers={ + "Authorization": f"Bearer {self.api_key}" + }) as resp: + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + yield { + "role": "assistant", + "content": json.loads(chunk[6:])["choices"][0]["text"] + } + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + args = {**self.default_args, **kwargs} + + messages = compile_chat_messages(args["model"], with_history, self.context_length, + args["max_tokens"], prompt, functions=None, system_message=self.system_message) + async with self._client_session.post(f"{self.base_url}/inference", json={ + "prompt": self.convert_to_prompt(messages), + **args + }, headers={ + "Authorization": f"Bearer {self.api_key}" + }) as resp: + try: + text = await resp.text() + j = json.loads(text) + return j["output"]["choices"][0]["text"] + except: + raise Exception(await resp.text()) diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 31aa5423..b40092af 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -54,33 +54,37 @@ class FileContextProvider(ContextProvider): list(filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)) async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: - filepaths = [] + absolute_filepaths: List[str] = [] for root, dir_names, file_names in os.walk(workspace_dir): dir_names[:] = [d for d in dir_names if not any( fnmatch(d, pattern) for pattern in self.ignore_patterns)] for file_name in file_names: - filepaths.append(os.path.join(root, file_name)) + absolute_filepaths.append(os.path.join(root, file_name)) - if len(filepaths) > 1000: + if len(absolute_filepaths) > 1000: break - if len(filepaths) > 1000: + if len(absolute_filepaths) > 1000: break items = [] - for file in filepaths: - content = get_file_contents(file) + for absolute_filepath in absolute_filepaths: + content = get_file_contents(absolute_filepath) if content is None: continue # no pun intended + + relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir) items.append(ContextItem( content=content[:min(2000, len(content))], description=ContextItemDescription( - name=os.path.basename(file), - description=file, + name=os.path.basename(absolute_filepath), + # We should add the full path to the ContextItem + # It warrants a data modeling discussion and has no immediate use case + description=relative_to_workspace, id=ContextItemId( provider_title=self.title, - item_id=remove_meilisearch_disallowed_chars(file) + item_id=remove_meilisearch_disallowed_chars(absolute_filepath) ) ) )) diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index ec670999..82f885d6 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -39,6 +39,7 @@ class HelpStep(Step): if question.strip() == "": self.description = help else: + self.description = "The following output is generated by a language model, which may hallucinate. Type just '/help'to see a fixed answer. You can also learn more by reading [the docs](https://continue.dev/docs).\n\n" prompt = dedent(f""" Information: @@ -48,7 +49,7 @@ class HelpStep(Step): Please us the information below to provide a succinct answer to the following question: {question} - Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback.""") + Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""") self.chat_context.append(ChatMessage( role="user", diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 7f460afc..f47c08ca 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -1,3 +1,4 @@ +import asyncio import os import shutil import subprocess @@ -58,15 +59,26 @@ async def check_meilisearch_running() -> bool: async with Client('http://localhost:7700') as client: try: resp = await client.health() - if resp["status"] != "available": + if resp.status != "available": return False return True - except: + except Exception as e: + logger.debug(e) return False except Exception: return False +async def poll_meilisearch_running(frequency: int = 0.1) -> bool: + """ + Polls MeiliSearch to see if it is running. + """ + while True: + if await check_meilisearch_running(): + return True + await asyncio.sleep(frequency) + + async def start_meilisearch(): """ Starts the MeiliSearch server, wait for it. |