summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/autopilot.py13
-rw-r--r--continuedev/src/continuedev/core/config.py39
-rw-r--r--continuedev/src/continuedev/core/context.py47
-rw-r--r--continuedev/src/continuedev/core/sdk.py40
-rw-r--r--continuedev/src/continuedev/libs/chroma/query.py19
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt43
-rw-r--r--continuedev/src/continuedev/libs/constants/main.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py31
-rw-r--r--continuedev/src/continuedev/libs/util/create_async_task.py14
-rw-r--r--continuedev/src/continuedev/libs/util/logging.py30
-rw-r--r--continuedev/src/continuedev/libs/util/step_name_to_steps.py3
-rw-r--r--continuedev/src/continuedev/libs/util/telemetry.py45
-rw-r--r--continuedev/src/continuedev/models/generate_json_schema.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/file.py5
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/filetree.py49
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/github.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/google.py2
-rw-r--r--continuedev/src/continuedev/plugins/context_providers/highlighted_code.py15
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py46
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py28
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py43
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py3
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py2
-rw-r--r--continuedev/src/continuedev/server/gui.py121
-rw-r--r--continuedev/src/continuedev/server/ide.py49
-rw-r--r--continuedev/src/continuedev/server/main.py64
-rw-r--r--continuedev/src/continuedev/server/meilisearch_server.py5
-rw-r--r--continuedev/src/continuedev/server/session_manager.py26
28 files changed, 492 insertions, 296 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index beb40c75..b4c951b8 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -15,12 +15,13 @@ from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
from ..models.main import ContinueBaseModel
from .main import Context, ContinueCustomException, Policy, History, FullState, Step, HistoryNode
-from ..plugins.steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
+from ..plugins.steps.core.core import DisplayErrorStep, ReversibleStep, ManualEditStep, UserInputStep
from .sdk import ContinueSDK
from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
from openai import error as openai_errors
from ..libs.util.create_async_task import create_async_task
from ..libs.util.telemetry import posthog_logger
+from ..libs.util.logging import logger
def get_error_title(e: Exception) -> str:
@@ -74,7 +75,7 @@ class Autopilot(ContinueBaseModel):
HighlightedCodeContextProvider(ide=ide),
FileContextProvider(workspace_dir=ide.workspace_directory)
])
- await autopilot.context_manager.load_index()
+ await autopilot.context_manager.load_index(ide.workspace_directory)
return autopilot
@@ -152,7 +153,7 @@ class Autopilot(ContinueBaseModel):
await self.update_subscribers()
except Exception as e:
- print(e)
+ logger.debug(e)
def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
for edit in edits:
@@ -257,7 +258,7 @@ class Autopilot(ContinueBaseModel):
e)
# Attach an InternalErrorObservation to the step and unhide it.
- print(
+ logger.error(
f"Error while running step: \n{error_string}\n{error_title}")
posthog_logger.capture_event('step error', {
'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()})
@@ -310,8 +311,8 @@ class Autopilot(ContinueBaseModel):
# Update subscribers with new description
await self.update_subscribers()
- create_async_task(update_description(),
- self.continue_sdk.ide.unique_id)
+ create_async_task(update_description(
+ ), on_error=lambda e: self.continue_sdk.run_step(DisplayErrorStep(e=e)))
return observation
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index af37264d..4fcab588 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -49,45 +49,6 @@ class ContinueConfig(BaseModel):
context_providers: List[ContextProvider] = []
- # Want to force these to be the slash commands for now
- @validator('slash_commands', pre=True)
- def default_slash_commands_validator(cls, v):
- from ..plugins.steps.open_config import OpenConfigStep
- from ..plugins.steps.clear_history import ClearHistoryStep
- from ..plugins.steps.feedback import FeedbackStep
- from ..plugins.steps.comment_code import CommentCodeStep
- from ..plugins.steps.main import EditHighlightedCodeStep
-
- DEFAULT_SLASH_COMMANDS = [
- SlashCommand(
- name="edit",
- description="Edit code in the current file or the highlighted code",
- step=EditHighlightedCodeStep,
- ),
- SlashCommand(
- name="config",
- description="Open the config file to create new and edit existing slash commands",
- step=OpenConfigStep,
- ),
- SlashCommand(
- name="comment",
- description="Write comments for the current file or highlighted code",
- step=CommentCodeStep,
- ),
- SlashCommand(
- name="feedback",
- description="Send feedback to improve Continue",
- step=FeedbackStep,
- ),
- SlashCommand(
- name="clear",
- description="Clear step history",
- step=ClearHistoryStep,
- )
- ]
-
- return DEFAULT_SLASH_COMMANDS + v
-
@validator('temperature', pre=True)
def temperature_validator(cls, v):
return max(0.0, min(1.0, v))
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 8afbd610..e968c35c 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -7,7 +7,7 @@ from pydantic import BaseModel
from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
from ..server.meilisearch_server import check_meilisearch_running
-
+from ..libs.util.logging import logger
SEARCH_INDEX_NAME = "continue_context_items"
@@ -35,7 +35,7 @@ class ContextProvider(BaseModel):
return self.selected_items
@abstractmethod
- async def provide_context_items(self) -> List[ContextItem]:
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
"""
Provide documents for search index. This is run on startup.
@@ -57,16 +57,22 @@ class ContextProvider(BaseModel):
Default implementation uses the search index to get the item.
"""
async with Client('http://localhost:7700') as search_client:
- result = await search_client.index(
- SEARCH_INDEX_NAME).get_document(id.to_string())
- return ContextItem(
- description=ContextItemDescription(
- name=result["name"],
- description=result["description"],
- id=id
- ),
- content=result["content"]
- )
+ try:
+ result = await search_client.index(
+ SEARCH_INDEX_NAME).get_document(id.to_string())
+ return ContextItem(
+ description=ContextItemDescription(
+ name=result["name"],
+ description=result["description"],
+ id=id
+ ),
+ content=result["content"]
+ )
+ except Exception as e:
+ logger.warning(
+ f"Error while retrieving document from meilisearch: {e}")
+
+ return None
async def delete_context_with_ids(self, ids: List[ContextItemId]):
"""
@@ -100,8 +106,8 @@ class ContextProvider(BaseModel):
if item.description.id.item_id == id.item_id:
return
- new_item = await self.get_item(id, query)
- self.selected_items.append(new_item)
+ if new_item := await self.get_item(id, query):
+ self.selected_items.append(new_item)
class ContextManager:
@@ -146,16 +152,16 @@ class ContextManager:
meilisearch_running = False
if not meilisearch_running:
- print(
+ logger.warning(
"MeiliSearch not running, avoiding any dependent context providers")
context_providers = list(
filter(lambda cp: cp.title == "code", context_providers))
return cls(context_providers)
- async def load_index(self):
+ async def load_index(self, workspace_dir: str):
for _, provider in self.context_providers.items():
- context_items = await provider.provide_context_items()
+ context_items = await provider.provide_context_items(workspace_dir)
documents = [
{
"id": item.description.id.to_string(),
@@ -166,8 +172,11 @@ class ContextManager:
for item in context_items
]
if len(documents) > 0:
- async with Client('http://localhost:7700') as search_client:
- await search_client.index(SEARCH_INDEX_NAME).add_documents(documents)
+ try:
+ async with Client('http://localhost:7700') as search_client:
+ await search_client.index(SEARCH_INDEX_NAME).add_documents(documents)
+ except Exception as e:
+ logger.debug(f"Error loading meilisearch index: {e}")
async def select_context_item(self, id: str, query: str):
"""
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 1dd4b857..be7008c0 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,5 +1,6 @@
from functools import cached_property
-from typing import Coroutine, Dict, Union
+import traceback
+from typing import Coroutine, Dict, Literal, Union
import os
from ..plugins.steps.core.core import DefaultModelEditCodeStep
@@ -16,6 +17,7 @@ from ..plugins.steps.core.core import *
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
from .models import Models
+from ..libs.util.logging import logger
class Autopilot:
@@ -43,11 +45,15 @@ class ContinueSDK(AbstractContinueSDK):
config = sdk._load_config_dot_py()
sdk.config = config
except Exception as e:
- print(e)
- sdk.config = ContinueConfig()
+ logger.error(f"Failed to load config.py: {e}")
+
+ sdk.config = ContinueConfig(
+ ) if sdk._last_valid_config is None else sdk._last_valid_config
+
+ formatted_err = '\n'.join(traceback.format_exception(e))
msg_step = MessageStep(
- name="Invalid Continue Config File", message=e.__repr__())
- msg_step.description = e.__repr__()
+ name="Invalid Continue Config File", message=formatted_err)
+ msg_step.description = f"Falling back to default config settings.\n```\n{formatted_err}\n```"
sdk.history.add_node(HistoryNode(
step=msg_step,
observation=None,
@@ -57,6 +63,11 @@ class ContinueSDK(AbstractContinueSDK):
sdk.models = sdk.config.models
await sdk.models.start(sdk)
+
+ # When the config is loaded, setup posthog logger
+ posthog_logger.setup(
+ sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
+
return sdk
@property
@@ -154,21 +165,14 @@ class ContinueSDK(AbstractContinueSDK):
def _load_config_dot_py(self) -> ContinueConfig:
# Use importlib to load the config file config.py at the given path
path = getConfigFilePath()
- try:
- import importlib.util
- spec = importlib.util.spec_from_file_location("config", path)
- config = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(config)
- self._last_valid_config = config.config
- # When the config is loaded, setup posthog logger
- posthog_logger.setup(
- self.ide.unique_id, config.config.allow_anonymous_telemetry or True)
+ import importlib.util
+ spec = importlib.util.spec_from_file_location("config", path)
+ config = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(config)
+ self._last_valid_config = config.config
- return config.config
- except Exception as e:
- print("Error loading config.py: ", e)
- return ContinueConfig() if self._last_valid_config is None else self._last_valid_config
+ return config.config
def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
highlighted_ranges = self.__autopilot.context_manager.context_providers[
diff --git a/continuedev/src/continuedev/libs/chroma/query.py b/continuedev/src/continuedev/libs/chroma/query.py
index c27329f0..f09b813a 100644
--- a/continuedev/src/continuedev/libs/chroma/query.py
+++ b/continuedev/src/continuedev/libs/chroma/query.py
@@ -5,6 +5,7 @@ from llama_index import GPTVectorStoreIndex, StorageContext, load_index_from_sto
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter
import os
from .update import filter_ignored_files, load_gpt_index_documents
+from ..util.logging import logger
from functools import cached_property
@@ -56,7 +57,8 @@ class ChromaIndexManager:
try:
text_chunks = text_splitter.split_text(doc.text)
except:
- print("ERROR (probably found special token): ", doc.text)
+ logger.warning(
+ f"ERROR (probably found special token): {doc.text}")
continue
filename = doc.extra_info["filename"]
chunks[filename] = len(text_chunks)
@@ -79,7 +81,7 @@ class ChromaIndexManager:
index.storage_context.persist(persist_dir=self.index_dir)
- print("Codebase index created")
+ logger.debug("Codebase index created")
def get_modified_deleted_files(self) -> Tuple[List[str], List[str]]:
"""Get a list of all files that have been modified since the last commit."""
@@ -121,7 +123,7 @@ class ChromaIndexManager:
del metadata["chunks"][file]
- print(f"Deleted {file}")
+ logger.debug(f"Deleted {file}")
for file in modified_files:
@@ -132,7 +134,7 @@ class ChromaIndexManager:
for i in range(num_chunks):
index.delete(f"{file}::{i}")
- print(f"Deleted old version of {file}")
+ logger.debug(f"Deleted old version of {file}")
with open(file, "r") as f:
text = f.read()
@@ -145,19 +147,20 @@ class ChromaIndexManager:
metadata["chunks"][file] = len(text_chunks)
- print(f"Inserted new version of {file}")
+ logger.debug(f"Inserted new version of {file}")
metadata["commit"] = self.current_commit
with open(f"{self.index_dir}/metadata.json", "w") as f:
json.dump(metadata, f, indent=4)
- print("Codebase index updated")
+ logger.debug("Codebase index updated")
def query_codebase_index(self, query: str) -> str:
"""Query the codebase index."""
if not self.check_index_exists():
- print("No index found for the codebase at ", self.index_dir)
+ logger.debug(
+ f"No index found for the codebase at {self.index_dir}")
return ""
storage_context = StorageContext.from_defaults(
@@ -180,4 +183,4 @@ class ChromaIndexManager:
documents = [Document(info)]
index = GPTVectorStoreIndex(documents)
index.save_to_disk(f'{self.index_dir}/additional_index.json')
- print("Additional index replaced")
+ logger.debug("Additional index replaced")
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index e40a2684..be978fd3 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -7,12 +7,18 @@ be sure to select the Python interpreter in ~/.continue/server/env.
import subprocess
-from continuedev.src.continuedev.core.main import Step
-from continuedev.src.continuedev.core.sdk import ContinueSDK
-from continuedev.src.continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
-from continuedev.src.continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
-from continuedev.src.continuedev.plugins.context_providers.google import GoogleContextProvider
-from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from continuedev.core.main import Step
+from continuedev.core.sdk import ContinueSDK
+from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
+from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
+from continuedev.plugins.context_providers.google import GoogleContextProvider
+from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from continuedev.plugins.steps.open_config import OpenConfigStep
+from continuedev.plugins.steps.clear_history import ClearHistoryStep
+from continuedev.plugins.steps.feedback import FeedbackStep
+from continuedev.plugins.steps.comment_code import CommentCodeStep
+from continuedev.plugins.steps.main import EditHighlightedCodeStep
+
class CommitMessageStep(Step):
"""
@@ -70,6 +76,31 @@ config = ContinueConfig(
# description="This is an example slash command. Use /config to edit it and create more",
# step=CommitMessageStep,
# )
+ SlashCommand(
+ name="edit",
+ description="Edit code in the current file or the highlighted code",
+ step=EditHighlightedCodeStep,
+ ),
+ SlashCommand(
+ name="config",
+ description="Open the config file to create new and edit existing slash commands",
+ step=OpenConfigStep,
+ ),
+ SlashCommand(
+ name="comment",
+ description="Write comments for the current file or highlighted code",
+ step=CommentCodeStep,
+ ),
+ SlashCommand(
+ name="feedback",
+ description="Send feedback to improve Continue",
+ step=FeedbackStep,
+ ),
+ SlashCommand(
+ name="clear",
+ description="Clear step history",
+ step=ClearHistoryStep,
+ )
],
# Context providers let you quickly select context by typing '@'
diff --git a/continuedev/src/continuedev/libs/constants/main.py b/continuedev/src/continuedev/libs/constants/main.py
index 96eb6e69..f5964df6 100644
--- a/continuedev/src/continuedev/libs/constants/main.py
+++ b/continuedev/src/continuedev/libs/constants/main.py
@@ -3,4 +3,4 @@
CONTINUE_GLOBAL_FOLDER = ".continue"
CONTINUE_SESSIONS_FOLDER = "sessions"
CONTINUE_SERVER_FOLDER = "server"
-
+CONTINUE_SERVER_VERSION_FILE = "server_version.txt"
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 16428d4e..73aff383 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,12 +1,21 @@
from functools import cached_property
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Union, Optional
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
from pydantic import BaseModel
-from ...core.main import ChatMessage
import openai
-from ..llm import LLM
+
+from ...core.main import ChatMessage
from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
+from ..llm import LLM
+
+
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
+
CHAT_MODELS = {
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
@@ -27,6 +36,7 @@ class AzureInfo(BaseModel):
class OpenAI(LLM):
model: str
+ openai_server_info: Optional[OpenAIServerInfo] = None
requires_api_key = "OPENAI_API_KEY"
requires_write_log = True
@@ -41,11 +51,12 @@ class OpenAI(LLM):
self.api_key = api_key
openai.api_key = self.api_key
- # Using an Azure OpenAI deployment
- if self.azure_info is not None:
- openai.api_type = "azure"
- openai.api_base = self.azure_info.endpoint
- openai.api_version = self.azure_info.api_version
+ if self.openai_server_info is not None:
+ openai.api_type = self.openai_server_info.api_type
+ if self.openai_server_info.api_base is not None:
+ openai.api_base = self.openai_server_info.api_base
+ if self.openai_server_info.api_version is not None:
+ openai.api_version = self.openai_server_info.api_version
async def stop(self):
pass
@@ -61,8 +72,8 @@ class OpenAI(LLM):
@property
def default_args(self):
args = {**DEFAULT_ARGS, "model": self.model}
- if self.azure_info is not None:
- args["engine"] = self.azure_info.engine
+ if self.openai_server_info is not None:
+ args["engine"] = self.openai_server_info.engine
return args
def count_tokens(self, text: str):
diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py
index 2473c638..4c6d3c95 100644
--- a/continuedev/src/continuedev/libs/util/create_async_task.py
+++ b/continuedev/src/continuedev/libs/util/create_async_task.py
@@ -1,12 +1,13 @@
-from typing import Coroutine, Union
+from typing import Callable, Coroutine, Optional, Union
import traceback
from .telemetry import posthog_logger
+from .logging import logger
import asyncio
import nest_asyncio
nest_asyncio.apply()
-def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None):
+def create_async_task(coro: Coroutine, on_error: Optional[Callable[[Exception], Coroutine]] = None):
"""asyncio.create_task and log errors by adding a callback"""
task = asyncio.create_task(coro)
@@ -14,11 +15,16 @@ def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None):
try:
future.result()
except Exception as e:
- print("Exception caught from async task: ",
- '\n'.join(traceback.format_exception(e)))
+ formatted_tb = '\n'.join(traceback.format_exception(e))
+ logger.critical(
+ f"Exception caught from async task: {formatted_tb}")
posthog_logger.capture_event("async_task_error", {
"error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))
})
+ # Log the error to the GUI
+ if on_error is not None:
+ asyncio.create_task(on_error(e))
+
task.add_done_callback(callback)
return task
diff --git a/continuedev/src/continuedev/libs/util/logging.py b/continuedev/src/continuedev/libs/util/logging.py
new file mode 100644
index 00000000..668d313f
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/logging.py
@@ -0,0 +1,30 @@
+import logging
+
+from .paths import getLogFilePath
+
+# Create a logger
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+# Create a file handler
+file_handler = logging.FileHandler(getLogFilePath())
+file_handler.setLevel(logging.DEBUG)
+
+# Create a console handler
+console_handler = logging.StreamHandler()
+console_handler.setLevel(logging.DEBUG)
+
+# Create a formatter
+formatter = logging.Formatter(
+ '[%(asctime)s] [%(levelname)s] %(message)s')
+
+# Add the formatter to the handlers
+file_handler.setFormatter(formatter)
+console_handler.setFormatter(formatter)
+
+# Add the handlers to the logger
+logger.addHandler(file_handler)
+logger.addHandler(console_handler)
+
+# Log a test message
+logger.debug('Testing logs')
diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
index baa25da6..ed1e79b7 100644
--- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py
+++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
@@ -14,6 +14,7 @@ from ...plugins.steps.on_traceback import DefaultOnTracebackStep
from ...plugins.steps.clear_history import ClearHistoryStep
from ...plugins.steps.open_config import OpenConfigStep
from ...plugins.steps.help import HelpStep
+from ...libs.util.logging import logger
# This mapping is used to convert from string in ContinueConfig json to corresponding Step class.
# Used for example in slash_commands and steps_on_startup
@@ -38,6 +39,6 @@ def get_step_from_name(step_name: str, params: Dict) -> Step:
try:
return step_name_to_step_class[step_name](**params)
except:
- print(
+ logger.error(
f"Incorrect parameters for step {step_name}. Parameters provided were: {params}")
raise
diff --git a/continuedev/src/continuedev/libs/util/telemetry.py b/continuedev/src/continuedev/libs/util/telemetry.py
index a967828e..60c910bb 100644
--- a/continuedev/src/continuedev/libs/util/telemetry.py
+++ b/continuedev/src/continuedev/libs/util/telemetry.py
@@ -3,6 +3,9 @@ from posthog import Posthog
import os
from dotenv import load_dotenv
from .commonregex import clean_pii_from_any
+from .logging import logger
+from .paths import getServerFolderPath
+from ..constants.main import CONTINUE_SERVER_VERSION_FILE
load_dotenv()
in_codespaces = os.getenv("CODESPACES") == "true"
@@ -10,28 +13,52 @@ POSTHOG_API_KEY = 'phc_JS6XFROuNbhJtVCEdTSYk6gl5ArRrTNMpCcguAXlSPs'
class PostHogLogger:
+ unique_id: str = "NO_UNIQUE_ID"
+ allow_anonymous_telemetry: bool = True
+
def __init__(self, api_key: str):
self.api_key = api_key
- self.unique_id = None
- self.allow_anonymous_telemetry = True
-
- def setup(self, unique_id: str, allow_anonymous_telemetry: bool):
- self.unique_id = unique_id
- self.allow_anonymous_telemetry = allow_anonymous_telemetry
# The personal API key is necessary only if you want to use local evaluation of feature flags.
self.posthog = Posthog(self.api_key, host='https://app.posthog.com')
+ def setup(self, unique_id: str, allow_anonymous_telemetry: bool):
+ logger.debug(f"Setting unique_id as {unique_id}")
+ self.unique_id = unique_id or "NO_UNIQUE_ID"
+ self.allow_anonymous_telemetry = allow_anonymous_telemetry or True
+
def capture_event(self, event_name: str, event_properties: Any):
- if not self.allow_anonymous_telemetry or self.unique_id is None:
+ # logger.debug(
+ # f"Logging to PostHog: {event_name} ({self.unique_id}, {self.allow_anonymous_telemetry}): {event_properties}")
+ telemetry_path = os.path.expanduser("~/.continue/telemetry.log")
+
+ # Make sure the telemetry file exists
+ if not os.path.exists(telemetry_path):
+ os.makedirs(os.path.dirname(telemetry_path), exist_ok=True)
+ open(telemetry_path, "w").close()
+
+ with open(telemetry_path, "a") as f:
+ str_to_write = f"{event_name}: {event_properties}\n{self.unique_id}\n{self.allow_anonymous_telemetry}\n\n"
+ f.write(str_to_write)
+
+ if not self.allow_anonymous_telemetry:
return
+ # Clean PII from event properties
+ event_properties = clean_pii_from_any(event_properties)
+
+ # Add additional properties that are on every event
if in_codespaces:
event_properties['codespaces'] = True
+ server_version_file = os.path.join(
+ getServerFolderPath(), CONTINUE_SERVER_VERSION_FILE)
+ if os.path.exists(server_version_file):
+ with open(server_version_file, "r") as f:
+ event_properties['server_version'] = f.read()
+
# Send event to PostHog
- self.posthog.capture(self.unique_id, event_name,
- clean_pii_from_any(event_properties))
+ self.posthog.capture(self.unique_id, event_name, event_properties)
posthog_logger = PostHogLogger(api_key=POSTHOG_API_KEY)
diff --git a/continuedev/src/continuedev/models/generate_json_schema.py b/continuedev/src/continuedev/models/generate_json_schema.py
index 06614984..51869fdd 100644
--- a/continuedev/src/continuedev/models/generate_json_schema.py
+++ b/continuedev/src/continuedev/models/generate_json_schema.py
@@ -38,7 +38,7 @@ def main():
try:
json = schema_json_of(model, indent=2, title=title)
except Exception as e:
- print(f"Failed to generate json schema for {title}: ", e)
+ print(f"Failed to generate json schema for {title}: {e}")
continue
with open(f"{SCHEMA_DIR}/{title}.json", "w") as f:
diff --git a/continuedev/src/continuedev/plugins/context_providers/file.py b/continuedev/src/continuedev/plugins/context_providers/file.py
index 31c8e1d9..634774df 100644
--- a/continuedev/src/continuedev/plugins/context_providers/file.py
+++ b/continuedev/src/continuedev/plugins/context_providers/file.py
@@ -49,13 +49,12 @@ class FileContextProvider(ContextProvider):
"""
title = "file"
- workspace_dir: str
ignore_patterns: List[str] = DEFAULT_IGNORE_DIRS + \
list(filter(lambda d: f"**/{d}", DEFAULT_IGNORE_DIRS))
- async def provide_context_items(self) -> List[ContextItem]:
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
filepaths = []
- for root, dir_names, file_names in os.walk(self.workspace_dir):
+ for root, dir_names, file_names in os.walk(workspace_dir):
dir_names[:] = [d for d in dir_names if not any(
fnmatch(d, pattern) for pattern in self.ignore_patterns)]
for file_name in file_names:
diff --git a/continuedev/src/continuedev/plugins/context_providers/filetree.py b/continuedev/src/continuedev/plugins/context_providers/filetree.py
new file mode 100644
index 00000000..c7b4806b
--- /dev/null
+++ b/continuedev/src/continuedev/plugins/context_providers/filetree.py
@@ -0,0 +1,49 @@
+import json
+from typing import List
+import os
+import aiohttp
+
+from ...core.main import ContextItem, ContextItemDescription, ContextItemId
+from ...core.context import ContextProvider
+
+
+def format_file_tree(startpath) -> str:
+ result = ""
+ for root, dirs, files in os.walk(startpath):
+ level = root.replace(startpath, '').count(os.sep)
+ indent = ' ' * 4 * (level)
+ result += '{}{}/'.format(indent, os.path.basename(root)) + "\n"
+ subindent = ' ' * 4 * (level + 1)
+ for f in files:
+ result += '{}{}'.format(subindent, f) + "\n"
+
+ return result
+
+
+class FileTreeContextProvider(ContextProvider):
+ title = "tree"
+
+ workspace_dir: str = None
+
+ def _filetree_context_item(self):
+ return ContextItem(
+ content=format_file_tree(self.workspace_dir),
+ description=ContextItemDescription(
+ name="File Tree",
+ description="Add a formatted file tree of this directory to the context",
+ id=ContextItemId(
+ provider_title=self.title,
+ item_id=self.title
+ )
+ )
+ )
+
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
+ self.workspace_dir = workspace_dir
+ return [self._filetree_context_item()]
+
+ async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
+ if not id.item_id == self.title:
+ raise Exception("Invalid item id")
+
+ return self._filetree_context_item()
diff --git a/continuedev/src/continuedev/plugins/context_providers/github.py b/continuedev/src/continuedev/plugins/context_providers/github.py
index 765a534d..2e7047f2 100644
--- a/continuedev/src/continuedev/plugins/context_providers/github.py
+++ b/continuedev/src/continuedev/plugins/context_providers/github.py
@@ -15,7 +15,7 @@ class GitHubIssuesContextProvider(ContextProvider):
repo_name: str
auth_token: str
- async def provide_context_items(self) -> List[ContextItem]:
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
auth = Auth.Token(self.auth_token)
gh = Github(auth=auth)
diff --git a/continuedev/src/continuedev/plugins/context_providers/google.py b/continuedev/src/continuedev/plugins/context_providers/google.py
index 64954833..fc76fe67 100644
--- a/continuedev/src/continuedev/plugins/context_providers/google.py
+++ b/continuedev/src/continuedev/plugins/context_providers/google.py
@@ -42,7 +42,7 @@ class GoogleContextProvider(ContextProvider):
async with session.post(url, headers=headers, data=payload) as response:
return await response.text()
- async def provide_context_items(self) -> List[ContextItem]:
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
return [self.BASE_CONTEXT_ITEM]
async def get_item(self, id: ContextItemId, query: str, _) -> ContextItem:
diff --git a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
index 1d040101..acd40dc7 100644
--- a/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
+++ b/continuedev/src/continuedev/plugins/context_providers/highlighted_code.py
@@ -3,7 +3,7 @@ from typing import Any, Dict, List
from ...core.main import ChatMessage
from ...models.filesystem import RangeInFile, RangeInFileWithContents
-from ...core.context import ContextItem, ContextItemDescription, ContextItemId
+from ...core.context import ContextItem, ContextItemDescription, ContextItemId, ContextProvider
from pydantic import BaseModel
@@ -12,7 +12,7 @@ class HighlightedRangeContextItem(BaseModel):
item: ContextItem
-class HighlightedCodeContextProvider(BaseModel):
+class HighlightedCodeContextProvider(ContextProvider):
"""
The ContextProvider class is a plugin that lets you provide new information to the LLM by typing '@'.
When you type '@', the context provider will be asked to populate a list of options.
@@ -96,9 +96,18 @@ class HighlightedCodeContextProvider(BaseModel):
hr.item.description.name = self._rif_to_name(
hr.rif, display_filename=basename)
- async def provide_context_items(self) -> List[ContextItem]:
+ async def provide_context_items(self, workspace_dir: str) -> List[ContextItem]:
return []
+ async def get_item(self, id: ContextItemId, query: str) -> ContextItem:
+ raise NotImplementedError()
+
+ async def clear_context(self):
+ self.highlighted_ranges = []
+ self.adding_highlighted_code = False
+ self.should_get_fallback_context_item = True
+ self.last_added_fallback = False
+
async def delete_context_with_ids(self, ids: List[ContextItemId]) -> List[ContextItem]:
indices_to_delete = [
int(id.item_id) for id in ids
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index 0a0fbca2..455d5a13 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -5,7 +5,7 @@ from pydantic import Field
from ...libs.util.strings import remove_quotes_and_escapes
from .main import EditHighlightedCodeStep
-from .core.core import MessageStep
+from .core.core import DisplayErrorStep, MessageStep
from ...core.main import FunctionCall, Models
from ...core.main import ChatMessage, Step, step_to_json_schema
from ...core.sdk import ContinueSDK
@@ -27,34 +27,32 @@ class SimpleChatStep(Step):
messages: List[ChatMessage] = None
async def run(self, sdk: ContinueSDK):
- completion = ""
messages = self.messages or await sdk.get_chat_context()
generator = sdk.models.default.stream_chat(
messages, temperature=sdk.config.temperature)
- try:
- async for chunk in generator:
- if sdk.current_step_was_deleted():
- # So that the message doesn't disappear
- self.hide = False
- break
- if "content" in chunk:
- self.description += chunk["content"]
- completion += chunk["content"]
- await sdk.update_ui()
- finally:
- self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
- f"Write a short title for the following chat message: {self.description}"))
-
- self.chat_context.append(ChatMessage(
- role="assistant",
- content=completion,
- summary=self.name
- ))
-
- # TODO: Never actually closing.
- await generator.aclose()
+ async for chunk in generator:
+ if sdk.current_step_was_deleted():
+ # So that the message doesn't disappear
+ self.hide = False
+ break
+
+ if "content" in chunk:
+ self.description += chunk["content"]
+ await sdk.update_ui()
+
+ self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
+ f"Write a short title for the following chat message: {self.description}"))
+
+ self.chat_context.append(ChatMessage(
+ role="assistant",
+ content=self.description,
+ summary=self.name
+ ))
+
+ # TODO: Never actually closing.
+ await generator.aclose()
class AddFileStep(Step):
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 4c5303fb..fb9ea029 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -1,9 +1,13 @@
# These steps are depended upon by ContinueSDK
import os
-import subprocess
+import json
import difflib
from textwrap import dedent
-from typing import Coroutine, List, Literal, Union
+import traceback
+from typing import Any, Coroutine, List, Union
+import difflib
+
+from pydantic import validator
from ....libs.llm.ggml import GGML
from ....models.main import Range
@@ -14,7 +18,6 @@ from ....core.observation import Observation, TextObservation, TracebackObservat
from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
-import difflib
class ContinueSDK:
@@ -41,6 +44,25 @@ class MessageStep(Step):
return TextObservation(text=self.message)
+class DisplayErrorStep(Step):
+ name: str = "Error in the Continue server"
+ e: Any
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ @validator("e", pre=True, always=True)
+ def validate_e(cls, v):
+ if isinstance(v, Exception):
+ return '\n'.join(traceback.format_exception(v))
+
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
+ return self.e
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ raise ContinueCustomException(message=self.e, title=self.name)
+
+
class FileSystemEditStep(ReversibleStep):
edit: FileSystemEdit
_diff: Union[EditDiff, None] = None
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index 4d75af30..ec670999 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -34,26 +34,33 @@ class HelpStep(Step):
description: str = ""
async def run(self, sdk: ContinueSDK):
-
question = self.user_input
- prompt = dedent(f"""Please us the information below to provide a succinct answer to the following quesiton: {question}
-
- Information:
-
- {help}""")
-
- self.chat_context.append(ChatMessage(
- role="user",
- content=prompt,
- summary="Help"
- ))
- messages = await sdk.get_chat_context()
- generator = sdk.models.default.stream_chat(messages)
- async for chunk in generator:
- if "content" in chunk:
- self.description += chunk["content"]
- await sdk.update_ui()
+ if question.strip() == "":
+ self.description = help
+ else:
+ prompt = dedent(f"""
+ Information:
+
+ {help}
+
+ Instructions:
+
+ Please us the information below to provide a succinct answer to the following question: {question}
+
+ Do not cite any slash commands other than those you've been told about, which are: /edit and /feedback.""")
+
+ self.chat_context.append(ChatMessage(
+ role="user",
+ content=prompt,
+ summary="Help"
+ ))
+ messages = await sdk.get_chat_context()
+ generator = sdk.models.default.stream_chat(messages)
+ async for chunk in generator:
+ if "content" in chunk:
+ self.description += chunk["content"]
+ await sdk.update_ui()
posthog_logger.capture_event(
"help", {"question": question, "answer": self.description})
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index 26c1cabd..2c3d34fc 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -13,6 +13,7 @@ from ...core.sdk import ContinueSDK, Models
from ...core.observation import Observation
from .core.core import DefaultModelEditCodeStep
from ...libs.util.calculate_diff import calculate_diff2
+from ...libs.util.logging import logger
class Policy(BaseModel):
@@ -105,7 +106,7 @@ class FasterEditHighlightedCodeStep(Step):
# Temporarily doing this to generate description.
self._prompt = prompt
self._completion = completion
- print(completion)
+ logger.debug(completion)
# ALTERNATIVE DECODING STEP HERE
raw_file_edits = []
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index c13047d6..966acb7c 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -65,5 +65,5 @@ class EditAllMatchesStep(Step):
range=range_in_file.range,
filename=range_in_file.filepath,
prompt=self.user_request
- ), sdk.ide.unique_id) for range_in_file in range_in_files]
+ )) for range_in_file in range_in_files]
await asyncio.gather(*tasks)
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index c0957395..98a5aea0 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -8,10 +8,12 @@ import traceback
from uvicorn.main import Server
from .session_manager import session_manager, Session
+from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
from .gui_protocol import AbstractGUIProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..libs.util.create_async_task import create_async_task
+from ..libs.util.logging import logger
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -25,17 +27,13 @@ class AppStatus:
@staticmethod
def handle_exit(*args, **kwargs):
AppStatus.should_exit = True
- print("Shutting down")
+ logger.debug("Shutting down")
original_handler(*args, **kwargs)
Server.handle_exit = AppStatus.handle_exit
-async def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return await session_manager.get_session(x_continue_session_id)
-
-
async def websocket_session(session_id: str) -> Session:
return await session_manager.get_session(session_id)
@@ -73,103 +71,97 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
+ def on_error(self, e: Exception):
+ return self.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))
+
def handle_json(self, message_type: str, data: Any):
- try:
- if message_type == "main_input":
- self.on_main_input(data["input"])
- elif message_type == "step_user_input":
- self.on_step_user_input(data["input"], data["index"])
- elif message_type == "refinement_input":
- self.on_refinement_input(data["input"], data["index"])
- elif message_type == "reverse_to_index":
- self.on_reverse_to_index(data["index"])
- elif message_type == "retry_at_index":
- self.on_retry_at_index(data["index"])
- elif message_type == "clear_history":
- self.on_clear_history()
- elif message_type == "delete_at_index":
- self.on_delete_at_index(data["index"])
- elif message_type == "delete_context_with_ids":
- self.on_delete_context_with_ids(data["ids"])
- elif message_type == "toggle_adding_highlighted_code":
- self.on_toggle_adding_highlighted_code()
- elif message_type == "set_editing_at_indices":
- self.on_set_editing_at_indices(data["indices"])
- elif message_type == "show_logs_at_index":
- self.on_show_logs_at_index(data["index"])
- elif message_type == "select_context_item":
- self.select_context_item(data["id"], data["query"])
- except Exception as e:
- print(e)
+ if message_type == "main_input":
+ self.on_main_input(data["input"])
+ elif message_type == "step_user_input":
+ self.on_step_user_input(data["input"], data["index"])
+ elif message_type == "refinement_input":
+ self.on_refinement_input(data["input"], data["index"])
+ elif message_type == "reverse_to_index":
+ self.on_reverse_to_index(data["index"])
+ elif message_type == "retry_at_index":
+ self.on_retry_at_index(data["index"])
+ elif message_type == "clear_history":
+ self.on_clear_history()
+ elif message_type == "delete_at_index":
+ self.on_delete_at_index(data["index"])
+ elif message_type == "delete_context_with_ids":
+ self.on_delete_context_with_ids(data["ids"])
+ elif message_type == "toggle_adding_highlighted_code":
+ self.on_toggle_adding_highlighted_code()
+ elif message_type == "set_editing_at_indices":
+ self.on_set_editing_at_indices(data["indices"])
+ elif message_type == "show_logs_at_index":
+ self.on_show_logs_at_index(data["index"])
+ elif message_type == "select_context_item":
+ self.select_context_item(data["id"], data["query"])
def on_main_input(self, input: str):
# Do something with user input
- create_async_task(self.session.autopilot.accept_user_input(
- input), self.session.autopilot.continue_sdk.ide.unique_id)
+ create_async_task(
+ self.session.autopilot.accept_user_input(input), self.on_error)
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
- create_async_task(self.session.autopilot.reverse_to_index(
- index), self.session.autopilot.continue_sdk.ide.unique_id)
+ create_async_task(
+ self.session.autopilot.reverse_to_index(index), self.on_error)
def on_step_user_input(self, input: str, index: int):
create_async_task(
- self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
+ self.session.autopilot.give_user_input(input, index), self.on_error)
def on_refinement_input(self, input: str, index: int):
create_async_task(
- self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
+ self.session.autopilot.accept_refinement_input(input, index), self.on_error)
def on_retry_at_index(self, index: int):
create_async_task(
- self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id)
+ self.session.autopilot.retry_at_index(index), self.on_error)
def on_clear_history(self):
- create_async_task(self.session.autopilot.clear_history(
- ), self.session.autopilot.continue_sdk.ide.unique_id)
+ create_async_task(
+ self.session.autopilot.clear_history(), self.on_error)
def on_delete_at_index(self, index: int):
- create_async_task(self.session.autopilot.delete_at_index(
- index), self.session.autopilot.continue_sdk.ide.unique_id)
+ create_async_task(
+ self.session.autopilot.delete_at_index(index), self.on_error)
def on_delete_context_with_ids(self, ids: List[str]):
create_async_task(
- self.session.autopilot.delete_context_with_ids(
- ids), self.session.autopilot.continue_sdk.ide.unique_id
- )
+ self.session.autopilot.delete_context_with_ids(ids), self.on_error)
def on_toggle_adding_highlighted_code(self):
create_async_task(
- self.session.autopilot.toggle_adding_highlighted_code(
- ), self.session.autopilot.continue_sdk.ide.unique_id
- )
+ self.session.autopilot.toggle_adding_highlighted_code(), self.on_error)
def on_set_editing_at_indices(self, indices: List[int]):
create_async_task(
- self.session.autopilot.set_editing_at_indices(
- indices), self.session.autopilot.continue_sdk.ide.unique_id
- )
+ self.session.autopilot.set_editing_at_indices(indices), self.on_error)
def on_show_logs_at_index(self, index: int):
name = f"continue_logs.txt"
logs = "\n\n############################################\n\n".join(
["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
create_async_task(
- self.session.autopilot.ide.showVirtualFile(name, logs), self.session.autopilot.continue_sdk.ide.unique_id)
+ self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error)
def select_context_item(self, id: str, query: str):
"""Called when user selects an item from the dropdown"""
create_async_task(
- self.session.autopilot.select_context_item(id, query), self.session.autopilot.continue_sdk.ide.unique_id)
+ self.session.autopilot.select_context_item(id, query), self.on_error)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
try:
- print("Received websocket connection at url: ", websocket.url)
+ logger.debug(f"Received websocket connection at url: {websocket.url}")
await websocket.accept()
- print("Session started")
+ logger.debug("Session started")
session_manager.register_websocket(session.session_id, websocket)
protocol = GUIProtocolServer(session)
protocol.websocket = websocket
@@ -179,7 +171,7 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
while AppStatus.should_exit is False:
message = await websocket.receive_text()
- print("Received GUI message", message)
+ logger.debug(f"Received GUI message {message}")
if type(message) is str:
message = json.loads(message)
@@ -190,16 +182,21 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
protocol.handle_json(message_type, data)
except WebSocketDisconnect as e:
- print("GUI websocket disconnected")
+ logger.debug("GUI websocket disconnected")
except Exception as e:
- print("ERROR in gui websocket: ", e)
+ # Log, send to PostHog, and send to GUI
+ logger.debug(f"ERROR in gui websocket: {e}")
+ err_msg = '\n'.join(traceback.format_exception(e))
posthog_logger.capture_event("gui_error", {
- "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
+ "error_title": e.__str__() or e.__repr__(), "error_message": err_msg})
+
+ await protocol.session.autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))
+
raise e
finally:
- print("Closing gui websocket")
+ logger.debug("Closing gui websocket")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
await session_manager.persist_session(session.session_id)
- session_manager.remove_session(session.session_id)
+ await session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 87374928..e4c07029 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -10,6 +10,7 @@ from pydantic import BaseModel
import traceback
import asyncio
+from ..plugins.steps.core.core import DisplayErrorStep, MessageStep
from .meilisearch_server import start_meilisearch
from ..libs.util.telemetry import posthog_logger
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -19,6 +20,7 @@ from .gui import session_manager
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
from .session_manager import SessionManager
+from ..libs.util.logging import logger
import nest_asyncio
nest_asyncio.apply()
@@ -37,7 +39,7 @@ class AppStatus:
@staticmethod
def handle_exit(*args, **kwargs):
AppStatus.should_exit = True
- print("Shutting down")
+ logger.debug("Shutting down")
original_handler(*args, **kwargs)
@@ -140,7 +142,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
continue
message_type = message["messageType"]
data = message["data"]
- print("Received message while initializing", message_type)
+ logger.debug(f"Received message while initializing {message_type}")
if message_type == "workspaceDirectory":
self.workspace_directory = data["workspaceDirectory"]
elif message_type == "uniqueId":
@@ -154,9 +156,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def _send_json(self, message_type: str, data: Any):
if self.websocket.application_state == WebSocketState.DISCONNECTED:
- print("Tried to send message, but websocket is disconnected", message_type)
+ logger.debug(
+ f"Tried to send message, but websocket is disconnected: {message_type}")
return
- print("Sending IDE message: ", message_type)
+ logger.debug(f"Sending IDE message: {message_type}")
await self.websocket.send_json({
"messageType": message_type,
"data": data
@@ -167,7 +170,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
except asyncio.TimeoutError:
raise Exception(
- "IDE Protocol _receive_json timed out after 20 seconds", message_type)
+ f"IDE Protocol _receive_json timed out after 20 seconds: {message_type}")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -277,6 +280,9 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
# This is where you might have triggers: plugins can subscribe to certian events
# like file changes, tracebacks, etc...
+ def on_error(self, e: Exception):
+ return self.session_manager.sessions[self.session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))
+
def onAcceptRejectSuggestion(self, accepted: bool):
posthog_logger.capture_event("accept_reject_suggestion", {
"accepted": accepted
@@ -309,22 +315,22 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onDeleteAtIndex(self, index: int):
if autopilot := self.__get_autopilot():
- create_async_task(autopilot.delete_at_index(index), self.unique_id)
+ create_async_task(autopilot.delete_at_index(index), self.on_error)
def onCommandOutput(self, output: str):
if autopilot := self.__get_autopilot():
create_async_task(
- autopilot.handle_command_output(output), self.unique_id)
+ autopilot.handle_command_output(output), self.on_error)
def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
if autopilot := self.__get_autopilot():
create_async_task(autopilot.handle_highlighted_code(
- range_in_files), self.unique_id)
+ range_in_files), self.on_error)
def onMainUserInput(self, input: str):
if autopilot := self.__get_autopilot():
create_async_task(
- autopilot.accept_user_input(input), self.unique_id)
+ autopilot.accept_user_input(input), self.on_error)
# Request information. Session doesn't matter.
async def getOpenFiles(self) -> List[str]:
@@ -354,7 +360,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
}, GetUserSecretResponse, "getUserSecret")
return resp.value
except Exception as e:
- print("Error getting user secret", e)
+ logger.debug(f"Error getting user secret: {e}")
return ""
async def saveFile(self, filepath: str):
@@ -437,15 +443,15 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
try:
await websocket.accept()
- print("Accepted websocket connection from, ", websocket.client)
+ logger.debug(f"Accepted websocket connection from {websocket.client}")
await websocket.send_json({"messageType": "connected", "data": {}})
# Start meilisearch
try:
await start_meilisearch()
except Exception as e:
- print("Failed to start MeiliSearch")
- print(e)
+ logger.debug("Failed to start MeiliSearch")
+ logger.debug(e)
def handle_msg(msg):
message = json.loads(msg)
@@ -455,9 +461,9 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
message_type = message["messageType"]
data = message["data"]
- print("Received IDE message: ", message_type)
+ logger.debug(f"Received IDE message: {message_type}")
create_async_task(
- ideProtocolServer.handle_json(message_type, data))
+ ideProtocolServer.handle_json(message_type, data), ideProtocolServer.on_error)
ideProtocolServer = IdeProtocolServer(session_manager, websocket)
if session_id is not None:
@@ -473,15 +479,20 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
message = await websocket.receive_text()
handle_msg(message)
- print("Closing ide websocket")
+ logger.debug("Closing ide websocket")
except WebSocketDisconnect as e:
- print("IDE wbsocket disconnected")
+ logger.debug("IDE wbsocket disconnected")
except Exception as e:
- print("Error in ide websocket: ", e)
+ logger.debug(f"Error in ide websocket: {e}")
+ err_msg = '\n'.join(traceback.format_exception(e))
posthog_logger.capture_event("gui_error", {
- "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
+ "error_title": e.__str__() or e.__repr__(), "error_message": err_msg})
+
+ await session_manager.sessions[session_id].autopilot.continue_sdk.run_step(DisplayErrorStep(e=e))
+
raise e
finally:
+ logger.debug("Closing ide websocket")
if websocket.client_state != WebSocketState.DISCONNECTED:
await websocket.close()
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index b92c9fa3..468bc855 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -15,6 +15,7 @@ from .ide import router as ide_router
from .gui import router as gui_router
from .session_manager import session_manager
from ..libs.util.paths import getLogFilePath
+from ..libs.util.logging import logger
app = FastAPI()
@@ -33,45 +34,49 @@ app.add_middleware(
@app.get("/health")
def health():
- print("Testing")
+ logger.debug("Health check")
return {"status": "ok"}
-# add cli arg for server port
-parser = argparse.ArgumentParser()
-parser.add_argument("-p", "--port", help="server port",
- type=int, default=65432)
-args = parser.parse_args()
-
-log_path = getLogFilePath()
-LOG_CONFIG = {
- 'version': 1,
- 'disable_existing_loggers': False,
- 'handlers': {
- 'file': {
- 'level': 'DEBUG',
- 'class': 'logging.FileHandler',
- 'filename': log_path,
- },
- },
- 'root': {
- 'level': 'DEBUG',
- 'handlers': ['file']
- }
-}
-print(f"Log path: {log_path}")
+class Logger(object):
+ def __init__(self, log_file: str):
+ self.terminal = sys.stdout
+ self.log = open(log_file, "a")
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log.write(message)
+
+ def flush(self):
+ # this flush method is needed for python 3 compatibility.
+ # this handles the flush command by doing nothing.
+ # you might want to specify some extra behavior here.
+ pass
+
+ def isatty(self):
+ return False
+
+
+try:
+ # add cli arg for server port
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-p", "--port", help="server port",
+ type=int, default=65432)
+ args = parser.parse_args()
+except Exception as e:
+ logger.debug(f"Error parsing command line arguments: {e}")
+ raise e
def run_server():
- config = uvicorn.Config(app, host="0.0.0.0",
- port=args.port, log_config=LOG_CONFIG)
+ config = uvicorn.Config(app, host="127.0.0.1", port=args.port)
server = uvicorn.Server(config)
server.run()
async def cleanup_coroutine():
- print("Cleaning up sessions")
+ logger.debug("Cleaning up sessions")
for session_id in session_manager.sessions:
await session_manager.persist_session(session_id)
@@ -90,13 +95,14 @@ def cpu_usage_report():
time.sleep(1)
# Call cpu_percent again to get the CPU usage over the interval
cpu_usage = process.cpu_percent(interval=None)
- print(f"CPU usage: {cpu_usage}%")
+ logger.debug(f"CPU usage: {cpu_usage}%")
atexit.register(cleanup)
if __name__ == "__main__":
try:
+ # Uncomment to get CPU usage reports
# import threading
# def cpu_usage_loop():
@@ -109,6 +115,6 @@ if __name__ == "__main__":
run_server()
except Exception as e:
- print("Error starting Continue server: ", e)
+ logger.debug(f"Error starting Continue server: {e}")
cleanup()
raise e
diff --git a/continuedev/src/continuedev/server/meilisearch_server.py b/continuedev/src/continuedev/server/meilisearch_server.py
index 00f692f5..7f460afc 100644
--- a/continuedev/src/continuedev/server/meilisearch_server.py
+++ b/continuedev/src/continuedev/server/meilisearch_server.py
@@ -5,6 +5,7 @@ import subprocess
from meilisearch_python_async import Client
from ..libs.util.paths import getServerFolderPath
+from ..libs.util.logging import logger
def ensure_meilisearch_installed() -> bool:
@@ -39,7 +40,7 @@ def ensure_meilisearch_installed() -> bool:
shutil.rmtree(p, ignore_errors=True)
# Download MeiliSearch
- print("Downloading MeiliSearch...")
+ logger.debug("Downloading MeiliSearch...")
subprocess.run(
f"curl -L https://install.meilisearch.com | sh", shell=True, check=True, cwd=serverPath)
@@ -82,6 +83,6 @@ async def start_meilisearch():
# Check if MeiliSearch is running
if not await check_meilisearch_running() or not was_already_installed:
- print("Starting MeiliSearch...")
+ logger.debug("Starting MeiliSearch...")
subprocess.Popen(["./meilisearch", "--no-analytics"], cwd=serverPath, stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT, close_fds=True, start_new_session=True)
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 3136f1bf..cf46028f 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -4,6 +4,9 @@ from typing import Any, Dict, List, Union
from uuid import uuid4
import json
+from fastapi.websockets import WebSocketState
+
+from ..plugins.steps.core.core import DisplayErrorStep
from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
@@ -13,6 +16,7 @@ from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
from ..libs.util.create_async_task import create_async_task
from ..libs.util.errors import SessionNotFound
+from ..libs.util.logging import logger
class Session:
@@ -59,6 +63,8 @@ class SessionManager:
return self.sessions[session_id]
async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ logger.debug(f"New session: {session_id}")
+
full_state = None
if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
with open(getSessionFilePath(session_id), "r") as f:
@@ -78,29 +84,35 @@ class SessionManager:
})
autopilot.on_update(on_update)
- create_async_task(autopilot.run_policy())
+ create_async_task(autopilot.run_policy(
+ ), lambda e: autopilot.continue_sdk.run_step(DisplayErrorStep(e=e)))
return session
- def remove_session(self, session_id: str):
- del self.sessions[session_id]
+ async def remove_session(self, session_id: str):
+ logger.debug(f"Removing session: {session_id}")
+ if session_id in self.sessions:
+ if session_id in self.registered_ides:
+ ws_to_close = self.registered_ides[session_id].websocket
+ if ws_to_close is not None and ws_to_close.client_state != WebSocketState.DISCONNECTED:
+ await self.sessions[session_id].autopilot.ide.websocket.close()
+
+ del self.sessions[session_id]
async def persist_session(self, session_id: str):
"""Save the session's FullState as a json file"""
full_state = await self.sessions[session_id].autopilot.get_full_state()
- if not os.path.exists(getSessionsFolderPath()):
- os.mkdir(getSessionsFolderPath())
with open(getSessionFilePath(session_id), "w") as f:
json.dump(full_state.dict(), f)
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
- print("Registered websocket for session", session_id)
+ logger.debug(f"Registered websocket for session {session_id}")
async def send_ws_data(self, session_id: str, message_type: str, data: Any):
if session_id not in self.sessions:
raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
- # print(f"Session {session_id} has no websocket")
+ # logger.debug(f"Session {session_id} has no websocket")
return
await self.sessions[session_id].ws.send_json({