diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-28 01:02:52 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 01:02:52 -0700 |
commit | 95363a5b52f3bf73531ac76b00178fa79ca97661 (patch) | |
tree | 9b9c1614556f1f0d21f363e6a9fe950069affb5d /continuedev | |
parent | d4acf4bb11dbd7d3d6210e2949d21143d721e81e (diff) | |
download | sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.gz sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.tar.bz2 sncontinue-95363a5b52f3bf73531ac76b00178fa79ca97661.zip |
Past input (#513)
* feat: :construction: use ComboBox in place of UserInputContainer
* feat: :construction: adding context to previous inputs steps
* feat: :sparkles: preview context items on click
* feat: :construction: more work on context items ui
* style: :construction: working out the details of ctx item buttons
* feat: :sparkles: getting the final details
* fix: :bug: fix height of ctx items bar
* fix: :bug: last couple of details
* fix: :bug: pass model param through to hf inference api
* fix: :loud_sound: better logging for timeout
* feat: :sparkles: option to set the meilisearch url
* fix: :bug: fix height of past inputs
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/__main__.py | 5 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 47 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 26 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/main.py | 2 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/constants/default_config.py | 10 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 7 | ||||
-rw-r--r-- | continuedev/src/continuedev/plugins/context_providers/highlighted_code.py | 11 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 35 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/ide.py | 24 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 23 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/meilisearch_server.py | 30 |
11 files changed, 160 insertions, 60 deletions
diff --git a/continuedev/src/continuedev/__main__.py b/continuedev/src/continuedev/__main__.py index 1974d87c..caaba117 100644 --- a/continuedev/src/continuedev/__main__.py +++ b/continuedev/src/continuedev/__main__.py @@ -12,6 +12,9 @@ app = typer.Typer() def main( port: int = typer.Option(65432, help="server port"), host: str = typer.Option("127.0.0.1", help="server host"), + meilisearch_url: Optional[str] = typer.Option( + None, help="The URL of the MeiliSearch server if running manually" + ), config: Optional[str] = typer.Option( None, help="The path to the configuration file" ), @@ -20,7 +23,7 @@ def main( if headless: run(config) else: - run_server(port=port, host=host) + run_server(port=port, host=host, meilisearch_url=meilisearch_url) if __name__ == "__main__": diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 0155e755..9ebf288b 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -37,7 +37,7 @@ from ..plugins.steps.core.core import ( ) from ..plugins.steps.on_traceback import DefaultOnTracebackStep from ..server.ide_protocol import AbstractIdeProtocolServer -from ..server.meilisearch_server import stop_meilisearch +from ..server.meilisearch_server import get_meilisearch_url, stop_meilisearch from .config import ContinueConfig from .context import ContextManager from .main import ( @@ -179,6 +179,7 @@ class Autopilot(ContinueBaseModel): config=self.continue_sdk.config, saved_context_groups=self._saved_context_groups, context_providers=self.context_manager.get_provider_descriptions(), + meilisearch_url=get_meilisearch_url(), ) self.full_state = full_state return full_state @@ -306,7 +307,8 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() async def edit_step_at_index(self, user_input: str, index: int): - step_to_rerun = self.history.timeline[index].step.copy() + node_to_rerun = self.history.timeline[index].copy() + step_to_rerun = node_to_rerun.step step_to_rerun.user_input = user_input step_to_rerun.description = user_input @@ -318,13 +320,29 @@ class Autopilot(ContinueBaseModel): node_to_delete.deleted = True self.history.current_index = index - 1 + + # Set the context to the context used by that step + await self.context_manager.clear_context() + for context_item in node_to_rerun.context_used: + await self.context_manager.manually_add_context_item(context_item) + await self.update_subscribers() # Rerun from the current step await self.run_from_step(step_to_rerun) - async def delete_context_with_ids(self, ids: List[str]): - await self.context_manager.delete_context_with_ids(ids) + async def delete_context_with_ids( + self, ids: List[str], index: Optional[int] = None + ): + if index is None: + await self.context_manager.delete_context_with_ids(ids) + else: + self.history.timeline[index].context_used = list( + filter( + lambda item: item.description.id.to_string() not in ids, + self.history.timeline[index].context_used, + ) + ) await self.update_subscribers() async def toggle_adding_highlighted_code(self): @@ -380,7 +398,12 @@ class Autopilot(ContinueBaseModel): # Update history - do this first so we get top-first tree ordering index_of_history_node = self.history.add_node( - HistoryNode(step=step, observation=None, depth=self._step_depth) + HistoryNode( + step=step, + observation=None, + depth=self._step_depth, + context_used=await self.context_manager.get_selected_items(), + ) ) # Call all subscribed callbacks @@ -600,7 +623,7 @@ class Autopilot(ContinueBaseModel): async def accept_user_input(self, user_input: str): self._main_user_input_queue.append(user_input) - await self.update_subscribers() + # await self.update_subscribers() if len(self._main_user_input_queue) > 1: return @@ -609,7 +632,7 @@ class Autopilot(ContinueBaseModel): # Just run the step that takes user input, and # then up to the policy to decide how to deal with it. self._main_user_input_queue.pop(0) - await self.update_subscribers() + # await self.update_subscribers() await self.run_from_step(UserInputStep(user_input=user_input)) while len(self._main_user_input_queue) > 0: @@ -635,6 +658,16 @@ class Autopilot(ContinueBaseModel): await self.context_manager.select_context_item(id, query) await self.update_subscribers() + async def select_context_item_at_index(self, id: str, query: str, index: int): + # TODO: This is different from how it works for the main input + # Ideally still tracked through the ContextProviders + # so they can watch for duplicates + context_item = await self.context_manager.get_context_item(id, query) + if context_item is None: + return + self.history.timeline[index].context_used.append(context_item) + await self.update_subscribers() + async def set_config_attr(self, key_path: List[str], value: redbaron.RedBaron): edit_config_property(key_path, value) await self.update_subscribers() diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index f2658602..d374dd02 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -10,7 +10,11 @@ from ..libs.util.create_async_task import create_async_task from ..libs.util.devdata import dev_data_logger from ..libs.util.logging import logger from ..libs.util.telemetry import posthog_logger -from ..server.meilisearch_server import poll_meilisearch_running, restart_meilisearch +from ..server.meilisearch_server import ( + get_meilisearch_url, + poll_meilisearch_running, + restart_meilisearch, +) from .main import ( ChatMessage, ContextItem, @@ -127,7 +131,7 @@ class ContextProvider(BaseModel): Default implementation uses the search index to get the item. """ - async with Client("http://localhost:7700") as search_client: + async with Client(get_meilisearch_url()) as search_client: try: result = await search_client.index(SEARCH_INDEX_NAME).get_document( id.to_string() @@ -295,7 +299,7 @@ class ContextManager: } for item in context_items ] - async with Client("http://localhost:7700") as search_client: + async with Client(get_meilisearch_url()) as search_client: async def add_docs(): index = await search_client.get_index(SEARCH_INDEX_NAME) @@ -313,7 +317,7 @@ class ContextManager: """ Deletes the documents in the search index. """ - async with Client("http://localhost:7700") as search_client: + async with Client(get_meilisearch_url()) as search_client: await asyncio.wait_for( search_client.index(SEARCH_INDEX_NAME).delete_documents(ids), timeout=20, @@ -321,7 +325,7 @@ class ContextManager: async def load_index(self, workspace_dir: str, should_retry: bool = True): try: - async with Client("http://localhost:7700") as search_client: + async with Client(get_meilisearch_url()) as search_client: # First, create the index if it doesn't exist # The index is currently shared by all workspaces await search_client.create_index(SEARCH_INDEX_NAME) @@ -422,6 +426,18 @@ class ContextManager: ) await self.context_providers[id.provider_title].add_context_item(id, query) + async def get_context_item(self, id: str, query: str) -> ContextItem: + """ + Returns the ContextItem with the given id. + """ + id: ContextItemId = ContextItemId.from_string(id) + if id.provider_title not in self.provider_titles: + raise ValueError( + f"Context provider with title {id.provider_title} not found" + ) + + return await self.context_providers[id.provider_title].get_item(id, query) + async def delete_context_with_ids(self, ids: List[str]): """ Deletes the ContextItems with the given IDs, lets ContextProviders recalculate. diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index cf41aab9..617a5aaa 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -108,6 +108,7 @@ class HistoryNode(ContinueBaseModel): deleted: bool = False active: bool = True logs: List[str] = [] + context_used: List["ContextItem"] = [] def to_chat_messages(self) -> List[ChatMessage]: if self.step.description is None or self.step.manage_own_chat_context: @@ -312,6 +313,7 @@ class FullState(ContinueBaseModel): config: ContinueConfig saved_context_groups: Dict[str, List[ContextItem]] = {} context_providers: List[ContextProviderDescription] = [] + meilisearch_url: Optional[str] = None class ContinueSDK: diff --git a/continuedev/src/continuedev/libs/constants/default_config.py b/continuedev/src/continuedev/libs/constants/default_config.py index a1b2de2c..92913001 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py +++ b/continuedev/src/continuedev/libs/constants/default_config.py @@ -31,24 +31,24 @@ config = ContinueConfig( custom_commands=[ CustomCommand( name="test", - description="Write unit tests for the highlighted code", + description="Write unit tests for highlighted code", prompt="Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated. Give the tests just as chat output, don't edit any file.", ) ], slash_commands=[ SlashCommand( name="edit", - description="Edit code in the current file or the highlighted code", + description="Edit highlighted code", step=EditHighlightedCodeStep, ), SlashCommand( name="config", - description="Customize Continue - slash commands, LLMs, system message, etc.", + description="Customize Continue", step=OpenConfigStep, ), SlashCommand( name="comment", - description="Write comments for the current file or highlighted code", + description="Write comments for the highlighted code", step=CommentCodeStep, ), SlashCommand( @@ -58,7 +58,7 @@ config = ContinueConfig( ), SlashCommand( name="share", - description="Download and share the session transcript", + description="Download and share this session", step=ShareSessionStep, ), SlashCommand( diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index a7771018..ab1482e8 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -57,16 +57,15 @@ class HuggingFaceInferenceAPI(LLM): if "stop" in args: args["stop_sequences"] = args["stop"] del args["stop"] - if "model" in args: - del args["model"] + return args async def _stream_complete(self, prompt, options): - self.collect_args(options) + args = self.collect_args(options) client = InferenceClient(self.endpoint_url, token=self.hf_token) - stream = client.text_generation(prompt, stream=True, details=True) + stream = client.text_generation(prompt, stream=True, details=True, **args) for r in stream: # skip special tokens diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index df82b1ab..bd31531e 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -257,8 +257,17 @@ class HighlightedCodeContextProvider(ContextProvider): self._disambiguate_highlighted_ranges() async def set_editing_at_ids(self, ids: List[str]): + # Don't do anything if there are no valid ids here + count = 0 for hr in self.highlighted_ranges: - hr.item.editing = hr.item.description.id.to_string() in ids + if hr.item.description.id.item_id in ids: + count += 1 + + if count == 0: + return + + for hr in self.highlighted_ranges: + hr.item.editing = hr.item.description.id.item_id in ids async def add_context_item( self, id: ContextItemId, query: str, prev: List[ContextItem] = None diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 9d2ea47a..10f6974f 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -104,7 +104,7 @@ class GUIProtocolServer: elif message_type == "delete_at_index": self.on_delete_at_index(data["index"]) elif message_type == "delete_context_with_ids": - self.on_delete_context_with_ids(data["ids"]) + self.on_delete_context_with_ids(data["ids"], data.get("index", None)) elif message_type == "toggle_adding_highlighted_code": self.on_toggle_adding_highlighted_code() elif message_type == "set_editing_at_ids": @@ -112,9 +112,11 @@ class GUIProtocolServer: elif message_type == "show_logs_at_index": self.on_show_logs_at_index(data["index"]) elif message_type == "show_context_virtual_file": - self.show_context_virtual_file() + self.show_context_virtual_file(data.get("index", None)) elif message_type == "select_context_item": self.select_context_item(data["id"], data["query"]) + elif message_type == "select_context_item_at_index": + self.select_context_item_at_index(data["id"], data["query"], data["index"]) elif message_type == "load_session": self.load_session(data.get("session_id", None)) elif message_type == "edit_step_at_index": @@ -171,9 +173,9 @@ class GUIProtocolServer: self.on_error, ) - def on_delete_context_with_ids(self, ids: List[str]): + def on_delete_context_with_ids(self, ids: List[str], index: Optional[int] = None): create_async_task( - self.session.autopilot.delete_context_with_ids(ids), self.on_error + self.session.autopilot.delete_context_with_ids(ids, index), self.on_error ) def on_toggle_adding_highlighted_code(self): @@ -188,7 +190,7 @@ class GUIProtocolServer: def on_show_logs_at_index(self, index: int): name = "Continue Context" logs = "\n\n############################################\n\n".join( - ["This is the prompt sent to the LLM during this step"] + ["This is the prompt that was sent to the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs ) create_async_task( @@ -196,12 +198,20 @@ class GUIProtocolServer: ) posthog_logger.capture_event("show_logs_at_index", {}) - def show_context_virtual_file(self): + def show_context_virtual_file(self, index: Optional[int] = None): async def async_stuff(): - msgs = await self.session.autopilot.continue_sdk.get_chat_context() + if index is None: + context_items = ( + await self.session.autopilot.context_manager.get_selected_items() + ) + elif index < len(self.session.autopilot.continue_sdk.history.timeline): + context_items = self.session.autopilot.continue_sdk.history.timeline[ + index + ].context_used + ctx = "\n\n-----------------------------------\n\n".join( - ["This is the exact context that will be passed to the LLM"] - + list(map(lambda x: x.content, msgs)) + ["These are the context items that will be passed to the LLM"] + + list(map(lambda x: x.content, context_items)) ) await self.session.autopilot.ide.showVirtualFile( "Continue - Selected Context", ctx @@ -218,6 +228,13 @@ class GUIProtocolServer: self.session.autopilot.select_context_item(id, query), self.on_error ) + def select_context_item_at_index(self, id: str, query: str, index: int): + """Called when user selects an item from the dropdown for prev UserInputStep""" + create_async_task( + self.session.autopilot.select_context_item_at_index(id, query, index), + self.on_error, + ) + def load_session(self, session_id: Optional[str] = None): async def load_and_tell_to_reconnect(): new_session_id = await session_manager.load_session( diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index d4f0690b..32bd0f0c 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -12,6 +12,7 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect, WebSocketState from uvicorn.main import Server +from ..core.main import ContinueCustomException from ..libs.util.create_async_task import create_async_task from ..libs.util.devdata import dev_data_logger from ..libs.util.logging import logger @@ -39,7 +40,6 @@ from ..models.filesystem_edit import ( from ..plugins.steps.core.core import DisplayErrorStep from .gui import session_manager from .ide_protocol import AbstractIdeProtocolServer -from .meilisearch_server import start_meilisearch from .session_manager import SessionManager nest_asyncio.apply() @@ -201,21 +201,24 @@ class IdeProtocolServer(AbstractIdeProtocolServer): except RuntimeError as e: logger.warning(f"Error sending IDE message, websocket probably closed: {e}") - async def _receive_json(self, message_type: str, timeout: int = 20) -> Any: + async def _receive_json( + self, message_type: str, timeout: int = 20, message=None + ) -> Any: try: return await asyncio.wait_for( self.sub_queue.get(message_type), timeout=timeout ) except asyncio.TimeoutError: - raise Exception( - f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}" + raise ContinueCustomException( + title=f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}", + message=f"IDE Protocol _receive_json timed out after 20 seconds. The message sent was: {message or ''}", ) async def _send_and_receive_json( self, data: Any, resp_model: Type[T], message_type: str ) -> T: await self._send_json(message_type, data) - resp = await self._receive_json(message_type) + resp = await self._receive_json(message_type, message=data) return resp_model.parse_obj(resp) async def handle_json(self, message_type: str, data: Any): @@ -597,17 +600,6 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): logger.debug(f"Accepted websocket connection from {websocket.client}") await websocket.send_json({"messageType": "connected", "data": {}}) - # Start meilisearch - try: - - async def on_err(e): - logger.debug(f"Failed to start MeiliSearch: {e}") - - create_async_task(start_meilisearch(), on_err) - except Exception as e: - logger.debug("Failed to start MeiliSearch") - logger.debug(e) - # Message handler def handle_msg(msg): message = json.loads(msg) diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index bbae2bb2..aa6c8944 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,6 +1,8 @@ import argparse import asyncio import atexit +from contextlib import asynccontextmanager +from typing import Optional import uvicorn from fastapi import FastAPI @@ -9,10 +11,21 @@ from fastapi.middleware.cors import CORSMiddleware from ..libs.util.logging import logger from .gui import router as gui_router from .ide import router as ide_router +from .meilisearch_server import start_meilisearch, stop_meilisearch from .session_manager import router as sessions_router from .session_manager import session_manager -app = FastAPI() +meilisearch_url_global = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await start_meilisearch(url=meilisearch_url_global) + yield + stop_meilisearch() + + +app = FastAPI(lifespan=lifespan) app.include_router(ide_router) app.include_router(gui_router) @@ -34,7 +47,13 @@ def health(): return {"status": "ok"} -def run_server(port: int = 65432, host: str = "127.0.0.1"): +def run_server( + port: int = 65432, host: str = "127.0.0.1", meilisearch_url: Optional[str] = None +): + global meilisearch_url_global + + meilisearch_url_global = meilisearch_url + config = uvicorn.Config(app, host=host, port=port) server = uvicorn.Server(config) server.run() diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 5e6cdd53..8929b69d 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -2,9 +2,11 @@ import asyncio import os import shutil import subprocess +from typing import Optional import aiofiles import aiohttp +import psutil from meilisearch_python_async import Client from ..libs.util.logging import logger @@ -89,13 +91,22 @@ async def ensure_meilisearch_installed() -> bool: return True +meilisearch_process = None +DEFAULT_MEILISEARCH_URL = "http://localhost:7700" +meilisearch_url = DEFAULT_MEILISEARCH_URL + + +def get_meilisearch_url(): + return meilisearch_url + + async def check_meilisearch_running() -> bool: """ Checks if MeiliSearch is running. """ try: - async with Client("http://localhost:7700") as client: + async with Client(meilisearch_url) as client: try: resp = await client.health() if resp.status != "available": @@ -117,14 +128,16 @@ async def poll_meilisearch_running(frequency: int = 0.1) -> bool: await asyncio.sleep(frequency) -meilisearch_process = None - - -async def start_meilisearch(): +async def start_meilisearch(url: Optional[str] = None): """ Starts the MeiliSearch server, wait for it. """ - global meilisearch_process + global meilisearch_process, meilisearch_url + + if url is not None: + logger.debug("Using MeiliSearch at URL: " + url) + meilisearch_url = url + return serverPath = getServerFolderPath() @@ -157,9 +170,6 @@ def stop_meilisearch(): meilisearch_process = None -import psutil - - def kill_proc(port): for proc in psutil.process_iter(): try: @@ -180,4 +190,4 @@ def kill_proc(port): async def restart_meilisearch(): stop_meilisearch() kill_proc(7700) - await start_meilisearch() + await start_meilisearch(url=meilisearch_url) |