summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/pyproject.toml2
-rw-r--r--continuedev/src/continuedev/core/autopilot.py15
-rw-r--r--continuedev/src/continuedev/core/context.py48
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py122
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py22
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py3
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py16
9 files changed, 197 insertions, 38 deletions
diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml
index 49b3c5ed..90ff0572 100644
--- a/continuedev/pyproject.toml
+++ b/continuedev/pyproject.toml
@@ -9,7 +9,7 @@ readme = "README.md"
python = "^3.8.1"
fastapi = "^0.95.1"
typer = "^0.7.0"
-openai = "^0.27.8"
+openai = "^0.27.5"
boltons = "^23.0.0"
pydantic = "^1.10.7"
uvicorn = "^0.21.1"
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 256f3439..9100c34e 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -81,11 +81,7 @@ class Autopilot(ContinueBaseModel):
self.continue_sdk.config.context_providers + [
HighlightedCodeContextProvider(ide=self.ide),
FileContextProvider(workspace_dir=self.ide.workspace_directory)
- ])
-
- logger.debug("Loading index")
- create_async_task(self.context_manager.load_index(
- self.ide.workspace_directory))
+ ], self.ide.workspace_directory)
if full_state is not None:
self.history = full_state.history
@@ -188,6 +184,9 @@ class Autopilot(ContinueBaseModel):
await self._run_singular_step(step)
async def handle_highlighted_code(self, range_in_files: List[RangeInFileWithContents]):
+ if "code" not in self.context_manager.context_providers:
+ return
+
# Add to context manager
await self.context_manager.context_providers["code"].handle_highlighted_code(
range_in_files)
@@ -212,10 +211,16 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
async def toggle_adding_highlighted_code(self):
+ if "code" not in self.context_manager.context_providers:
+ return
+
self.context_manager.context_providers["code"].adding_highlighted_code = not self.context_manager.context_providers["code"].adding_highlighted_code
await self.update_subscribers()
async def set_editing_at_ids(self, ids: List[str]):
+ if "code" not in self.context_manager.context_providers:
+ return
+
await self.context_manager.context_providers["code"].set_editing_at_ids(ids)
await self.update_subscribers()
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index b1f68b50..db1c770a 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -8,9 +8,10 @@ from pydantic import BaseModel
from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
-from ..server.meilisearch_server import check_meilisearch_running
+from ..server.meilisearch_server import poll_meilisearch_running
from ..libs.util.logging import logger
from ..libs.util.telemetry import posthog_logger
+from ..libs.util.create_async_task import create_async_task
SEARCH_INDEX_NAME = "continue_context_items"
@@ -140,31 +141,32 @@ class ContextManager:
self.context_providers = {}
self.provider_titles = set()
- async def start(self, context_providers: List[ContextProvider]):
+ async def start(self, context_providers: List[ContextProvider], workspace_directory: str):
"""
Starts the context manager.
"""
+ # Use only non-meilisearch-dependent providers until it is loaded
self.context_providers = {
- prov.title: prov for prov in context_providers}
+ title: provider for title, provider in self.context_providers.items() if title == "code"
+ }
self.provider_titles = {
provider.title for provider in context_providers}
- async with Client('http://localhost:7700') as search_client:
- meilisearch_running = True
+ # Start MeiliSearch in the background without blocking
+ async def start_meilisearch(context_providers):
try:
-
- health = await search_client.health()
- if not health.status == "available":
- meilisearch_running = False
- except:
- meilisearch_running = False
-
- if not meilisearch_running:
+ await asyncio.wait_for(poll_meilisearch_running(), timeout=20)
+ self.context_providers = {
+ prov.title: prov for prov in context_providers}
+ logger.debug("Loading Meilisearch index...")
+ await self.load_index(workspace_directory)
+ logger.debug("Loaded Meilisearch index")
+ except asyncio.TimeoutError:
+ logger.warning("MeiliSearch did not start within 5 seconds")
logger.warning(
"MeiliSearch not running, avoiding any dependent context providers")
- self.context_providers = {
- title: provider for title, provider in self.context_providers.items() if title == "code"
- }
+
+ create_async_task(start_meilisearch(context_providers))
async def load_index(self, workspace_dir: str):
for _, provider in self.context_providers.items():
@@ -176,14 +178,24 @@ class ContextManager:
"id": item.description.id.to_string(),
"name": item.description.name,
"description": item.description.description,
- "content": item.content
+ "content": item.content,
+ "workspace_dir": workspace_dir,
}
for item in context_items
]
if len(documents) > 0:
try:
async with Client('http://localhost:7700') as search_client:
- await asyncio.wait_for(search_client.index(SEARCH_INDEX_NAME).add_documents(documents), timeout=5)
+ # First, create the index if it doesn't exist
+ await search_client.create_index(SEARCH_INDEX_NAME)
+ # The index is currently shared by all workspaces
+ globalSearchIndex = await search_client.get_index(SEARCH_INDEX_NAME)
+ await asyncio.wait_for(asyncio.gather(
+ # Ensure that the index has the correct filterable attributes
+ globalSearchIndex.update_filterable_attributes(
+ ["workspace_dir"]),
+ globalSearchIndex.add_documents(documents)
+ ), timeout=5)
except Exception as e:
logger.debug(f"Error loading meilisearch index: {e}")
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 2f131354..25a61e63 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -82,7 +82,10 @@ class GGML(LLM):
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
+ yield {
+ "role": "assistant",
+ "content": json.loads(chunk[6:])["choices"][0]["delta"]
+ }
except:
raise Exception(str(line[0]))
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
index 235fd906..0dd359e7 100644
--- a/continuedev/src/continuedev/libs/llm/replicate.py
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -25,7 +25,7 @@ class ReplicateLLM(LLM):
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+ return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}
def count_tokens(self, text: str):
return count_tokens(self.name, text)
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
new file mode 100644
index 00000000..c3f171c9
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -0,0 +1,122 @@
+import json
+from typing import Any, Coroutine, Dict, Generator, List, Union
+
+import aiohttp
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
+
+
+class TogetherLLM(LLM):
+ # this is model-specific
+ api_key: str
+ model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
+ max_context_length: int = 2048
+ base_url: str = "https://api.together.xyz"
+ verify_ssl: bool = True
+
+ _client_session: aiohttp.ClientSession = None
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl))
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ def convert_to_prompt(self, chat_messages: List[ChatMessage]) -> str:
+ system_message = None
+ if chat_messages[0]["role"] == "system":
+ system_message = chat_messages.pop(0)["content"]
+
+ prompt = "\n"
+ if system_message:
+ prompt += f"<human>: Hi!\n<bot>: {system_message}\n"
+ for message in chat_messages:
+ prompt += f'<{"human" if message["role"] == "user" else "bot"}>: {message["content"]}\n'
+
+ prompt += "<bot>:"
+ return prompt
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = self.default_args.copy()
+ args.update(kwargs)
+ args["stream_tokens"] = True
+
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with self._client_session.post(f"{self.base_url}/inference", json={
+ "prompt": self.convert_to_prompt(messages),
+ **args
+ }, headers={
+ "Authorization": f"Bearer {self.api_key}"
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["stream_tokens"] = True
+
+ async with self._client_session.post(f"{self.base_url}/inference", json={
+ "prompt": self.convert_to_prompt(messages),
+ **args
+ }, headers={
+ "Authorization": f"Bearer {self.api_key}"
+ }) as resp:
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield {
+ "role": "assistant",
+ "content": json.loads(chunk[6:])["choices"][0]["text"]
+ }
+ except:
+ raise Exception(str(line[0]))
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ args = {**self.default_args, **kwargs}
+
+ messages = compile_chat_messages(args["model"], with_history, self.context_length,
+ args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ async with self._client_session.post(f"{self.base_url}/inference", json={
+ "prompt": self.convert_to_prompt(messages),
+ **args
+ }, headers={
+ "Authorization": f"Bearer {self.api_key}"
+ }) as resp:
+ try:
+ text = await resp.text()
+ j = json.loads(text)
+ return j["output"]["choices"][0]["text"]
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 31aa5423..b40092af 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -54,33 +54,37 @@ class FileContextProvider(ContextProvider):
list(filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS))
async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
- filepaths = []
+ absolute_filepaths: List[str] = []
for root, dir_names, file_names in os.walk(workspace_dir):
dir_names[:] = [d for d in dir_names if not any(
fnmatch(d, pattern) for pattern in self.ignore_patterns)]
for file_name in file_names:
- filepaths.append(os.path.join(root, file_name))
+ absolute_filepaths.append(os.path.join(root, file_name))
- if len(filepaths) > 1000:
+ if len(absolute_filepaths) > 1000:
break
- if len(filepaths) > 1000:
+ if len(absolute_filepaths) > 1000:
break
items = []
- for file in filepaths:
- content = get_file_contents(file)
+ for absolute_filepath in absolute_filepaths:
+ content = get_file_contents(absolute_filepath)
if content is None:
continue # no pun intended
+
+ relative_to_workspace = os.path.relpath(absolute_filepath, workspace_dir)
items.append(ContextItem(
content=content[:min(2000, len(content))],
description=ContextItemDescription(
- name=os.path.basename(file),
- description=file,
+ name=os.path.basename(absolute_filepath),
+ # We should add the full path to the ContextItem
+ # It warrants a data modeling discussion and has no immediate use case
+ description=relative_to_workspace,
id=ContextItemId(
provider_title=self.title,
- item_id=remove_meilisearch_disallowed_chars(file)
+ item_id=remove_meilisearch_disallowed_chars(absolute_filepath)
)
)
))
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index ec670999..82f885d6 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -39,6 +39,7 @@ class HelpStep(Step):
if question.strip() == "":
self.description = help
else:
+ self.description = "The following output is generated by a language model, which may hallucinate. Type just '/help'to see a fixed answer. You can also learn more by reading [the docs](https://continue.dev/docs).\n\n"
prompt = dedent(f"""
Information:
@@ -48,7 +49,7 @@ class HelpStep(Step):
Please us the information below to provide a succinct answer to the following question: {question}
- Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback.""")
+ Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback. Never refer or link to any URL.""")
self.chat_context.append(ChatMessage(
role="user",
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 7f460afc..f47c08ca 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -1,3 +1,4 @@
+import asyncio
import os
import shutil
import subprocess
@@ -58,15 +59,26 @@ async def check_meilisearch_running() -> bool:
async with Client('http://localhost:7700') as client:
try:
resp = await client.health()
- if resp["status"] != "available":
+ if resp.status != "available":
return False
return True
- except:
+ except Exception as e:
+ logger.debug(e)
return False
except Exception:
return False
+async def poll_meilisearch_running(frequency: int = 0.1) -> bool:
+ """
+ Polls MeiliSearch to see if it is running.
+ """
+ while True:
+ if await check_meilisearch_running():
+ return True
+ await asyncio.sleep(frequency)
+
+
async def start_meilisearch():
"""
Starts the MeiliSearch server, wait for it.