diff options
Diffstat (limited to 'continuedev/src')
28 files changed, 492 insertions, 296 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index beb40c75..b4c951b8 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -15,12 +15,13 @@ from ..server.ide_protocol import AbstractIdeProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..models.main import ContinueBaseModel from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode -from ..plugins.steps.core.core import ReversibleStep, ManualEditStep, UserInputStep +from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep from .sdk import ContinueSDK from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback from openai import error as openai_errors from ..libs.util.create_async_task import create_async_task from ..libs.util.telemetry import posthog_logger +from ..libs.util.logging import logger def get_error_title(e: Exception) -> str: @@ -74,7 +75,7 @@ class Autopilot(ContinueBaseModel): HighlightedCodeContextProvider(ide=ide), FileContextProvider(workspace_dir=ide.workspace_directory) ]) - await autopilot.context_manager.load_index() + await autopilot.context_manager.load_index(ide.workspace_directory) return autopilot @@ -152,7 +153,7 @@ class Autopilot(ContinueBaseModel): await self.update_subscribers() except Exception as e: - print(e) + logger.debug(e) def handle_manual_edits(self, edits: List[FileEditWithFullContents]): for edit in edits: @@ -257,7 +258,7 @@ class Autopilot(ContinueBaseModel): e) # Attach an InternalErrorObservation to the step and unhide it. - print( + logger.error( f"Error while running step: \n{error_string}\n{error_title}") posthog_logger.capture_event('step error', { 'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()}) @@ -310,8 +311,8 @@ class Autopilot(ContinueBaseModel): # Update subscribers with new description await self.update_subscribers() - create_async_task(update_description(), - self.continue_sdk.ide.unique_id) + create_async_task(update_description( + ), on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e))) return observation diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index af37264d..4fcab588 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -49,45 +49,6 @@ class ContinueConfig(BaseModel): context_providers: List[ContextProvider] = [] - # Want to force these to be the slash commands for now - @validator('slash_commands', pre=True) - def default_slash_commands_validator(cls, v): - from ..plugins.steps.open_config import OpenConfigStep - from ..plugins.steps.clear_history import ClearHistoryStep - from ..plugins.steps.feedback import FeedbackStep - from ..plugins.steps.comment_code import CommentCodeStep - from ..plugins.steps.main import EditHighlightedCodeStep - - DEFAULT_SLASH_COMMANDS = [ - SlashCommand( - name="edit", - description="Edit code in the current file or the highlighted code", - step=EditHighlightedCodeStep, - ), - SlashCommand( - name="config", - description="Open the config file to create new and edit existing slash commands", - step=OpenConfigStep, - ), - SlashCommand( - name="comment", - description="Write comments for the current file or highlighted code", - step=CommentCodeStep, - ), - SlashCommand( - name="feedback", - description="Send feedback to improve Continue", - step=FeedbackStep, - ), - SlashCommand( - name="clear", - description="Clear step history", - step=ClearHistoryStep, - ) - ] - - return DEFAULT_SLASH_COMMANDS + v - @validator('temperature', pre=True) def temperature_validator(cls, v): return max(0.0, min(1.0, v)) diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 8afbd610..e968c35c 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId from ..server.meilisearch_server import check_meilisearch_running - +from ..libs.util.logging import logger SEARCH_INDEX_NAME = "continue_context_items" @@ -35,7 +35,7 @@ class ContextProvider(BaseModel): return self.selected_items @abstractmethod - async def provide_context_items(self) -> List[ContextItem]: + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: """ Provide documents for search index. This is run on startup. @@ -57,16 +57,22 @@ class ContextProvider(BaseModel): Default implementation uses the search index to get the item. """ async with Client('http://localhost:7700') as search_client: - result = await search_client.index( - SEARCH_INDEX_NAME).get_document(id.to_string()) - return ContextItem( - description=ContextItemDescription( - name=result["name"], - description=result["description"], - id=id - ), - content=result["content"] - ) + try: + result = await search_client.index( + SEARCH_INDEX_NAME).get_document(id.to_string()) + return ContextItem( + description=ContextItemDescription( + name=result["name"], + description=result["description"], + id=id + ), + content=result["content"] + ) + except Exception as e: + logger.warning( + f"Error while retrieving document from meilisearch: {e}") + + return None async def delete_context_with_ids(self, ids: List[ContextItemId]): """ @@ -100,8 +106,8 @@ class ContextProvider(BaseModel): if item.description.id.item_id == id.item_id: return - new_item = await self.get_item(id, query) - self.selected_items.append(new_item) + if new_item := await self.get_item(id, query): + self.selected_items.append(new_item) class ContextManager: @@ -146,16 +152,16 @@ class ContextManager: meilisearch_running = False if not meilisearch_running: - print( + logger.warning( "MeiliSearch not running, avoiding any dependent context providers") context_providers = list( filter(lambda cp: cp.title == "code", context_providers)) return cls(context_providers) - async def load_index(self): + async def load_index(self, workspace_dir: str): for _, provider in self.context_providers.items(): - context_items = await provider.provide_context_items() + context_items = await provider.provide_context_items(workspace_dir) documents = [ { "id": item.description.id.to_string(), @@ -166,8 +172,11 @@ class ContextManager: for item in context_items ] if len(documents) > 0: - async with Client('http://localhost:7700') as search_client: - await search_client.index(SEARCH_INDEX_NAME).add_documents(documents) + try: + async with Client('http://localhost:7700') as search_client: + await search_client.index(SEARCH_INDEX_NAME).add_documents(documents) + except Exception as e: + logger.debug(f"Error loading meilisearch index: {e}") async def select_context_item(self, id: str, query: str): """ diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 1dd4b857..be7008c0 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,5 +1,6 @@ from functools import cached_property -from typing import Coroutine, Dict, Union +import traceback +from typing import Coroutine, Dict, Literal, Union import os from ..plugins.steps.core.core import DefaultModelEditCodeStep @@ -16,6 +17,7 @@ from ..plugins.steps.core.core import * from ..libs.util.telemetry import posthog_logger from ..libs.util.paths import getConfigFilePath from .models import Models +from ..libs.util.logging import logger class Autopilot: @@ -43,11 +45,15 @@ class ContinueSDK(AbstractContinueSDK): config = sdk._load_config_dot_py() sdk.config = config except Exception as e: - print(e) - sdk.config = ContinueConfig() + logger.error(f"Failed to load config.py: {e}") + + sdk.config = ContinueConfig( + ) if sdk._last_valid_config is None else sdk._last_valid_config + + formatted_err = '\n'.join(traceback.format_exception(e)) msg_step = MessageStep( - name="Invalid Continue Config File", message=e.__repr__()) - msg_step.description = e.__repr__() + name="Invalid Continue Config File", message=formatted_err) + msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```" sdk.history.add_node(HistoryNode( step=msg_step, observation=None, @@ -57,6 +63,11 @@ class ContinueSDK(AbstractContinueSDK): sdk.models = sdk.config.models await sdk.models.start(sdk) + + # When the config is loaded, setup posthog logger + posthog_logger.setup( + sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry) + return sdk @property @@ -154,21 +165,14 @@ class ContinueSDK(AbstractContinueSDK): def _load_config_dot_py(self) -> ContinueConfig: # Use importlib to load the config file config.py at the given path path = getConfigFilePath() - try: - import importlib.util - spec = importlib.util.spec_from_file_location("config", path) - config = importlib.util.module_from_spec(spec) - spec.loader.exec_module(config) - self._last_valid_config = config.config - # When the config is loaded, setup posthog logger - posthog_logger.setup( - self.ide.unique_id, config.config.allow_anonymous_telemetry or True) + import importlib.util + spec = importlib.util.spec_from_file_location("config", path) + config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(config) + self._last_valid_config = config.config - return config.config - except Exception as e: - print("Error loading config.py: ", e) - return ContinueConfig() if self._last_valid_config is None else self._last_valid_config + return config.config def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]: highlighted_ranges = self.__autopilot.context_manager.context_providers[ diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py index c27329f0..f09b813a 100644 --- a/continuedev/src/continuedev/libs/chroma/query.py +++ b/continuedev/src/continuedev/libs/chroma/query.py @@ -5,6 +5,7 @@ from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_sto from llama_index.langchain_helpers.text_splitter import TokenTextSplitter import os from .update import filter_ignored_files, load_gpt_index_documents +from ..util.logging import logger from functools import cached_property @@ -56,7 +57,8 @@ class ChromaIndexManager: try: text_chunks = text_splitter.split_text(doc.text) except: - print("ERROR (probably found special token): ", doc.text) + logger.warning( + f"ERROR (probably found special token): {doc.text}") continue filename = doc.extra_info["filename"] chunks[filename] = len(text_chunks) @@ -79,7 +81,7 @@ class ChromaIndexManager: index.storage_context.persist(persist_dir=self.index_dir) - print("Codebase index created") + logger.debug("Codebase index created") def get_modified_deleted_files(self) -> Tuple[List[str], List[str]]: """Get a list of all files that have been modified since the last commit.""" @@ -121,7 +123,7 @@ class ChromaIndexManager: del metadata["chunks"][file] - print(f"Deleted {file}") + logger.debug(f"Deleted {file}") for file in modified_files: @@ -132,7 +134,7 @@ class ChromaIndexManager: for i in range(num_chunks): index.delete(f"{file}::{i}") - print(f"Deleted old version of {file}") + logger.debug(f"Deleted old version of {file}") with open(file, "r") as f: text = f.read() @@ -145,19 +147,20 @@ class ChromaIndexManager: metadata["chunks"][file] = len(text_chunks) - print(f"Inserted new version of {file}") + logger.debug(f"Inserted new version of {file}") metadata["commit"] = self.current_commit with open(f"{self.index_dir}/metadata.json", "w") as f: json.dump(metadata, f, indent=4) - print("Codebase index updated") + logger.debug("Codebase index updated") def query_codebase_index(self, query: str) -> str: """Query the codebase index.""" if not self.check_index_exists(): - print("No index found for the codebase at ", self.index_dir) + logger.debug( + f"No index found for the codebase at {self.index_dir}") return "" storage_context = StorageContext.from_defaults( @@ -180,4 +183,4 @@ class ChromaIndexManager: documents = [Document(info)] index = GPTVectorStoreIndex(documents) index.save_to_disk(f'{self.index_dir}/additional_index.json') - print("Additional index replaced") + logger.debug("Additional index replaced") diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt index e40a2684..be978fd3 100644 --- a/continuedev/src/continuedev/libs/constants/default_config.py.txt +++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt @@ -7,12 +7,18 @@ be sure to select the Python interpreter in ~/.continue/server/env. import subprocess -from continuedev.src.continuedev.core.main import Step -from continuedev.src.continuedev.core.sdk import ContinueSDK -from continuedev.src.continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig -from continuedev.src.continuedev.plugins.context_providers.github import GitHubIssuesContextProvider -from continuedev.src.continuedev.plugins.context_providers.google import GoogleContextProvider -from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from continuedev.core.main import Step +from continuedev.core.sdk import ContinueSDK +from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig +from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider +from continuedev.plugins.context_providers.google import GoogleContextProvider +from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI +from continuedev.plugins.steps.open_config import OpenConfigStep +from continuedev.plugins.steps.clear_history import ClearHistoryStep +from continuedev.plugins.steps.feedback import FeedbackStep +from continuedev.plugins.steps.comment_code import CommentCodeStep +from continuedev.plugins.steps.main import EditHighlightedCodeStep + class CommitMessageStep(Step): """ @@ -70,6 +76,31 @@ config = ContinueConfig( # description="This is an example slash command. Use /config to edit it and create more", # step=CommitMessageStep, # ) + SlashCommand( + name="edit", + description="Edit code in the current file or the highlighted code", + step=EditHighlightedCodeStep, + ), + SlashCommand( + name="config", + description="Open the config file to create new and edit existing slash commands", + step=OpenConfigStep, + ), + SlashCommand( + name="comment", + description="Write comments for the current file or highlighted code", + step=CommentCodeStep, + ), + SlashCommand( + name="feedback", + description="Send feedback to improve Continue", + step=FeedbackStep, + ), + SlashCommand( + name="clear", + description="Clear step history", + step=ClearHistoryStep, + ) ], # Context providers let you quickly select context by typing '@' diff --git a/continuedev/src/continuedev/libs/constants/main.py b/continuedev/src/continuedev/libs/constants/main.py index 96eb6e69..f5964df6 100644 --- a/continuedev/src/continuedev/libs/constants/main.py +++ b/continuedev/src/continuedev/libs/constants/main.py @@ -3,4 +3,4 @@ CONTINUE_GLOBAL_FOLDER = ".continue" CONTINUE_SESSIONS_FOLDER = "sessions" CONTINUE_SERVER_FOLDER = "server" - +CONTINUE_SERVER_VERSION_FILE = "server_version.txt" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 16428d4e..73aff383 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,12 +1,21 @@ from functools import cached_property import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Union, Optional +from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional from pydantic import BaseModel -from ...core.main import ChatMessage import openai -from ..llm import LLM + +from ...core.main import ChatMessage from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top +from ..llm import LLM + + +class OpenAIServerInfo(BaseModel): + api_base: Optional[str] = None + engine: Optional[str] = None + api_version: Optional[str] = None + api_type: Literal["azure", "openai"] = "openai" + CHAT_MODELS = { "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" @@ -27,6 +36,7 @@ class AzureInfo(BaseModel): class OpenAI(LLM): model: str + openai_server_info: Optional[OpenAIServerInfo] = None requires_api_key = "OPENAI_API_KEY" requires_write_log = True @@ -41,11 +51,12 @@ class OpenAI(LLM): self.api_key = api_key openai.api_key = self.api_key - # Using an Azure OpenAI deployment - if self.azure_info is not None: - openai.api_type = "azure" - openai.api_base = self.azure_info.endpoint - openai.api_version = self.azure_info.api_version + if self.openai_server_info is not None: + openai.api_type = self.openai_server_info.api_type + if self.openai_server_info.api_base is not None: + openai.api_base = self.openai_server_info.api_base + if self.openai_server_info.api_version is not None: + openai.api_version = self.openai_server_info.api_version async def stop(self): pass @@ -61,8 +72,8 @@ class OpenAI(LLM): @property def default_args(self): args = {**DEFAULT_ARGS, "model": self.model} - if self.azure_info is not None: - args["engine"] = self.azure_info.engine + if self.openai_server_info is not None: + args["engine"] = self.openai_server_info.engine return args def count_tokens(self, text: str): diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py index 2473c638..4c6d3c95 100644 --- a/continuedev/src/continuedev/libs/util/create_async_task.py +++ b/continuedev/src/continuedev/libs/util/create_async_task.py @@ -1,12 +1,13 @@ -from typing import Coroutine, Union +from typing import Callable, Coroutine, Optional, Union import traceback from .telemetry import posthog_logger +from .logging import logger import asyncio import nest_asyncio nest_asyncio.apply() -def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): +def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None): """asyncio.create_task and log errors by adding a callback""" task = asyncio.create_task(coro) @@ -14,11 +15,16 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None): try: future.result() except Exception as e: - print("Exception caught from async task: ", - '\n'.join(traceback.format_exception(e))) + formatted_tb = '\n'.join(traceback.format_exception(e)) + logger.critical( + f"Exception caught from async task: {formatted_tb}") posthog_logger.capture_event("async_task_error", { "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e)) }) + # Log the error to the GUI + if on_error is not None: + asyncio.create_task(on_error(e)) + task.add_done_callback(callback) return task diff --git a/continuedev/src/continuedev/libs/util/logging.py b/continuedev/src/continuedev/libs/util/logging.py new file mode 100644 index 00000000..668d313f --- /dev/null +++ b/continuedev/src/continuedev/libs/util/logging.py @@ -0,0 +1,30 @@ +import logging + +from .paths import getLogFilePath + +# Create a logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# Create a file handler +file_handler = logging.FileHandler(getLogFilePath()) +file_handler.setLevel(logging.DEBUG) + +# Create a console handler +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.DEBUG) + +# Create a formatter +formatter = logging.Formatter( + '[%(asctime)s] [%(levelname)s] %(message)s') + +# Add the formatter to the handlers +file_handler.setFormatter(formatter) +console_handler.setFormatter(formatter) + +# Add the handlers to the logger +logger.addHandler(file_handler) +logger.addHandler(console_handler) + +# Log a test message +logger.debug('Testing logs') diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py index baa25da6..ed1e79b7 100644 --- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py +++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py @@ -14,6 +14,7 @@ from ...plugins.steps.on_traceback import DefaultOnTracebackStep from ...plugins.steps.clear_history import ClearHistoryStep from ...plugins.steps.open_config import OpenConfigStep from ...plugins.steps.help import HelpStep +from ...libs.util.logging import logger # This mapping is used to convert from string in ContinueConfig json to corresponding Step class. # Used for example in slash_commands and steps_on_startup @@ -38,6 +39,6 @@ def get_step_from_name(step_name: str, params: Dict) -> Step: try: return step_name_to_step_class[step_name](**params) except: - print( + logger.error( f"Incorrect parameters for step {step_name}. Parameters provided were: {params}") raise diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py index a967828e..60c910bb 100644 --- a/continuedev/src/continuedev/libs/util/telemetry.py +++ b/continuedev/src/continuedev/libs/util/telemetry.py @@ -3,6 +3,9 @@ from posthog import Posthog import os from dotenv import load_dotenv from .commonregex import clean_pii_from_any +from .logging import logger +from .paths import getServerFolderPath +from ..constants.main import CONTINUE_SERVER_VERSION_FILE load_dotenv() in_codespaces = os.getenv("CODESPACES") == "true" @@ -10,28 +13,52 @@ POSTHOG_API_KEY = 'phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs' class PostHogLogger: + unique_id: str = "NO_UNIQUE_ID" + allow_anonymous_telemetry: bool = True + def __init__(self, api_key: str): self.api_key = api_key - self.unique_id = None - self.allow_anonymous_telemetry = True - - def setup(self, unique_id: str, allow_anonymous_telemetry: bool): - self.unique_id = unique_id - self.allow_anonymous_telemetry = allow_anonymous_telemetry # The personal API key is necessary only if you want to use local evaluation of feature flags. self.posthog = Posthog(self.api_key, host='https://app.posthog.com') + def setup(self, unique_id: str, allow_anonymous_telemetry: bool): + logger.debug(f"Setting unique_id as {unique_id}") + self.unique_id = unique_id or "NO_UNIQUE_ID" + self.allow_anonymous_telemetry = allow_anonymous_telemetry or True + def capture_event(self, event_name: str, event_properties: Any): - if not self.allow_anonymous_telemetry or self.unique_id is None: + # logger.debug( + # f"Logging to PostHog: {event_name} ({self.unique_id}, {self.allow_anonymous_telemetry}): {event_properties}") + telemetry_path = os.path.expanduser("~/.continue/telemetry.log") + + # Make sure the telemetry file exists + if not os.path.exists(telemetry_path): + os.makedirs(os.path.dirname(telemetry_path), exist_ok=True) + open(telemetry_path, "w").close() + + with open(telemetry_path, "a") as f: + str_to_write = f"{event_name}: {event_properties}\n{self.unique_id}\n{self.allow_anonymous_telemetry}\n\n" + f.write(str_to_write) + + if not self.allow_anonymous_telemetry: return + # Clean PII from event properties + event_properties = clean_pii_from_any(event_properties) + + # Add additional properties that are on every event if in_codespaces: event_properties['codespaces'] = True + server_version_file = os.path.join( + getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE) + if os.path.exists(server_version_file): + with open(server_version_file, "r") as f: + event_properties['server_version'] = f.read() + # Send event to PostHog - self.posthog.capture(self.unique_id, event_name, - clean_pii_from_any(event_properties)) + self.posthog.capture(self.unique_id, event_name, event_properties) posthog_logger = PostHogLogger(api_key=POSTHOG_API_KEY) diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py index 06614984..51869fdd 100644 --- a/continuedev/src/continuedev/models/generate_json_schema.py +++ b/continuedev/src/continuedev/models/generate_json_schema.py @@ -38,7 +38,7 @@ def main(): try: json = schema_json_of(model, indent=2, title=title) except Exception as e: - print(f"Failed to generate json schema for {title}: ", e) + print(f"Failed to generate json schema for {title}: {e}") continue with open(f"{SCHEMA_DIR}/{title}.json", "w") as f: diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py index 31c8e1d9..634774df 100644 --- a/continuedev/src/continuedev/plugins/context_providers/file.py +++ b/continuedev/src/continuedev/plugins/context_providers/file.py @@ -49,13 +49,12 @@ class FileContextProvider(ContextProvider): """ title = "file" - workspace_dir: str ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + \ list(filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS)) - async def provide_context_items(self) -> List[ContextItem]: + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: filepaths = [] - for root, dir_names, file_names in os.walk(self.workspace_dir): + for root, dir_names, file_names in os.walk(workspace_dir): dir_names[:] = [d for d in dir_names if not any( fnmatch(d, pattern) for pattern in self.ignore_patterns)] for file_name in file_names: diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py new file mode 100644 index 00000000..c7b4806b --- /dev/null +++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py @@ -0,0 +1,49 @@ +import json +from typing import List +import os +import aiohttp + +from ...core.main import ContextItem, ContextItemDescription, ContextItemId +from ...core.context import ContextProvider + + +def format_file_tree(startpath) -> str: + result = "" + for root, dirs, files in os.walk(startpath): + level = root.replace(startpath, '').count(os.sep) + indent = ' ' * 4 * (level) + result += '{}{}/'.format(indent, os.path.basename(root)) + "\n" + subindent = ' ' * 4 * (level + 1) + for f in files: + result += '{}{}'.format(subindent, f) + "\n" + + return result + + +class FileTreeContextProvider(ContextProvider): + title = "tree" + + workspace_dir: str = None + + def _filetree_context_item(self): + return ContextItem( + content=format_file_tree(self.workspace_dir), + description=ContextItemDescription( + name="File Tree", + description="Add a formatted file tree of this directory to the context", + id=ContextItemId( + provider_title=self.title, + item_id=self.title + ) + ) + ) + + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: + self.workspace_dir = workspace_dir + return [self._filetree_context_item()] + + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + if not id.item_id == self.title: + raise Exception("Invalid item id") + + return self._filetree_context_item() diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py index 765a534d..2e7047f2 100644 --- a/continuedev/src/continuedev/plugins/context_providers/github.py +++ b/continuedev/src/continuedev/plugins/context_providers/github.py @@ -15,7 +15,7 @@ class GitHubIssuesContextProvider(ContextProvider): repo_name: str auth_token: str - async def provide_context_items(self) -> List[ContextItem]: + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: auth = Auth.Token(self.auth_token) gh = Github(auth=auth) diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py index 64954833..fc76fe67 100644 --- a/continuedev/src/continuedev/plugins/context_providers/google.py +++ b/continuedev/src/continuedev/plugins/context_providers/google.py @@ -42,7 +42,7 @@ class GoogleContextProvider(ContextProvider): async with session.post(url, headers=headers, data=payload) as response: return await response.text() - async def provide_context_items(self) -> List[ContextItem]: + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: return [self.BASE_CONTEXT_ITEM] async def get_item(self, id: ContextItemId, query: str, _) -> ContextItem: diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py index 1d040101..acd40dc7 100644 --- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py +++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List from ...core.main import ChatMessage from ...models.filesystem import RangeInFile, RangeInFileWithContents -from ...core.context import ContextItem, ContextItemDescription, ContextItemId +from ...core.context import ContextItem, ContextItemDescription, ContextItemId, ContextProvider from pydantic import BaseModel @@ -12,7 +12,7 @@ class HighlightedRangeContextItem(BaseModel): item: ContextItem -class HighlightedCodeContextProvider(BaseModel): +class HighlightedCodeContextProvider(ContextProvider): """ The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'. When you type '@', the context provider will be asked to populate a list of options. @@ -96,9 +96,18 @@ class HighlightedCodeContextProvider(BaseModel): hr.item.description.name = self._rif_to_name( hr.rif, display_filename=basename) - async def provide_context_items(self) -> List[ContextItem]: + async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]: return [] + async def get_item(self, id: ContextItemId, query: str) -> ContextItem: + raise NotImplementedError() + + async def clear_context(self): + self.highlighted_ranges = [] + self.adding_highlighted_code = False + self.should_get_fallback_context_item = True + self.last_added_fallback = False + async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]: indices_to_delete = [ int(id.item_id) for id in ids diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index 0a0fbca2..455d5a13 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -5,7 +5,7 @@ from pydantic import Field from ...libs.util.strings import remove_quotes_and_escapes from .main import EditHighlightedCodeStep -from .core.core import MessageStep +from .core.core import DisplayErrorStep, MessageStep from ...core.main import FunctionCall, Models from ...core.main import ChatMessage, Step, step_to_json_schema from ...core.sdk import ContinueSDK @@ -27,34 +27,32 @@ class SimpleChatStep(Step): messages: List[ChatMessage] = None async def run(self, sdk: ContinueSDK): - completion = "" messages = self.messages or await sdk.get_chat_context() generator = sdk.models.default.stream_chat( messages, temperature=sdk.config.temperature) - try: - async for chunk in generator: - if sdk.current_step_was_deleted(): - # So that the message doesn't disappear - self.hide = False - break - if "content" in chunk: - self.description += chunk["content"] - completion += chunk["content"] - await sdk.update_ui() - finally: - self.name = remove_quotes_and_escapes(await sdk.models.medium.complete( - f"Write a short title for the following chat message: {self.description}")) - - self.chat_context.append(ChatMessage( - role="assistant", - content=completion, - summary=self.name - )) - - # TODO: Never actually closing. - await generator.aclose() + async for chunk in generator: + if sdk.current_step_was_deleted(): + # So that the message doesn't disappear + self.hide = False + break + + if "content" in chunk: + self.description += chunk["content"] + await sdk.update_ui() + + self.name = remove_quotes_and_escapes(await sdk.models.medium.complete( + f"Write a short title for the following chat message: {self.description}")) + + self.chat_context.append(ChatMessage( + role="assistant", + content=self.description, + summary=self.name + )) + + # TODO: Never actually closing. + await generator.aclose() class AddFileStep(Step): diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index 4c5303fb..fb9ea029 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -1,9 +1,13 @@ # These steps are depended upon by ContinueSDK import os -import subprocess +import json import difflib from textwrap import dedent -from typing import Coroutine, List, Literal, Union +import traceback +from typing import Any, Coroutine, List, Union +import difflib + +from pydantic import validator from ....libs.llm.ggml import GGML from ....models.main import Range @@ -14,7 +18,6 @@ from ....core.observation import Observation, TextObservation, TracebackObservat from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes -import difflib class ContinueSDK: @@ -41,6 +44,25 @@ class MessageStep(Step): return TextObservation(text=self.message) +class DisplayErrorStep(Step): + name: str = "Error in the Continue server" + e: Any + + class Config: + arbitrary_types_allowed = True + + @validator("e", pre=True, always=True) + def validate_e(cls, v): + if isinstance(v, Exception): + return '\n'.join(traceback.format_exception(v)) + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.e + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + raise ContinueCustomException(message=self.e, title=self.name) + + class FileSystemEditStep(ReversibleStep): edit: FileSystemEdit _diff: Union[EditDiff, None] = None diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index 4d75af30..ec670999 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -34,26 +34,33 @@ class HelpStep(Step): description: str = "" async def run(self, sdk: ContinueSDK): - question = self.user_input - prompt = dedent(f"""Please us the information below to provide a succinct answer to the following quesiton: {question} - - Information: - - {help}""") - - self.chat_context.append(ChatMessage( - role="user", - content=prompt, - summary="Help" - )) - messages = await sdk.get_chat_context() - generator = sdk.models.default.stream_chat(messages) - async for chunk in generator: - if "content" in chunk: - self.description += chunk["content"] - await sdk.update_ui() + if question.strip() == "": + self.description = help + else: + prompt = dedent(f""" + Information: + + {help} + + Instructions: + + Please us the information below to provide a succinct answer to the following question: {question} + + Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback.""") + + self.chat_context.append(ChatMessage( + role="user", + content=prompt, + summary="Help" + )) + messages = await sdk.get_chat_context() + generator = sdk.models.default.stream_chat(messages) + async for chunk in generator: + if "content" in chunk: + self.description += chunk["content"] + await sdk.update_ui() posthog_logger.capture_event( "help", {"question": question, "answer": self.description}) diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index 26c1cabd..2c3d34fc 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -13,6 +13,7 @@ from ...core.sdk import ContinueSDK, Models from ...core.observation import Observation from .core.core import DefaultModelEditCodeStep from ...libs.util.calculate_diff import calculate_diff2 +from ...libs.util.logging import logger class Policy(BaseModel): @@ -105,7 +106,7 @@ class FasterEditHighlightedCodeStep(Step): # Temporarily doing this to generate description. self._prompt = prompt self._completion = completion - print(completion) + logger.debug(completion) # ALTERNATIVE DECODING STEP HERE raw_file_edits = [] diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index c13047d6..966acb7c 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -65,5 +65,5 @@ class EditAllMatchesStep(Step): range=range_in_file.range, filename=range_in_file.filepath, prompt=self.user_request - ), sdk.ide.unique_id) for range_in_file in range_in_files] + )) for range_in_file in range_in_files] await asyncio.gather(*tasks) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index c0957395..98a5aea0 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -8,10 +8,12 @@ import traceback from uvicorn.main import Server from .session_manager import session_manager, Session +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .gui_protocol import AbstractGUIProtocolServer from ..libs.util.queue import AsyncSubscriptionQueue from ..libs.util.telemetry import posthog_logger from ..libs.util.create_async_task import create_async_task +from ..libs.util.logging import logger router = APIRouter(prefix="/gui", tags=["gui"]) @@ -25,17 +27,13 @@ class AppStatus: @staticmethod def handle_exit(*args, **kwargs): AppStatus.should_exit = True - print("Shutting down") + logger.debug("Shutting down") original_handler(*args, **kwargs) Server.handle_exit = AppStatus.handle_exit -async def session(x_continue_session_id: str = Header("anonymous")) -> Session: - return await session_manager.get_session(x_continue_session_id) - - async def websocket_session(session_id: str) -> Session: return await session_manager.get_session(session_id) @@ -73,103 +71,97 @@ class GUIProtocolServer(AbstractGUIProtocolServer): resp = await self._receive_json(message_type) return resp_model.parse_obj(resp) + def on_error(self, e: Exception): + return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + def handle_json(self, message_type: str, data: Any): - try: - if message_type == "main_input": - self.on_main_input(data["input"]) - elif message_type == "step_user_input": - self.on_step_user_input(data["input"], data["index"]) - elif message_type == "refinement_input": - self.on_refinement_input(data["input"], data["index"]) - elif message_type == "reverse_to_index": - self.on_reverse_to_index(data["index"]) - elif message_type == "retry_at_index": - self.on_retry_at_index(data["index"]) - elif message_type == "clear_history": - self.on_clear_history() - elif message_type == "delete_at_index": - self.on_delete_at_index(data["index"]) - elif message_type == "delete_context_with_ids": - self.on_delete_context_with_ids(data["ids"]) - elif message_type == "toggle_adding_highlighted_code": - self.on_toggle_adding_highlighted_code() - elif message_type == "set_editing_at_indices": - self.on_set_editing_at_indices(data["indices"]) - elif message_type == "show_logs_at_index": - self.on_show_logs_at_index(data["index"]) - elif message_type == "select_context_item": - self.select_context_item(data["id"], data["query"]) - except Exception as e: - print(e) + if message_type == "main_input": + self.on_main_input(data["input"]) + elif message_type == "step_user_input": + self.on_step_user_input(data["input"], data["index"]) + elif message_type == "refinement_input": + self.on_refinement_input(data["input"], data["index"]) + elif message_type == "reverse_to_index": + self.on_reverse_to_index(data["index"]) + elif message_type == "retry_at_index": + self.on_retry_at_index(data["index"]) + elif message_type == "clear_history": + self.on_clear_history() + elif message_type == "delete_at_index": + self.on_delete_at_index(data["index"]) + elif message_type == "delete_context_with_ids": + self.on_delete_context_with_ids(data["ids"]) + elif message_type == "toggle_adding_highlighted_code": + self.on_toggle_adding_highlighted_code() + elif message_type == "set_editing_at_indices": + self.on_set_editing_at_indices(data["indices"]) + elif message_type == "show_logs_at_index": + self.on_show_logs_at_index(data["index"]) + elif message_type == "select_context_item": + self.select_context_item(data["id"], data["query"]) def on_main_input(self, input: str): # Do something with user input - create_async_task(self.session.autopilot.accept_user_input( - input), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.accept_user_input(input), self.on_error) def on_reverse_to_index(self, index: int): # Reverse the history to the given index - create_async_task(self.session.autopilot.reverse_to_index( - index), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.reverse_to_index(index), self.on_error) def on_step_user_input(self, input: str, index: int): create_async_task( - self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.give_user_input(input, index), self.on_error) def on_refinement_input(self, input: str, index: int): create_async_task( - self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.accept_refinement_input(input, index), self.on_error) def on_retry_at_index(self, index: int): create_async_task( - self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.retry_at_index(index), self.on_error) def on_clear_history(self): - create_async_task(self.session.autopilot.clear_history( - ), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.clear_history(), self.on_error) def on_delete_at_index(self, index: int): - create_async_task(self.session.autopilot.delete_at_index( - index), self.session.autopilot.continue_sdk.ide.unique_id) + create_async_task( + self.session.autopilot.delete_at_index(index), self.on_error) def on_delete_context_with_ids(self, ids: List[str]): create_async_task( - self.session.autopilot.delete_context_with_ids( - ids), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.delete_context_with_ids(ids), self.on_error) def on_toggle_adding_highlighted_code(self): create_async_task( - self.session.autopilot.toggle_adding_highlighted_code( - ), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) def on_set_editing_at_indices(self, indices: List[int]): create_async_task( - self.session.autopilot.set_editing_at_indices( - indices), self.session.autopilot.continue_sdk.ide.unique_id - ) + self.session.autopilot.set_editing_at_indices(indices), self.on_error) def on_show_logs_at_index(self, index: int): name = f"continue_logs.txt" logs = "\n\n############################################\n\n".join( ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) create_async_task( - self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" create_async_task( - self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id) + self.session.autopilot.select_context_item(id, query), self.on_error) @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): try: - print("Received websocket connection at url: ", websocket.url) + logger.debug(f"Received websocket connection at url: {websocket.url}") await websocket.accept() - print("Session started") + logger.debug("Session started") session_manager.register_websocket(session.session_id, websocket) protocol = GUIProtocolServer(session) protocol.websocket = websocket @@ -179,7 +171,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we while AppStatus.should_exit is False: message = await websocket.receive_text() - print("Received GUI message", message) + logger.debug(f"Received GUI message {message}") if type(message) is str: message = json.loads(message) @@ -190,16 +182,21 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we protocol.handle_json(message_type, data) except WebSocketDisconnect as e: - print("GUI websocket disconnected") + logger.debug("GUI websocket disconnected") except Exception as e: - print("ERROR in gui websocket: ", e) + # Log, send to PostHog, and send to GUI + logger.debug(f"ERROR in gui websocket: {e}") + err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) + "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + + await protocol.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + raise e finally: - print("Closing gui websocket") + logger.debug("Closing gui websocket") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() await session_manager.persist_session(session.session_id) - session_manager.remove_session(session.session_id) + await session_manager.remove_session(session.session_id) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index 87374928..e4c07029 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -10,6 +10,7 @@ from pydantic import BaseModel import traceback import asyncio +from ..plugins.steps.core.core import DisplayErrorStep, MessageStep from .meilisearch_server import start_meilisearch from ..libs.util.telemetry import posthog_logger from ..libs.util.queue import AsyncSubscriptionQueue @@ -19,6 +20,7 @@ from .gui import session_manager from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from .session_manager import SessionManager +from ..libs.util.logging import logger import nest_asyncio nest_asyncio.apply() @@ -37,7 +39,7 @@ class AppStatus: @staticmethod def handle_exit(*args, **kwargs): AppStatus.should_exit = True - print("Shutting down") + logger.debug("Shutting down") original_handler(*args, **kwargs) @@ -140,7 +142,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): continue message_type = message["messageType"] data = message["data"] - print("Received message while initializing", message_type) + logger.debug(f"Received message while initializing {message_type}") if message_type == "workspaceDirectory": self.workspace_directory = data["workspaceDirectory"] elif message_type == "uniqueId": @@ -154,9 +156,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def _send_json(self, message_type: str, data: Any): if self.websocket.application_state == WebSocketState.DISCONNECTED: - print("Tried to send message, but websocket is disconnected", message_type) + logger.debug( + f"Tried to send message, but websocket is disconnected: {message_type}") return - print("Sending IDE message: ", message_type) + logger.debug(f"Sending IDE message: {message_type}") await self.websocket.send_json({ "messageType": message_type, "data": data @@ -167,7 +170,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout) except asyncio.TimeoutError: raise Exception( - "IDE Protocol _receive_json timed out after 20 seconds", message_type) + f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}") async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T: await self._send_json(message_type, data) @@ -277,6 +280,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer): # This is where you might have triggers: plugins can subscribe to certian events # like file changes, tracebacks, etc... + def on_error(self, e: Exception): + return self.session_manager.sessions[self.session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + def onAcceptRejectSuggestion(self, accepted: bool): posthog_logger.capture_event("accept_reject_suggestion", { "accepted": accepted @@ -309,22 +315,22 @@ class IdeProtocolServer(AbstractIdeProtocolServer): def onDeleteAtIndex(self, index: int): if autopilot := self.__get_autopilot(): - create_async_task(autopilot.delete_at_index(index), self.unique_id) + create_async_task(autopilot.delete_at_index(index), self.on_error) def onCommandOutput(self, output: str): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.handle_command_output(output), self.unique_id) + autopilot.handle_command_output(output), self.on_error) def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]): if autopilot := self.__get_autopilot(): create_async_task(autopilot.handle_highlighted_code( - range_in_files), self.unique_id) + range_in_files), self.on_error) def onMainUserInput(self, input: str): if autopilot := self.__get_autopilot(): create_async_task( - autopilot.accept_user_input(input), self.unique_id) + autopilot.accept_user_input(input), self.on_error) # Request information. Session doesn't matter. async def getOpenFiles(self) -> List[str]: @@ -354,7 +360,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer): }, GetUserSecretResponse, "getUserSecret") return resp.value except Exception as e: - print("Error getting user secret", e) + logger.debug(f"Error getting user secret: {e}") return "" async def saveFile(self, filepath: str): @@ -437,15 +443,15 @@ class IdeProtocolServer(AbstractIdeProtocolServer): async def websocket_endpoint(websocket: WebSocket, session_id: str = None): try: await websocket.accept() - print("Accepted websocket connection from, ", websocket.client) + logger.debug(f"Accepted websocket connection from {websocket.client}") await websocket.send_json({"messageType": "connected", "data": {}}) # Start meilisearch try: await start_meilisearch() except Exception as e: - print("Failed to start MeiliSearch") - print(e) + logger.debug("Failed to start MeiliSearch") + logger.debug(e) def handle_msg(msg): message = json.loads(msg) @@ -455,9 +461,9 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): message_type = message["messageType"] data = message["data"] - print("Received IDE message: ", message_type) + logger.debug(f"Received IDE message: {message_type}") create_async_task( - ideProtocolServer.handle_json(message_type, data)) + ideProtocolServer.handle_json(message_type, data), ideProtocolServer.on_error) ideProtocolServer = IdeProtocolServer(session_manager, websocket) if session_id is not None: @@ -473,15 +479,20 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None): message = await websocket.receive_text() handle_msg(message) - print("Closing ide websocket") + logger.debug("Closing ide websocket") except WebSocketDisconnect as e: - print("IDE wbsocket disconnected") + logger.debug("IDE wbsocket disconnected") except Exception as e: - print("Error in ide websocket: ", e) + logger.debug(f"Error in ide websocket: {e}") + err_msg = '\n'.join(traceback.format_exception(e)) posthog_logger.capture_event("gui_error", { - "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))}) + "error_title": e.__str__() or e.__repr__(), "error_message": err_msg}) + + await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)) + raise e finally: + logger.debug("Closing ide websocket") if websocket.client_state != WebSocketState.DISCONNECTED: await websocket.close() diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index b92c9fa3..468bc855 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -15,6 +15,7 @@ from .ide import router as ide_router from .gui import router as gui_router from .session_manager import session_manager from ..libs.util.paths import getLogFilePath +from ..libs.util.logging import logger app = FastAPI() @@ -33,45 +34,49 @@ app.add_middleware( @app.get("/health") def health(): - print("Testing") + logger.debug("Health check") return {"status": "ok"} -# add cli arg for server port -parser = argparse.ArgumentParser() -parser.add_argument("-p", "--port", help="server port", - type=int, default=65432) -args = parser.parse_args() - -log_path = getLogFilePath() -LOG_CONFIG = { - 'version': 1, - 'disable_existing_loggers': False, - 'handlers': { - 'file': { - 'level': 'DEBUG', - 'class': 'logging.FileHandler', - 'filename': log_path, - }, - }, - 'root': { - 'level': 'DEBUG', - 'handlers': ['file'] - } -} -print(f"Log path: {log_path}") +class Logger(object): + def __init__(self, log_file: str): + self.terminal = sys.stdout + self.log = open(log_file, "a") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + + def isatty(self): + return False + + +try: + # add cli arg for server port + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--port", help="server port", + type=int, default=65432) + args = parser.parse_args() +except Exception as e: + logger.debug(f"Error parsing command line arguments: {e}") + raise e def run_server(): - config = uvicorn.Config(app, host="0.0.0.0", - port=args.port, log_config=LOG_CONFIG) + config = uvicorn.Config(app, host="127.0.0.1", port=args.port) server = uvicorn.Server(config) server.run() async def cleanup_coroutine(): - print("Cleaning up sessions") + logger.debug("Cleaning up sessions") for session_id in session_manager.sessions: await session_manager.persist_session(session_id) @@ -90,13 +95,14 @@ def cpu_usage_report(): time.sleep(1) # Call cpu_percent again to get the CPU usage over the interval cpu_usage = process.cpu_percent(interval=None) - print(f"CPU usage: {cpu_usage}%") + logger.debug(f"CPU usage: {cpu_usage}%") atexit.register(cleanup) if __name__ == "__main__": try: + # Uncomment to get CPU usage reports # import threading # def cpu_usage_loop(): @@ -109,6 +115,6 @@ if __name__ == "__main__": run_server() except Exception as e: - print("Error starting Continue server: ", e) + logger.debug(f"Error starting Continue server: {e}") cleanup() raise e diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py index 00f692f5..7f460afc 100644 --- a/continuedev/src/continuedev/server/meilisearch_server.py +++ b/continuedev/src/continuedev/server/meilisearch_server.py @@ -5,6 +5,7 @@ import subprocess from meilisearch_python_async import Client from ..libs.util.paths import getServerFolderPath +from ..libs.util.logging import logger def ensure_meilisearch_installed() -> bool: @@ -39,7 +40,7 @@ def ensure_meilisearch_installed() -> bool: shutil.rmtree(p, ignore_errors=True) # Download MeiliSearch - print("Downloading MeiliSearch...") + logger.debug("Downloading MeiliSearch...") subprocess.run( f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath) @@ -82,6 +83,6 @@ async def start_meilisearch(): # Check if MeiliSearch is running if not await check_meilisearch_running() or not was_already_installed: - print("Starting MeiliSearch...") + logger.debug("Starting MeiliSearch...") subprocess.Popen(["./meilisearch", "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT, close_fds=True, start_new_session=True) diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index 3136f1bf..cf46028f 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -4,6 +4,9 @@ from typing import Any, Dict, List, Union from uuid import uuid4 import json +from fastapi.websockets import WebSocketState + +from ..plugins.steps.core.core import DisplayErrorStep from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath from ..models.filesystem_edit import FileEditWithFullContents from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER @@ -13,6 +16,7 @@ from ..core.autopilot import Autopilot from .ide_protocol import AbstractIdeProtocolServer from ..libs.util.create_async_task import create_async_task from ..libs.util.errors import SessionNotFound +from ..libs.util.logging import logger class Session: @@ -59,6 +63,8 @@ class SessionManager: return self.sessions[session_id] async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session: + logger.debug(f"New session: {session_id}") + full_state = None if session_id is not None and os.path.exists(getSessionFilePath(session_id)): with open(getSessionFilePath(session_id), "r") as f: @@ -78,29 +84,35 @@ class SessionManager: }) autopilot.on_update(on_update) - create_async_task(autopilot.run_policy()) + create_async_task(autopilot.run_policy( + ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))) return session - def remove_session(self, session_id: str): - del self.sessions[session_id] + async def remove_session(self, session_id: str): + logger.debug(f"Removing session: {session_id}") + if session_id in self.sessions: + if session_id in self.registered_ides: + ws_to_close = self.registered_ides[session_id].websocket + if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED: + await self.sessions[session_id].autopilot.ide.websocket.close() + + del self.sessions[session_id] async def persist_session(self, session_id: str): """Save the session's FullState as a json file""" full_state = await self.sessions[session_id].autopilot.get_full_state() - if not os.path.exists(getSessionsFolderPath()): - os.mkdir(getSessionsFolderPath()) with open(getSessionFilePath(session_id), "w") as f: json.dump(full_state.dict(), f) def register_websocket(self, session_id: str, ws: WebSocket): self.sessions[session_id].ws = ws - print("Registered websocket for session", session_id) + logger.debug(f"Registered websocket for session {session_id}") async def send_ws_data(self, session_id: str, message_type: str, data: Any): if session_id not in self.sessions: raise SessionNotFound(f"Session {session_id} not found") if self.sessions[session_id].ws is None: - # print(f"Session {session_id} has no websocket") + # logger.debug(f"Session {session_id} has no websocket") return await self.sessions[session_id].ws.send_json({ |