summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-09-28 01:02:52 -0700
committerGitHub <noreply@github.com>2023-09-28 01:02:52 -0700
commit95363a5b52f3bf73531ac76b00178fa79ca97661 (patch)
tree9b9c1614556f1f0d21f363e6a9fe950069affb5d /continuedev
parentd4acf4bb11dbd7d3d6210e2949d21143d721e81e (diff)
downloadsncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.gz
sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.bz2
sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.zip
Past input (#513)
* feat: :construction: use ComboBox in place of UserInputContainer * feat: :construction: adding context to previous inputs steps * feat: :sparkles: preview context items on click * feat: :construction: more work on context items ui * style: :construction: working out the details of ctx item buttons * feat: :sparkles: getting the final details * fix: :bug: fix height of ctx items bar * fix: :bug: last couple of details * fix: :bug: pass model param through to hf inference api * fix: :loud_sound: better logging for timeout * feat: :sparkles: option to set the meilisearch url * fix: :bug: fix height of past inputs
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/__main__.py5
-rw-r--r--continuedev/src/continuedev/core/autopilot.py47
-rw-r--r--continuedev/src/continuedev/core/context.py26
-rw-r--r--continuedev/src/continuedev/core/main.py2
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py10
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py7
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/highlighted_code.py11
-rw-r--r--continuedev/src/continuedev/server/gui.py35
-rw-r--r--continuedev/src/continuedev/server/ide.py24
-rw-r--r--continuedev/src/continuedev/server/main.py23
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py30
11 files changed, 160 insertions, 60 deletions
diff --git a/continuedev/src/continuedev/__main__.py b/continuedev/src/continuedev/__main__.py
index 1974d87c..caaba117 100644
--- a/continuedev/src/continuedev/__main__.py
+++ b/continuedev/src/continuedev/__main__.py
@@ -12,6 +12,9 @@ app = typer.Typer()
def main(
port: int = typer.Option(65432, help="server port"),
host: str = typer.Option("127.0.0.1", help="server host"),
+ meilisearch_url: Optional[str] = typer.Option(
+ None, help="The URL of the MeiliSearch server if running manually"
+ ),
config: Optional[str] = typer.Option(
None, help="The path to the configuration file"
),
@@ -20,7 +23,7 @@ def main(
if headless:
run(config)
else:
- run_server(port=port, host=host)
+ run_server(port=port, host=host, meilisearch_url=meilisearch_url)
if __name__ == "__main__":
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 0155e755..9ebf288b 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -37,7 +37,7 @@ from ..plugins.steps.core.core import (
)
from ..plugins.steps.on_traceback import DefaultOnTracebackStep
from ..server.ide_protocol import AbstractIdeProtocolServer
-from ..server.meilisearch_server import stop_meilisearch
+from ..server.meilisearch_server import get_meilisearch_url, stop_meilisearch
from .config import ContinueConfig
from .context import ContextManager
from .main import (
@@ -179,6 +179,7 @@ class Autopilot(ContinueBaseModel):
config=self.continue_sdk.config,
saved_context_groups=self._saved_context_groups,
context_providers=self.context_manager.get_provider_descriptions(),
+ meilisearch_url=get_meilisearch_url(),
)
self.full_state = full_state
return full_state
@@ -306,7 +307,8 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
async def edit_step_at_index(self, user_input: str, index: int):
- step_to_rerun = self.history.timeline[index].step.copy()
+ node_to_rerun = self.history.timeline[index].copy()
+ step_to_rerun = node_to_rerun.step
step_to_rerun.user_input = user_input
step_to_rerun.description = user_input
@@ -318,13 +320,29 @@ class Autopilot(ContinueBaseModel):
node_to_delete.deleted = True
self.history.current_index = index - 1
+
+ # Set the context to the context used by that step
+ await self.context_manager.clear_context()
+ for context_item in node_to_rerun.context_used:
+ await self.context_manager.manually_add_context_item(context_item)
+
await self.update_subscribers()
# Rerun from the current step
await self.run_from_step(step_to_rerun)
- async def delete_context_with_ids(self, ids: List[str]):
- await self.context_manager.delete_context_with_ids(ids)
+ async def delete_context_with_ids(
+ self, ids: List[str], index: Optional[int] = None
+ ):
+ if index is None:
+ await self.context_manager.delete_context_with_ids(ids)
+ else:
+ self.history.timeline[index].context_used = list(
+ filter(
+ lambda item: item.description.id.to_string() not in ids,
+ self.history.timeline[index].context_used,
+ )
+ )
await self.update_subscribers()
async def toggle_adding_highlighted_code(self):
@@ -380,7 +398,12 @@ class Autopilot(ContinueBaseModel):
# Update history - do this first so we get top-first tree ordering
index_of_history_node = self.history.add_node(
- HistoryNode(step=step, observation=None, depth=self._step_depth)
+ HistoryNode(
+ step=step,
+ observation=None,
+ depth=self._step_depth,
+ context_used=await self.context_manager.get_selected_items(),
+ )
)
# Call all subscribed callbacks
@@ -600,7 +623,7 @@ class Autopilot(ContinueBaseModel):
async def accept_user_input(self, user_input: str):
self._main_user_input_queue.append(user_input)
- await self.update_subscribers()
+ # await self.update_subscribers()
if len(self._main_user_input_queue) > 1:
return
@@ -609,7 +632,7 @@ class Autopilot(ContinueBaseModel):
# Just run the step that takes user input, and
# then up to the policy to decide how to deal with it.
self._main_user_input_queue.pop(0)
- await self.update_subscribers()
+ # await self.update_subscribers()
await self.run_from_step(UserInputStep(user_input=user_input))
while len(self._main_user_input_queue) > 0:
@@ -635,6 +658,16 @@ class Autopilot(ContinueBaseModel):
await self.context_manager.select_context_item(id, query)
await self.update_subscribers()
+ async def select_context_item_at_index(self, id: str, query: str, index: int):
+ # TODO: This is different from how it works for the main input
+ # Ideally still tracked through the ContextProviders
+ # so they can watch for duplicates
+ context_item = await self.context_manager.get_context_item(id, query)
+ if context_item is None:
+ return
+ self.history.timeline[index].context_used.append(context_item)
+ await self.update_subscribers()
+
async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron):
edit_config_property(key_path, value)
await self.update_subscribers()
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index f2658602..d374dd02 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -10,7 +10,11 @@ from ..libs.util.create_async_task import create_async_task
from ..libs.util.devdata import dev_data_logger
from ..libs.util.logging import logger
from ..libs.util.telemetry import posthog_logger
-from ..server.meilisearch_server import poll_meilisearch_running, restart_meilisearch
+from ..server.meilisearch_server import (
+ get_meilisearch_url,
+ poll_meilisearch_running,
+ restart_meilisearch,
+)
from .main import (
ChatMessage,
ContextItem,
@@ -127,7 +131,7 @@ class ContextProvider(BaseModel):
Default implementation uses the search index to get the item.
"""
- async with Client("http://localhost:7700") as search_client:
+ async with Client(get_meilisearch_url()) as search_client:
try:
result = await search_client.index(SEARCH_INDEX_NAME).get_document(
id.to_string()
@@ -295,7 +299,7 @@ class ContextManager:
}
for item in context_items
]
- async with Client("http://localhost:7700") as search_client:
+ async with Client(get_meilisearch_url()) as search_client:
async def add_docs():
index = await search_client.get_index(SEARCH_INDEX_NAME)
@@ -313,7 +317,7 @@ class ContextManager:
"""
Deletes the documents in the search index.
"""
- async with Client("http://localhost:7700") as search_client:
+ async with Client(get_meilisearch_url()) as search_client:
await asyncio.wait_for(
search_client.index(SEARCH_INDEX_NAME).delete_documents(ids),
timeout=20,
@@ -321,7 +325,7 @@ class ContextManager:
async def load_index(self, workspace_dir: str, should_retry: bool = True):
try:
- async with Client("http://localhost:7700") as search_client:
+ async with Client(get_meilisearch_url()) as search_client:
# First, create the index if it doesn't exist
# The index is currently shared by all workspaces
await search_client.create_index(SEARCH_INDEX_NAME)
@@ -422,6 +426,18 @@ class ContextManager:
)
await self.context_providers[id.provider_title].add_context_item(id, query)
+ async def get_context_item(self, id: str, query: str) -> ContextItem:
+ """
+ Returns the ContextItem with the given id.
+ """
+ id: ContextItemId = ContextItemId.from_string(id)
+ if id.provider_title not in self.provider_titles:
+ raise ValueError(
+ f"Context provider with title {id.provider_title} not found"
+ )
+
+ return await self.context_providers[id.provider_title].get_item(id, query)
+
async def delete_context_with_ids(self, ids: List[str]):
"""
Deletes the ContextItems with the given IDs, lets ContextProviders recalculate.
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index cf41aab9..617a5aaa 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -108,6 +108,7 @@ class HistoryNode(ContinueBaseModel):
deleted: bool = False
active: bool = True
logs: List[str] = []
+ context_used: List["ContextItem"] = []
def to_chat_messages(self) -> List[ChatMessage]:
if self.step.description is None or self.step.manage_own_chat_context:
@@ -312,6 +313,7 @@ class FullState(ContinueBaseModel):
config: ContinueConfig
saved_context_groups: Dict[str, List[ContextItem]] = {}
context_providers: List[ContextProviderDescription] = []
+ meilisearch_url: Optional[str] = None
class ContinueSDK:
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py b/continuedev/src/continuedev/libs/constants/default_config.py
index a1b2de2c..92913001 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py
+++ b/continuedev/src/continuedev/libs/constants/default_config.py
@@ -31,24 +31,24 @@ config = ContinueConfig(
custom_commands=[
CustomCommand(
name="test",
- description="Write unit tests for the highlighted code",
+ description="Write unit tests for highlighted code",
prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.",
)
],
slash_commands=[
SlashCommand(
name="edit",
- description="Edit code in the current file or the highlighted code",
+ description="Edit highlighted code",
step=EditHighlightedCodeStep,
),
SlashCommand(
name="config",
- description="Customize Continue - slash commands, LLMs, system message, etc.",
+ description="Customize Continue",
step=OpenConfigStep,
),
SlashCommand(
name="comment",
- description="Write comments for the current file or highlighted code",
+ description="Write comments for the highlighted code",
step=CommentCodeStep,
),
SlashCommand(
@@ -58,7 +58,7 @@ config = ContinueConfig(
),
SlashCommand(
name="share",
- description="Download and share the session transcript",
+ description="Download and share this session",
step=ShareSessionStep,
),
SlashCommand(
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index a7771018..ab1482e8 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -57,16 +57,15 @@ class HuggingFaceInferenceAPI(LLM):
if "stop" in args:
args["stop_sequences"] = args["stop"]
del args["stop"]
- if "model" in args:
- del args["model"]
+
return args
async def _stream_complete(self, prompt, options):
- self.collect_args(options)
+ args = self.collect_args(options)
client = InferenceClient(self.endpoint_url, token=self.hf_token)
- stream = client.text_generation(prompt, stream=True, details=True)
+ stream = client.text_generation(prompt, stream=True, details=True, **args)
for r in stream:
# skip special tokens
diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
index df82b1ab..bd31531e 100644
--- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
+++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
@@ -257,8 +257,17 @@ class HighlightedCodeContextProvider(ContextProvider):
self._disambiguate_highlighted_ranges()
async def set_editing_at_ids(self, ids: List[str]):
+ # Don't do anything if there are no valid ids here
+ count = 0
for hr in self.highlighted_ranges:
- hr.item.editing = hr.item.description.id.to_string() in ids
+ if hr.item.description.id.item_id in ids:
+ count += 1
+
+ if count == 0:
+ return
+
+ for hr in self.highlighted_ranges:
+ hr.item.editing = hr.item.description.id.item_id in ids
async def add_context_item(
self, id: ContextItemId, query: str, prev: List[ContextItem] = None
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 9d2ea47a..10f6974f 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -104,7 +104,7 @@ class GUIProtocolServer:
elif message_type == "delete_at_index":
self.on_delete_at_index(data["index"])
elif message_type == "delete_context_with_ids":
- self.on_delete_context_with_ids(data["ids"])
+ self.on_delete_context_with_ids(data["ids"], data.get("index", None))
elif message_type == "toggle_adding_highlighted_code":
self.on_toggle_adding_highlighted_code()
elif message_type == "set_editing_at_ids":
@@ -112,9 +112,11 @@ class GUIProtocolServer:
elif message_type == "show_logs_at_index":
self.on_show_logs_at_index(data["index"])
elif message_type == "show_context_virtual_file":
- self.show_context_virtual_file()
+ self.show_context_virtual_file(data.get("index", None))
elif message_type == "select_context_item":
self.select_context_item(data["id"], data["query"])
+ elif message_type == "select_context_item_at_index":
+ self.select_context_item_at_index(data["id"], data["query"], data["index"])
elif message_type == "load_session":
self.load_session(data.get("session_id", None))
elif message_type == "edit_step_at_index":
@@ -171,9 +173,9 @@ class GUIProtocolServer:
self.on_error,
)
- def on_delete_context_with_ids(self, ids: List[str]):
+ def on_delete_context_with_ids(self, ids: List[str], index: Optional[int] = None):
create_async_task(
- self.session.autopilot.delete_context_with_ids(ids), self.on_error
+ self.session.autopilot.delete_context_with_ids(ids, index), self.on_error
)
def on_toggle_adding_highlighted_code(self):
@@ -188,7 +190,7 @@ class GUIProtocolServer:
def on_show_logs_at_index(self, index: int):
name = "Continue Context"
logs = "\n\n############################################\n\n".join(
- ["This is the prompt sent to the LLM during this step"]
+ ["This is the prompt that was sent to the LLM during this step"]
+ self.session.autopilot.continue_sdk.history.timeline[index].logs
)
create_async_task(
@@ -196,12 +198,20 @@ class GUIProtocolServer:
)
posthog_logger.capture_event("show_logs_at_index", {})
- def show_context_virtual_file(self):
+ def show_context_virtual_file(self, index: Optional[int] = None):
async def async_stuff():
- msgs = await self.session.autopilot.continue_sdk.get_chat_context()
+ if index is None:
+ context_items = (
+ await self.session.autopilot.context_manager.get_selected_items()
+ )
+ elif index < len(self.session.autopilot.continue_sdk.history.timeline):
+ context_items = self.session.autopilot.continue_sdk.history.timeline[
+ index
+ ].context_used
+
ctx = "\n\n-----------------------------------\n\n".join(
- ["This is the exact context that will be passed to the LLM"]
- + list(map(lambda x: x.content, msgs))
+ ["These are the context items that will be passed to the LLM"]
+ + list(map(lambda x: x.content, context_items))
)
await self.session.autopilot.ide.showVirtualFile(
"Continue - Selected Context", ctx
@@ -218,6 +228,13 @@ class GUIProtocolServer:
self.session.autopilot.select_context_item(id, query), self.on_error
)
+ def select_context_item_at_index(self, id: str, query: str, index: int):
+ """Called when user selects an item from the dropdown for prev UserInputStep"""
+ create_async_task(
+ self.session.autopilot.select_context_item_at_index(id, query, index),
+ self.on_error,
+ )
+
def load_session(self, session_id: Optional[str] = None):
async def load_and_tell_to_reconnect():
new_session_id = await session_manager.load_session(
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index d4f0690b..32bd0f0c 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -12,6 +12,7 @@ from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect, WebSocketState
from uvicorn.main import Server
+from ..core.main import ContinueCustomException
from ..libs.util.create_async_task import create_async_task
from ..libs.util.devdata import dev_data_logger
from ..libs.util.logging import logger
@@ -39,7 +40,6 @@ from ..models.filesystem_edit import (
from ..plugins.steps.core.core import DisplayErrorStep
from .gui import session_manager
from .ide_protocol import AbstractIdeProtocolServer
-from .meilisearch_server import start_meilisearch
from .session_manager import SessionManager
nest_asyncio.apply()
@@ -201,21 +201,24 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
except RuntimeError as e:
logger.warning(f"Error sending IDE message, websocket probably closed: {e}")
- async def _receive_json(self, message_type: str, timeout: int = 20) -> Any:
+ async def _receive_json(
+ self, message_type: str, timeout: int = 20, message=None
+ ) -> Any:
try:
return await asyncio.wait_for(
self.sub_queue.get(message_type), timeout=timeout
)
except asyncio.TimeoutError:
- raise Exception(
- f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}"
+ raise ContinueCustomException(
+ title=f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}",
+ message=f"IDE Protocol _receive_json timed out after 20 seconds. The message sent was: {message or ''}",
)
async def _send_and_receive_json(
self, data: Any, resp_model: Type[T], message_type: str
) -> T:
await self._send_json(message_type, data)
- resp = await self._receive_json(message_type)
+ resp = await self._receive_json(message_type, message=data)
return resp_model.parse_obj(resp)
async def handle_json(self, message_type: str, data: Any):
@@ -597,17 +600,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
logger.debug(f"Accepted websocket connection from {websocket.client}")
await websocket.send_json({"messageType": "connected", "data": {}})
- # Start meilisearch
- try:
-
- async def on_err(e):
- logger.debug(f"Failed to start MeiliSearch: {e}")
-
- create_async_task(start_meilisearch(), on_err)
- except Exception as e:
- logger.debug("Failed to start MeiliSearch")
- logger.debug(e)
-
# Message handler
def handle_msg(msg):
message = json.loads(msg)
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index bbae2bb2..aa6c8944 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -1,6 +1,8 @@
import argparse
import asyncio
import atexit
+from contextlib import asynccontextmanager
+from typing import Optional
import uvicorn
from fastapi import FastAPI
@@ -9,10 +11,21 @@ from fastapi.middleware.cors import CORSMiddleware
from ..libs.util.logging import logger
from .gui import router as gui_router
from .ide import router as ide_router
+from .meilisearch_server import start_meilisearch, stop_meilisearch
from .session_manager import router as sessions_router
from .session_manager import session_manager
-app = FastAPI()
+meilisearch_url_global = None
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ await start_meilisearch(url=meilisearch_url_global)
+ yield
+ stop_meilisearch()
+
+
+app = FastAPI(lifespan=lifespan)
app.include_router(ide_router)
app.include_router(gui_router)
@@ -34,7 +47,13 @@ def health():
return {"status": "ok"}
-def run_server(port: int = 65432, host: str = "127.0.0.1"):
+def run_server(
+ port: int = 65432, host: str = "127.0.0.1", meilisearch_url: Optional[str] = None
+):
+ global meilisearch_url_global
+
+ meilisearch_url_global = meilisearch_url
+
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
server.run()
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 5e6cdd53..8929b69d 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -2,9 +2,11 @@ import asyncio
import os
import shutil
import subprocess
+from typing import Optional
import aiofiles
import aiohttp
+import psutil
from meilisearch_python_async import Client
from ..libs.util.logging import logger
@@ -89,13 +91,22 @@ async def ensure_meilisearch_installed() -> bool:
return True
+meilisearch_process = None
+DEFAULT_MEILISEARCH_URL = "http://localhost:7700"
+meilisearch_url = DEFAULT_MEILISEARCH_URL
+
+
+def get_meilisearch_url():
+ return meilisearch_url
+
+
async def check_meilisearch_running() -> bool:
"""
Checks if MeiliSearch is running.
"""
try:
- async with Client("http://localhost:7700") as client:
+ async with Client(meilisearch_url) as client:
try:
resp = await client.health()
if resp.status != "available":
@@ -117,14 +128,16 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool:
await asyncio.sleep(frequency)
-meilisearch_process = None
-
-
-async def start_meilisearch():
+async def start_meilisearch(url: Optional[str] = None):
"""
Starts the MeiliSearch server, wait for it.
"""
- global meilisearch_process
+ global meilisearch_process, meilisearch_url
+
+ if url is not None:
+ logger.debug("Using MeiliSearch at URL: " + url)
+ meilisearch_url = url
+ return
serverPath = getServerFolderPath()
@@ -157,9 +170,6 @@ def stop_meilisearch():
meilisearch_process = None
-import psutil
-
-
def kill_proc(port):
for proc in psutil.process_iter():
try:
@@ -180,4 +190,4 @@ def kill_proc(port):
async def restart_meilisearch():
stop_meilisearch()
kill_proc(7700)
- await start_meilisearch()
+ await start_meilisearch(url=meilisearch_url)