summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 16:25:02 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 16:25:02 -0700
commit4c3a25a1c8938f8132233e021c74d98eb19d7ddd (patch)
tree8460e5703f224e7ef5c2c7eca6b470f338b93e1e /continuedev
parent3ded151331933c9a1352cc46c3cc67c5733d1c86 (diff)
parenta4a815628f702af806603015ec6805edd151328b (diff)
downloadsncontinue-4c3a25a1c8938f8132233e021c74d98eb19d7ddd.tar.gz
sncontinue-4c3a25a1c8938f8132233e021c74d98eb19d7ddd.tar.bz2
sncontinue-4c3a25a1c8938f8132233e021c74d98eb19d7ddd.zip
Merge branch 'main' into ggml-server
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py4
-rw-r--r--continuedev/src/continuedev/core/autopilot.py44
-rw-r--r--continuedev/src/continuedev/core/config.py9
-rw-r--r--continuedev/src/continuedev/core/policy.py7
-rw-r--r--continuedev/src/continuedev/core/sdk.py95
-rw-r--r--continuedev/src/continuedev/libs/constants/main.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py6
-rw-r--r--continuedev/src/continuedev/libs/util/create_async_task.py24
-rw-r--r--continuedev/src/continuedev/libs/util/errors.py2
-rw-r--r--continuedev/src/continuedev/libs/util/paths.py17
-rw-r--r--continuedev/src/continuedev/libs/util/step_name_to_steps.py4
-rw-r--r--continuedev/src/continuedev/recipes/TemplateRecipe/main.py4
-rw-r--r--continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/server/gui.py83
-rw-r--r--continuedev/src/continuedev/server/ide.py149
-rw-r--r--continuedev/src/continuedev/server/ide_protocol.py27
-rw-r--r--continuedev/src/continuedev/server/main.py16
-rw-r--r--continuedev/src/continuedev/server/session_manager.py51
-rw-r--r--continuedev/src/continuedev/server/state_manager.py21
-rw-r--r--continuedev/src/continuedev/steps/chat.py21
-rw-r--r--continuedev/src/continuedev/steps/core/core.py178
-rw-r--r--continuedev/src/continuedev/steps/help.py59
-rw-r--r--continuedev/src/continuedev/steps/main.py48
-rw-r--r--continuedev/src/continuedev/steps/open_config.py9
-rw-r--r--continuedev/src/continuedev/steps/search_directory.py5
-rw-r--r--continuedev/src/continuedev/steps/welcome.py2
26 files changed, 604 insertions, 289 deletions
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 7bd3da6c..94d7be10 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -76,9 +76,7 @@ class AbstractContinueSDK(ABC):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
pass
- @abstractproperty
- def config(self) -> ContinueConfig:
- pass
+ config: ContinueConfig
@abstractmethod
def set_loading_message(self, message: str):
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 5c3baafd..0696c360 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -1,13 +1,13 @@
from functools import cached_property
import traceback
import time
-from typing import Any, Callable, Coroutine, Dict, List
+from typing import Any, Callable, Coroutine, Dict, List, Union
import os
from aiohttp import ClientPayloadError
+from pydantic import root_validator
from ..models.filesystem import RangeInFileWithContents
from ..models.filesystem_edit import FileEditWithFullContents
-from ..libs.llm import LLM
from .observation import Observation, InternalErrorObservation
from ..server.ide_protocol import AbstractIdeProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -16,10 +16,10 @@ from .main import Context, ContinueCustomException, HighlightedRangeContext, Pol
from ..steps.core.core import ReversibleStep, ManualEditStep, UserInputStep
from ..libs.util.telemetry import capture_event
from .sdk import ContinueSDK
-import asyncio
from ..libs.util.step_name_to_steps import get_step_from_name
from ..libs.util.traceback_parsers import get_python_traceback, get_javascript_traceback
from openai import error as openai_errors
+from ..libs.util.create_async_task import create_async_task
def get_error_title(e: Exception) -> str:
@@ -34,9 +34,11 @@ def get_error_title(e: Exception) -> str:
elif isinstance(e, ClientPayloadError):
return "The request to OpenAI failed. Please try again."
elif isinstance(e, openai_errors.APIConnectionError):
- return "The request failed. Please check your internet connection and try again."
+ return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\""
elif isinstance(e, openai_errors.InvalidRequestError):
return 'Your API key does not have access to GPT-4. You can use ours for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to ""'
+ elif e.__str__().startswith("Cannot connect to host"):
+ return "The request failed. Please check your internet connection and try again."
return e.__str__() or e.__repr__()
@@ -45,8 +47,11 @@ class Autopilot(ContinueBaseModel):
ide: AbstractIdeProtocolServer
history: History = History.from_empty()
context: Context = Context()
+ full_state: Union[FullState, None] = None
_on_update_callbacks: List[Callable[[FullState], None]] = []
+ continue_sdk: ContinueSDK = None
+
_active: bool = False
_should_halt: bool = False
_main_user_input_queue: List[str] = []
@@ -54,16 +59,25 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @cached_property
- def continue_sdk(self) -> ContinueSDK:
- return ContinueSDK(self)
+ @classmethod
+ async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
+ autopilot = cls(ide=ide, policy=policy)
+ autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ return autopilot
class Config:
arbitrary_types_allowed = True
keep_untouched = (cached_property,)
+ @root_validator(pre=True)
+ def fill_in_values(cls, values):
+ full_state: FullState = values.get('full_state')
+ if full_state is not None:
+ values['history'] = full_state.history
+ return values
+
def get_full_state(self) -> FullState:
- return FullState(
+ full_state = FullState(
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
@@ -72,6 +86,8 @@ class Autopilot(ContinueBaseModel):
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self._adding_highlighted_code,
)
+ self.full_state = full_state
+ return full_state
def get_available_slash_commands(self) -> List[Dict]:
custom_commands = list(map(lambda x: {
@@ -207,6 +223,8 @@ class Autopilot(ContinueBaseModel):
async def delete_at_index(self, index: int):
self.history.timeline[index].step.hide = True
self.history.timeline[index].deleted = True
+ self.history.timeline[index].active = False
+
await self.update_subscribers()
async def delete_context_at_indices(self, indices: List[int]):
@@ -250,7 +268,7 @@ class Autopilot(ContinueBaseModel):
# i -= 1
capture_event(self.continue_sdk.ide.unique_id, 'step run', {
- 'step_name': step.name, 'params': step.dict()})
+ 'step_name': step.name, 'params': step.dict()})
if not is_future_step:
# Check manual edits buffer, clear out if needed by creating a ManualEditStep
@@ -284,12 +302,13 @@ class Autopilot(ContinueBaseModel):
e.__class__, ContinueCustomException)
error_string = e.message if is_continue_custom_exception else '\n'.join(
- traceback.format_tb(e.__traceback__)) + f"\n\n{e.__repr__()}"
+ traceback.format_exception(e))
error_title = e.title if is_continue_custom_exception else get_error_title(
e)
# Attach an InternalErrorObservation to the step and unhide it.
- print(f"Error while running step: \n{error_string}\n{error_title}")
+ print(
+ f"Error while running step: \n{error_string}\n{error_title}")
capture_event(self.continue_sdk.ide.unique_id, 'step error', {
'error_message': error_string, 'error_title': error_title, 'step_name': step.name, 'params': step.dict()})
@@ -341,7 +360,8 @@ class Autopilot(ContinueBaseModel):
# Update subscribers with new description
await self.update_subscribers()
- asyncio.create_task(update_description())
+ create_async_task(update_description(),
+ self.continue_sdk.ide.unique_id)
return observation
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 55f5bc60..6e430c04 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -45,6 +45,11 @@ DEFAULT_SLASH_COMMANDS = [
step_name="OpenConfigStep",
),
SlashCommand(
+ name="help",
+ description="Ask a question like '/help what is given to the llm as context?'",
+ step_name="HelpStep",
+ ),
+ SlashCommand(
name="comment",
description="Write comments for the current file or highlighted code",
step_name="CommentCodeStep",
@@ -131,7 +136,7 @@ def load_global_config() -> ContinueConfig:
config_path = os.path.join(global_dir, 'config.json')
if not os.path.exists(config_path):
with open(config_path, 'w') as f:
- json.dump(ContinueConfig().dict(), f)
+ json.dump(ContinueConfig().dict(), f, indent=4)
with open(config_path, 'r') as f:
try:
config_dict = json.load(f)
@@ -151,7 +156,7 @@ def update_global_config(config: ContinueConfig):
yaml_path = os.path.join(global_dir, 'config.yaml')
if os.path.exists(yaml_path):
with open(config_path, 'w') as f:
- yaml.dump(config.dict(), f)
+ yaml.dump(config.dict(), f, indent=4)
else:
config_path = os.path.join(global_dir, 'config.json')
with open(config_path, 'w') as f:
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index b8363df2..bc897357 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -59,11 +59,8 @@ class DemoPolicy(Policy):
return (
MessageStep(name="Welcome to Continue", message=dedent("""\
- Highlight code and ask a question or give instructions
- - Use `cmd+k` (Mac) / `ctrl+k` (Windows) to open Continue
- - Use `cmd+shift+e` / `ctrl+shift+e` to open file Explorer
- - Add your own OpenAI API key to VS Code Settings with `cmd+,`
- - Use slash commands when you want fine-grained control
- - Past steps are included as part of the context by default""")) >>
+ - Use `cmd+m` (Mac) / `ctrl+m` (Windows) to open Continue
+ - Use `/help` to ask questions about how to use Continue""")) >>
WelcomeStep() >>
# SetupContinueWorkspaceStep() >>
# CreateCodebaseIndexChroma() >>
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 22393746..9389e1e9 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,6 +1,6 @@
import asyncio
from functools import cached_property
-from typing import Coroutine, Union
+from typing import Coroutine, Dict, Union
import os
from ..steps.core.core import DefaultModelEditCodeStep
@@ -14,7 +14,7 @@ from ..libs.llm.openai import OpenAI
from ..libs.llm.ggml import GGML
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
-from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole
+from .main import Context, ContinueCustomException, History, Step, ChatMessage
from ..steps.core.core import *
from ..libs.llm.proxy_server import ProxyServer
@@ -23,26 +23,46 @@ class Autopilot:
pass
+ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
+MODEL_PROVIDER_TO_ENV_VAR = {
+ "openai": "OPENAI_API_KEY",
+ "hf_inference_api": "HUGGING_FACE_TOKEN",
+ "anthropic": "ANTHROPIC_API_KEY"
+}
+
+
class Models:
- def __init__(self, sdk: "ContinueSDK"):
+ provider_keys: Dict[ModelProvider, str] = {}
+ model_providers: List[ModelProvider]
+
+ def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
+ self.model_providers = model_providers
+
+ @classmethod
+ async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
+ models = Models(sdk, with_providers)
+ for provider in with_providers:
+ if provider in MODEL_PROVIDER_TO_ENV_VAR:
+ env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
+ models.provider_keys[provider] = await sdk.get_user_secret(
+ env_var, f'Please add your {env_var} to the .env file')
+
+ return models
def __load_openai_model(self, model: str) -> OpenAI:
- async def load_openai_model():
- api_key = await self.sdk.get_user_secret(
- 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free')
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model)
- return OpenAI(api_key=api_key, default_model=model)
- return asyncio.get_event_loop().run_until_complete(load_openai_model())
+ api_key = self.provider_keys["openai"]
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, model)
+ return OpenAI(api_key=api_key, default_model=model)
+
+ def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
+ api_key = self.provider_keys["hf_inference_api"]
+ return HuggingFaceInferenceAPI(api_key=api_key, model=model)
@cached_property
def starcoder(self):
- async def load_starcoder():
- api_key = await self.sdk.get_user_secret(
- 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file')
- return HuggingFaceInferenceAPI(api_key=api_key)
- return asyncio.get_event_loop().run_until_complete(load_starcoder())
+ return self.__load_hf_inference_api_model("bigcode/starcoder")
@cached_property
def gpt35(self):
@@ -80,7 +100,7 @@ class Models:
def default(self):
return self.ggml
default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt35
+ return self.__model_from_name(default_model) if default_model is not None else self.gpt4
class ContinueSDK(AbstractContinueSDK):
@@ -93,8 +113,27 @@ class ContinueSDK(AbstractContinueSDK):
def __init__(self, autopilot: Autopilot):
self.ide = autopilot.ide
self.__autopilot = autopilot
- self.models = Models(self)
self.context = autopilot.context
+ self.config = self._load_config()
+
+ @classmethod
+ async def create(cls, autopilot: Autopilot) -> "ContinueSDK":
+ sdk = ContinueSDK(autopilot)
+ sdk.models = await Models.create(sdk)
+ return sdk
+
+ config: ContinueConfig
+
+ def _load_config(self) -> ContinueConfig:
+ dir = self.ide.workspace_directory
+ yaml_path = os.path.join(dir, '.continue', 'config.yaml')
+ json_path = os.path.join(dir, '.continue', 'config.json')
+ if os.path.exists(yaml_path):
+ return load_config(yaml_path)
+ elif os.path.exists(json_path):
+ return load_config(json_path)
+ else:
+ return load_global_config()
@property
def history(self) -> History:
@@ -172,18 +211,6 @@ class ContinueSDK(AbstractContinueSDK):
async def get_user_secret(self, env_var: str, prompt: str) -> str:
return await self.ide.getUserSecret(env_var)
- @property
- def config(self) -> ContinueConfig:
- dir = self.ide.workspace_directory
- yaml_path = os.path.join(dir, '.continue', 'config.yaml')
- json_path = os.path.join(dir, '.continue', 'config.json')
- if os.path.exists(yaml_path):
- return load_config(yaml_path)
- elif os.path.exists(json_path):
- return load_config(json_path)
- else:
- return load_global_config()
-
def get_code_context(self, only_editing: bool = False) -> List[RangeInFileWithContents]:
context = list(filter(lambda x: x.editing, self.__autopilot._highlighted_ranges)
) if only_editing else self.__autopilot._highlighted_ranges
@@ -208,14 +235,14 @@ class ContinueSDK(AbstractContinueSDK):
preface = "The following code is highlighted"
+ # If no higlighted ranges, use first file as context
if len(highlighted_code) == 0:
preface = "The following file is open"
- # Get the full contents of all open files
- files = await self.ide.getOpenFiles()
- if len(files) > 0:
- content = await self.ide.readFile(files[0])
+ visible_files = await self.ide.getVisibleFiles()
+ if len(visible_files) > 0:
+ content = await self.ide.readFile(visible_files[0])
highlighted_code = [
- RangeInFileWithContents.from_entire_file(files[0], content)]
+ RangeInFileWithContents.from_entire_file(visible_files[0], content)]
for rif in highlighted_code:
msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{rif.contents}\n```",
diff --git a/continuedev/src/continuedev/libs/constants/main.py b/continuedev/src/continuedev/libs/constants/main.py
new file mode 100644
index 00000000..96eb6e69
--- /dev/null
+++ b/continuedev/src/continuedev/libs/constants/main.py
@@ -0,0 +1,6 @@
+## PATHS ##
+
+CONTINUE_GLOBAL_FOLDER = ".continue"
+CONTINUE_SESSIONS_FOLDER = "sessions"
+CONTINUE_SERVER_FOLDER = "server"
+
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 1586c620..803ba122 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -9,7 +9,11 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
api_key: str
- model: str = "bigcode/starcoder"
+ model: str
+
+ def __init__(self, api_key: str, model: str):
+ self.api_key = api_key
+ self.model = model
def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
"""Return the completion of the text with the given temperature."""
diff --git a/continuedev/src/continuedev/libs/util/create_async_task.py b/continuedev/src/continuedev/libs/util/create_async_task.py
new file mode 100644
index 00000000..354cea82
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/create_async_task.py
@@ -0,0 +1,24 @@
+from typing import Coroutine, Union
+import traceback
+from .telemetry import capture_event
+import asyncio
+import nest_asyncio
+nest_asyncio.apply()
+
+
+def create_async_task(coro: Coroutine, unique_id: Union[str, None] = None):
+ """asyncio.create_task and log errors by adding a callback"""
+ task = asyncio.create_task(coro)
+
+ def callback(future: asyncio.Future):
+ try:
+ future.result()
+ except Exception as e:
+ print("Exception caught from async task: ",
+ '\n'.join(traceback.format_exception(e)))
+ capture_event(unique_id or "None", "async_task_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))
+ })
+
+ task.add_done_callback(callback)
+ return task
diff --git a/continuedev/src/continuedev/libs/util/errors.py b/continuedev/src/continuedev/libs/util/errors.py
new file mode 100644
index 00000000..46074cfc
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/errors.py
@@ -0,0 +1,2 @@
+class SessionNotFound(Exception):
+ pass
diff --git a/continuedev/src/continuedev/libs/util/paths.py b/continuedev/src/continuedev/libs/util/paths.py
new file mode 100644
index 00000000..fddef887
--- /dev/null
+++ b/continuedev/src/continuedev/libs/util/paths.py
@@ -0,0 +1,17 @@
+import os
+
+from ..constants.main import CONTINUE_SESSIONS_FOLDER, CONTINUE_GLOBAL_FOLDER, CONTINUE_SERVER_FOLDER
+
+def getGlobalFolderPath():
+ return os.path.join(os.path.expanduser("~"), CONTINUE_GLOBAL_FOLDER)
+
+
+
+def getSessionsFolderPath():
+ return os.path.join(getGlobalFolderPath(), CONTINUE_SESSIONS_FOLDER)
+
+def getServerFolderPath():
+ return os.path.join(getGlobalFolderPath(), CONTINUE_SERVER_FOLDER)
+
+def getSessionFilePath(session_id: str):
+ return os.path.join(getSessionsFolderPath(), f"{session_id}.json") \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/step_name_to_steps.py b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
index d329e110..49056c81 100644
--- a/continuedev/src/continuedev/libs/util/step_name_to_steps.py
+++ b/continuedev/src/continuedev/libs/util/step_name_to_steps.py
@@ -13,6 +13,7 @@ from ...recipes.DeployPipelineAirflowRecipe.main import DeployPipelineAirflowRec
from ...steps.on_traceback import DefaultOnTracebackStep
from ...steps.clear_history import ClearHistoryStep
from ...steps.open_config import OpenConfigStep
+from ...steps.help import HelpStep
# This mapping is used to convert from string in ContinueConfig json to corresponding Step class.
# Used for example in slash_commands and steps_on_startup
@@ -28,7 +29,8 @@ step_name_to_step_class = {
"DeployPipelineAirflowRecipe": DeployPipelineAirflowRecipe,
"DefaultOnTracebackStep": DefaultOnTracebackStep,
"ClearHistoryStep": ClearHistoryStep,
- "OpenConfigStep": OpenConfigStep
+ "OpenConfigStep": OpenConfigStep,
+ "HelpStep": HelpStep,
}
diff --git a/continuedev/src/continuedev/recipes/TemplateRecipe/main.py b/continuedev/src/continuedev/recipes/TemplateRecipe/main.py
index 94675725..16132cfd 100644
--- a/continuedev/src/continuedev/recipes/TemplateRecipe/main.py
+++ b/continuedev/src/continuedev/recipes/TemplateRecipe/main.py
@@ -20,8 +20,8 @@ class TemplateRecipe(Step):
# The code executed when the recipe is run
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- open_files = await sdk.ide.getOpenFiles()
+ visible_files = await sdk.ide.getVisibleFiles()
await sdk.edit_file(
- filename=open_files[0],
+ filename=visible_files[0],
prompt=f"Append a statement to print `Hello, {self.name}!` at the end of the file."
)
diff --git a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py
index 6e1244b3..c7a65fa6 100644
--- a/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/recipes/WritePytestsRecipe/main.py
@@ -14,7 +14,7 @@ class WritePytestsRecipe(Step):
async def run(self, sdk: ContinueSDK):
if self.for_filepath is None:
- self.for_filepath = (await sdk.ide.getOpenFiles())[0]
+ self.for_filepath = (await sdk.ide.getVisibleFiles())[0]
filename = os.path.basename(self.for_filepath)
dirname = os.path.dirname(self.for_filepath)
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 8e9b1fb9..4201353e 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -1,15 +1,17 @@
+import asyncio
import json
from fastapi import Depends, Header, WebSocket, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from typing import Any, List, Type, TypeVar, Union
from pydantic import BaseModel
+import traceback
from uvicorn.main import Server
from .session_manager import SessionManager, session_manager, Session
from .gui_protocol import AbstractGUIProtocolServer
from ..libs.util.queue import AsyncSubscriptionQueue
-import asyncio
-import nest_asyncio
-nest_asyncio.apply()
+from ..libs.util.telemetry import capture_event
+from ..libs.util.create_async_task import create_async_task
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -30,12 +32,12 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+async def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return await session_manager.get_session(x_continue_session_id)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+async def websocket_session(session_id: str) -> Session:
+ return await session_manager.get_session(session_id)
T = TypeVar("T", bound=BaseModel)
@@ -52,13 +54,19 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.session = session
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "GUI Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -102,51 +110,60 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
def on_main_input(self, input: str):
# Do something with user input
- asyncio.create_task(self.session.autopilot.accept_user_input(input))
+ create_async_task(self.session.autopilot.accept_user_input(
+ input), self.session.autopilot.continue_sdk.ide.unique_id)
def on_reverse_to_index(self, index: int):
# Reverse the history to the given index
- asyncio.create_task(self.session.autopilot.reverse_to_index(index))
+ create_async_task(self.session.autopilot.reverse_to_index(
+ index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_step_user_input(self, input: str, index: int):
- asyncio.create_task(
- self.session.autopilot.give_user_input(input, index))
+ create_async_task(
+ self.session.autopilot.give_user_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_refinement_input(self, input: str, index: int):
- asyncio.create_task(
- self.session.autopilot.accept_refinement_input(input, index))
+ create_async_task(
+ self.session.autopilot.accept_refinement_input(input, index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_retry_at_index(self, index: int):
- asyncio.create_task(
- self.session.autopilot.retry_at_index(index))
+ create_async_task(
+ self.session.autopilot.retry_at_index(index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_change_default_model(self, model: str):
- asyncio.create_task(self.session.autopilot.change_default_model(model))
+ create_async_task(self.session.autopilot.change_default_model(
+ model), self.session.autopilot.continue_sdk.ide.unique_id)
def on_clear_history(self):
- asyncio.create_task(self.session.autopilot.clear_history())
+ create_async_task(self.session.autopilot.clear_history(
+ ), self.session.autopilot.continue_sdk.ide.unique_id)
def on_delete_at_index(self, index: int):
- asyncio.create_task(self.session.autopilot.delete_at_index(index))
+ create_async_task(self.session.autopilot.delete_at_index(
+ index), self.session.autopilot.continue_sdk.ide.unique_id)
def on_delete_context_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.delete_context_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.delete_context_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_toggle_adding_highlighted_code(self):
- asyncio.create_task(
- self.session.autopilot.toggle_adding_highlighted_code()
+ create_async_task(
+ self.session.autopilot.toggle_adding_highlighted_code(
+ ), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_set_editing_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.set_editing_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.set_editing_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
def on_set_pinned_at_indices(self, indices: List[int]):
- asyncio.create_task(
- self.session.autopilot.set_pinned_at_indices(indices)
+ create_async_task(
+ self.session.autopilot.set_pinned_at_indices(
+ indices), self.session.autopilot.continue_sdk.ide.unique_id
)
@@ -176,11 +193,17 @@ async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(we
data = message["data"]
protocol.handle_json(message_type, data)
-
+ except WebSocketDisconnect as e:
+ print("GUI websocket disconnected")
except Exception as e:
print("ERROR in gui websocket: ", e)
+ capture_event(session.autopilot.continue_sdk.ide.unique_id, "gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
finally:
print("Closing gui websocket")
- await websocket.close()
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
+
+ session_manager.persist_session(session.session_id)
session_manager.remove_session(session.session_id)
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index e4a6266a..a91708ec 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -5,7 +5,9 @@ import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
from fastapi import WebSocket, Body, APIRouter
+from starlette.websockets import WebSocketState, WebSocketDisconnect
from uvicorn.main import Server
+import traceback
from ..libs.util.telemetry import capture_event
from ..libs.util.queue import AsyncSubscriptionQueue
@@ -15,6 +17,7 @@ from pydantic import BaseModel
from .gui import SessionManager, session_manager
from .ide_protocol import AbstractIdeProtocolServer
import asyncio
+from ..libs.util.create_async_task import create_async_task
import nest_asyncio
nest_asyncio.apply()
@@ -50,6 +53,10 @@ class OpenFilesResponse(BaseModel):
openFiles: List[str]
+class VisibleFilesResponse(BaseModel):
+ visibleFiles: List[str]
+
+
class HighlightedCodeResponse(BaseModel):
highlightedCode: List[RangeInFile]
@@ -110,19 +117,52 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
websocket: WebSocket
session_manager: SessionManager
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ session_id: Union[str, None] = None
def __init__(self, session_manager: SessionManager, websocket: WebSocket):
self.websocket = websocket
self.session_manager = session_manager
+ workspace_directory: str = None
+ unique_id: str = None
+
+ async def initialize(self, session_id: str) -> List[str]:
+ self.session_id = session_id
+ await self._send_json("workspaceDirectory", {})
+ await self._send_json("uniqueId", {})
+ other_msgs = []
+ while True:
+ msg_string = await self.websocket.receive_text()
+ message = json.loads(msg_string)
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+ if message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
+ else:
+ other_msgs.append(msg_string)
+
+ if self.workspace_directory is not None and self.unique_id is not None:
+ break
+ return other_msgs
+
async def _send_json(self, message_type: str, data: Any):
+ if self.websocket.application_state == WebSocketState.DISCONNECTED:
+ return
await self.websocket.send_json({
"messageType": message_type,
"data": data
})
- async def _receive_json(self, message_type: str) -> Any:
- return await self.sub_queue.get(message_type)
+ async def _receive_json(self, message_type: str, timeout: int = 5) -> Any:
+ try:
+ return await asyncio.wait_for(self.sub_queue.get(message_type), timeout=timeout)
+ except asyncio.TimeoutError:
+ raise Exception(
+ "IDE Protocol _receive_json timed out after 5 seconds")
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
await self._send_json(message_type, data)
@@ -130,8 +170,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
return resp_model.parse_obj(resp)
async def handle_json(self, message_type: str, data: Any):
- if message_type == "openGUI":
- await self.openGUI()
+ if message_type == "getSessionId":
+ await self.getSessionId()
elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
elif message_type == "setSuggestionsLocked":
@@ -154,8 +194,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
self.onMainUserInput(data["input"])
elif message_type == "deleteAtIndex":
self.onDeleteAtIndex(data["index"])
- elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory", "getUserSecret", "runCommand", "uniqueId"]:
+ elif message_type in ["highlightedCode", "openFiles", "visibleFiles", "readFile", "editFile", "getUserSecret", "runCommand"]:
self.sub_queue.post(message_type, data)
+ elif message_type == "workspaceDirectory":
+ self.workspace_directory = data["workspaceDirectory"]
+ elif message_type == "uniqueId":
+ self.unique_id = data["uniqueId"]
else:
raise ValueError("Unknown message type", message_type)
@@ -187,9 +231,10 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
"locked": locked
})
- async def openGUI(self):
- session_id = self.session_manager.new_session(self)
- await self._send_json("openGUI", {
+ async def getSessionId(self):
+ session_id = (await self.session_manager.new_session(
+ self, self.session_id)).session_id
+ await self._send_json("getSessionId", {
"sessionId": session_id
})
@@ -242,53 +287,40 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def onOpenGUIRequest(self):
pass
+ def __get_autopilot(self):
+ return self.session_manager.sessions[self.session_id].autopilot
+
def onFileEdits(self, edits: List[FileEditWithFullContents]):
- # Send the file edits to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- session.autopilot.handle_manual_edits(edits)
+ if autopilot := self.__get_autopilot():
+ autopilot.handle_manual_edits(edits)
def onDeleteAtIndex(self, index: int):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(session.autopilot.delete_at_index(index))
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.delete_at_index(index), self.unique_id)
def onCommandOutput(self, output: str):
- # Send the output to ALL autopilots.
- # Maybe not ideal behavior
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.handle_command_output(output))
+ if autopilot := self.__get_autopilot():
+ create_async_task(
+ autopilot.handle_command_output(output), self.unique_id)
def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.handle_highlighted_code(range_in_files))
+ if autopilot := self.__get_autopilot():
+ create_async_task(autopilot.handle_highlighted_code(
+ range_in_files), self.unique_id)
def onMainUserInput(self, input: str):
- for _, session in self.session_manager.sessions.items():
- asyncio.create_task(
- session.autopilot.accept_user_input(input))
+ if autopilot := self.__get_autopilot():
+ create_async_task(
+ autopilot.accept_user_input(input), self.unique_id)
# Request information. Session doesn't matter.
async def getOpenFiles(self) -> List[str]:
resp = await self._send_and_receive_json({}, OpenFilesResponse, "openFiles")
return resp.openFiles
- async def getWorkspaceDirectory(self) -> str:
- resp = await self._send_and_receive_json({}, WorkspaceDirectoryResponse, "workspaceDirectory")
- return resp.workspaceDirectory
-
- async def get_unique_id(self) -> str:
- resp = await self._send_and_receive_json({}, UniqueIdResponse, "uniqueId")
- return resp.uniqueId
-
- @property
- def workspace_directory(self) -> str:
- return asyncio.run(self.getWorkspaceDirectory())
-
- @cached_property_no_none
- def unique_id(self) -> str:
- return asyncio.run(self.get_unique_id())
+ async def getVisibleFiles(self) -> List[str]:
+ resp = await self._send_and_receive_json({}, VisibleFilesResponse, "visibleFiles")
+ return resp.visibleFiles
async def getHighlightedCode(self) -> List[RangeInFile]:
resp = await self._send_and_receive_json({}, HighlightedCodeResponse, "highlightedCode")
@@ -389,28 +421,45 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
@router.websocket("/ws")
-async def websocket_endpoint(websocket: WebSocket):
+async def websocket_endpoint(websocket: WebSocket, session_id: str = None):
try:
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
await websocket.send_json({"messageType": "connected", "data": {}})
- ideProtocolServer = IdeProtocolServer(session_manager, websocket)
-
- while AppStatus.should_exit is False:
- message = await websocket.receive_text()
- message = json.loads(message)
+ def handle_msg(msg):
+ message = json.loads(msg)
if "messageType" not in message or "data" not in message:
- continue
+ return
message_type = message["messageType"]
data = message["data"]
- await ideProtocolServer.handle_json(message_type, data)
+ create_async_task(
+ ideProtocolServer.handle_json(message_type, data))
+
+ ideProtocolServer = IdeProtocolServer(session_manager, websocket)
+ if session_id is not None:
+ session_manager.registered_ides[session_id] = ideProtocolServer
+ other_msgs = await ideProtocolServer.initialize(session_id)
+
+ for other_msg in other_msgs:
+ handle_msg(other_msg)
+
+ while AppStatus.should_exit is False:
+ message = await websocket.receive_text()
+ handle_msg(message)
print("Closing ide websocket")
- await websocket.close()
+ except WebSocketDisconnect as e:
+ print("IDE wbsocket disconnected")
except Exception as e:
print("Error in ide websocket: ", e)
- await websocket.close()
+ capture_event(ideProtocolServer.unique_id, "gui_error", {
+ "error_title": e.__str__() or e.__repr__(), "error_message": '\n'.join(traceback.format_exception(e))})
raise e
+ finally:
+ if websocket.client_state != WebSocketState.DISCONNECTED:
+ await websocket.close()
+
+ session_manager.registered_ides.pop(ideProtocolServer.session_id)
diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py
index dfdca504..d0fb0bf8 100644
--- a/continuedev/src/continuedev/server/ide_protocol.py
+++ b/continuedev/src/continuedev/server/ide_protocol.py
@@ -1,5 +1,6 @@
-from typing import Any, List
+from typing import Any, List, Union
from abc import ABC, abstractmethod, abstractproperty
+from fastapi import WebSocket
from ..models.main import Traceback
from ..models.filesystem_edit import FileEdit, FileSystemEdit, EditDiff
@@ -7,6 +8,9 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents
class AbstractIdeProtocolServer(ABC):
+ websocket: WebSocket
+ session_id: Union[str, None]
+
@abstractmethod
async def handle_json(self, data: Any):
"""Handle a json message"""
@@ -16,10 +20,6 @@ class AbstractIdeProtocolServer(ABC):
"""Show a suggestion to the user"""
@abstractmethod
- async def getWorkspaceDirectory(self):
- """Get the workspace directory"""
-
- @abstractmethod
async def setFileOpen(self, filepath: str, open: bool = True):
"""Set whether a file is open"""
@@ -28,8 +28,8 @@ class AbstractIdeProtocolServer(ABC):
"""Set whether suggestions are locked"""
@abstractmethod
- async def openGUI(self):
- """Open a GUI"""
+ async def getSessionId(self):
+ """Get a new session ID"""
@abstractmethod
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
@@ -56,6 +56,10 @@ class AbstractIdeProtocolServer(ABC):
"""Get a list of open files"""
@abstractmethod
+ async def getVisibleFiles(self) -> List[str]:
+ """Get a list of visible files"""
+
+ @abstractmethod
async def getHighlightedCode(self) -> List[RangeInFile]:
"""Get a list of highlighted code"""
@@ -103,10 +107,5 @@ class AbstractIdeProtocolServer(ABC):
async def showDiff(self, filepath: str, replacement: str, step_index: int):
"""Show a diff"""
- @abstractproperty
- def workspace_directory(self) -> str:
- """Get the workspace directory"""
-
- @abstractproperty
- def unique_id(self) -> str:
- """Get a unique ID for this IDE"""
+ workspace_directory: str
+ unique_id: str
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index f4d82903..aa093853 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -4,7 +4,8 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .ide import router as ide_router
from .gui import router as gui_router
-import logging
+from .session_manager import session_manager
+import atexit
import uvicorn
import argparse
@@ -44,5 +45,16 @@ def run_server():
uvicorn.run(app, host="0.0.0.0", port=args.port)
+def cleanup():
+ print("Cleaning up sessions")
+ for session_id in session_manager.sessions:
+ session_manager.persist_session(session_id)
+
+
+atexit.register(cleanup)
if __name__ == "__main__":
- run_server()
+ try:
+ run_server()
+ except Exception as e:
+ cleanup()
+ raise e
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index 99a38146..6d109ca6 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -1,20 +1,24 @@
+import os
from fastapi import WebSocket
from typing import Any, Dict, List, Union
from uuid import uuid4
+import json
+from ..libs.util.paths import getSessionFilePath, getSessionsFolderPath
from ..models.filesystem_edit import FileEditWithFullContents
+from ..libs.constants.main import CONTINUE_SESSIONS_FOLDER
from ..core.policy import DemoPolicy
from ..core.main import FullState
from ..core.autopilot import Autopilot
from .ide_protocol import AbstractIdeProtocolServer
-import asyncio
-import nest_asyncio
-nest_asyncio.apply()
+from ..libs.util.create_async_task import create_async_task
+from ..libs.util.errors import SessionNotFound
class Session:
session_id: str
autopilot: Autopilot
+ # The GUI websocket for the session
ws: Union[WebSocket, None]
def __init__(self, session_id: str, autopilot: Autopilot):
@@ -38,18 +42,35 @@ class DemoAutopilot(Autopilot):
class SessionManager:
sessions: Dict[str, Session] = {}
- _event_loop: Union[asyncio.BaseEventLoop, None] = None
+ # Mapping of session_id to IDE, where the IDE is still alive
+ registered_ides: Dict[str, AbstractIdeProtocolServer] = {}
- def get_session(self, session_id: str) -> Session:
+ async def get_session(self, session_id: str) -> Session:
if session_id not in self.sessions:
+ # Check then whether it is persisted by listing all files in the sessions folder
+ # And only if the IDE is still alive
+ sessions_folder = getSessionsFolderPath()
+ session_files = os.listdir(sessions_folder)
+ if f"{session_id}.json" in session_files and session_id in self.registered_ides:
+ if self.registered_ides[session_id].session_id is not None:
+ return await self.new_session(self.registered_ides[session_id], session_id=session_id)
+
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- autopilot = DemoAutopilot(policy=DemoPolicy(), ide=ide)
- session_id = str(uuid4())
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ full_state = None
+ if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
+ with open(getSessionFilePath(session_id), "r") as f:
+ full_state = FullState(**json.load(f))
+
+ autopilot = await DemoAutopilot.create(
+ policy=DemoPolicy(), ide=ide, full_state=full_state)
+ session_id = session_id or str(uuid4())
+ ide.session_id = session_id
session = Session(session_id=session_id, autopilot=autopilot)
self.sessions[session_id] = session
+ self.registered_ides[session_id] = ide
async def on_update(state: FullState):
await session_manager.send_ws_data(session_id, "state_update", {
@@ -57,17 +78,27 @@ class SessionManager:
})
autopilot.on_update(on_update)
- asyncio.create_task(autopilot.run_policy())
- return session_id
+ create_async_task(autopilot.run_policy())
+ return session
def remove_session(self, session_id: str):
del self.sessions[session_id]
+ def persist_session(self, session_id: str):
+ """Save the session's FullState as a json file"""
+ full_state = self.sessions[session_id].autopilot.get_full_state()
+ if not os.path.exists(getSessionsFolderPath()):
+ os.mkdir(getSessionsFolderPath())
+ with open(getSessionFilePath(session_id), "w") as f:
+ json.dump(full_state.dict(), f)
+
def register_websocket(self, session_id: str, ws: WebSocket):
self.sessions[session_id].ws = ws
print("Registered websocket for session", session_id)
async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if session_id not in self.sessions:
+ raise SessionNotFound(f"Session {session_id} not found")
if self.sessions[session_id].ws is None:
print(f"Session {session_id} has no websocket")
return
diff --git a/continuedev/src/continuedev/server/state_manager.py b/continuedev/src/continuedev/server/state_manager.py
deleted file mode 100644
index c9bd760b..00000000
--- a/continuedev/src/continuedev/server/state_manager.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from typing import Any, List, Tuple, Union
-from fastapi import WebSocket
-from pydantic import BaseModel
-from ..core.main import FullState
-
-# State updates represented as (path, replacement) pairs
-StateUpdate = Tuple[List[Union[str, int]], Any]
-
-
-class StateManager:
- """
- A class that acts as the source of truth for state, ingesting changes to the entire object and streaming only the updated portions to client.
- """
-
- def __init__(self, ws: WebSocket):
- self.ws = ws
-
- def _send_update(self, updates: List[StateUpdate]):
- self.ws.send_json(
- [update.dict() for update in updates]
- )
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index 1df1e0bf..3751dec2 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -27,14 +27,21 @@ class SimpleChatStep(Step):
async def run(self, sdk: ContinueSDK):
completion = ""
messages = self.messages or await sdk.get_chat_context()
- async for chunk in sdk.models.default.stream_chat(messages, temperature=0.5):
- if sdk.current_step_was_deleted():
- return
- if "content" in chunk:
- self.description += chunk["content"]
- completion += chunk["content"]
- await sdk.update_ui()
+ generator = sdk.models.default.stream_chat(messages, temperature=0.5)
+ try:
+ async for chunk in generator:
+ if sdk.current_step_was_deleted():
+ # So that the message doesn't disappear
+ self.hide = False
+ return
+
+ if "content" in chunk:
+ self.description += chunk["content"]
+ completion += chunk["content"]
+ await sdk.update_ui()
+ finally:
+ await generator.aclose()
self.name = (await sdk.models.gpt35.complete(
f"Write a short title for the following chat message: {self.description}")).strip()
diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py
index 0b067d7d..2c9d8c01 100644
--- a/continuedev/src/continuedev/steps/core/core.py
+++ b/continuedev/src/continuedev/steps/core/core.py
@@ -10,7 +10,7 @@ from ...libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
-from ...core.main import ChatMessage, Step, SequentialStep
+from ...core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
from ...libs.util.dedent import dedent_and_get_common_whitespace
import difflib
@@ -153,7 +153,8 @@ class DefaultModelEditCodeStep(Step):
Main task:
""")
-
+ _previous_contents: str = ""
+ _new_contents: str = ""
_prompt_and_completion: str = ""
def _cleanup_output(self, output: str) -> str:
@@ -168,10 +169,19 @@ class DefaultModelEditCodeStep(Step):
return output
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- description = await models.gpt3516k.complete(dedent(f"""\
- {self._prompt_and_completion}
-
- Please give brief a description of the changes made above using markdown bullet points. Be concise and only mention changes made to the commit before, not prefix or suffix:"""))
+ if self._previous_contents.strip() == self._new_contents.strip():
+ description = "No edits were made"
+ else:
+ description = await models.gpt3516k.complete(dedent(f"""\
+ ```original
+ {self._previous_contents}
+ ```
+
+ ```new
+ {self._new_contents}
+ ```
+
+ Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))
name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
self.name = self._cleanup_output(name)
@@ -263,6 +273,23 @@ class DefaultModelEditCodeStep(Step):
return file_prefix, rif.contents, file_suffix, model_to_use, max_tokens
def compile_prompt(self, file_prefix: str, contents: str, file_suffix: str, sdk: ContinueSDK) -> str:
+ if contents.strip() == "":
+ # Seperate prompt for insertion at the cursor, the other tends to cause it to repeat whole file
+ prompt = dedent(f"""\
+<file_prefix>
+{file_prefix}
+</file_prefix>
+<insertion_code_here>
+<file_suffix>
+{file_suffix}
+</file_suffix>
+<user_request>
+{self.user_input}
+</user_request>
+
+Please output the code to be inserted at the cursor in order to fulfill the user_request. Do NOT preface your answer or write anything other than code. You should not write any tags, just the code. Make sure to correctly indent the code:""")
+ return prompt
+
prompt = self._prompt
if file_prefix.strip() != "":
prompt += dedent(f"""
@@ -304,15 +331,32 @@ class DefaultModelEditCodeStep(Step):
prompt = self.compile_prompt(file_prefix, contents, file_suffix, sdk)
full_file_contents_lines = full_file_contents.split("\n")
- async def sendDiffUpdate(lines: List[str], sdk: ContinueSDK):
- nonlocal full_file_contents_lines, rif
+ lines_to_display = []
+
+ async def sendDiffUpdate(lines: List[str], sdk: ContinueSDK, final: bool = False):
+ nonlocal full_file_contents_lines, rif, lines_to_display
completion = "\n".join(lines)
full_prefix_lines = full_file_contents_lines[:rif.range.start.line]
full_suffix_lines = full_file_contents_lines[rif.range.end.line:]
+
+ # Don't do this at the very end, just show the inserted code
+ if final:
+ lines_to_display = []
+ # Only recalculate at every new-line, because this is sort of expensive
+ elif completion.endswith("\n"):
+ contents_lines = rif.contents.split("\n")
+ rewritten_lines = 0
+ for line in lines:
+ for i in range(rewritten_lines, len(contents_lines)):
+ if difflib.SequenceMatcher(None, line, contents_lines[i]).ratio() > 0.7 and contents_lines[i].strip() != "":
+ rewritten_lines = i + 1
+ break
+ lines_to_display = contents_lines[rewritten_lines:]
+
new_file_contents = "\n".join(
- full_prefix_lines) + "\n" + completion + "\n" + "\n".join(full_suffix_lines)
+ full_prefix_lines) + "\n" + completion + "\n" + ("\n".join(lines_to_display) + "\n" if len(lines_to_display) > 0 else "") + "\n".join(full_suffix_lines)
step_index = sdk.history.current_index
@@ -431,6 +475,14 @@ class DefaultModelEditCodeStep(Step):
current_block_lines.append(line)
messages = await sdk.get_chat_context()
+ # Delete the last user and assistant messages
+ i = len(messages) - 1
+ deleted = 0
+ while i >= 0 and deleted < 2:
+ if messages[i].role == "user" or messages[i].role == "assistant":
+ messages.pop(i)
+ deleted += 1
+ i -= 1
messages.append(ChatMessage(
role="user",
content=prompt,
@@ -448,58 +500,63 @@ class DefaultModelEditCodeStep(Step):
messages = [ChatMessage(
role="user", content=f"```\n{rif.contents}\n```\n{self.user_input}\n```\n", summary=self.user_input)]
- async for chunk in model_to_use.stream_chat(messages, temperature=0, max_tokens=max_tokens):
- # Stop early if it is repeating the file_suffix or the step was deleted
- if repeating_file_suffix:
- break
- if sdk.current_step_was_deleted():
- return
+ generator = model_to_use.stream_chat(
+ messages, temperature=0, max_tokens=max_tokens)
- # Accumulate lines
- if "content" not in chunk:
- continue
- chunk = chunk["content"]
- chunk_lines = chunk.split("\n")
- chunk_lines[0] = unfinished_line + chunk_lines[0]
- if chunk.endswith("\n"):
- unfinished_line = ""
- chunk_lines.pop() # because this will be an empty string
- else:
- unfinished_line = chunk_lines.pop()
-
- # Deal with newly accumulated lines
- for i in range(len(chunk_lines)):
- # Trailing whitespace doesn't matter
- chunk_lines[i] = chunk_lines[i].rstrip()
- chunk_lines[i] = common_whitespace + chunk_lines[i]
-
- # Lines that should signify the end of generation
- if self.is_end_line(chunk_lines[i]):
- break
- # Lines that should be ignored, like the <> tags
- elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0):
- continue
- # Check if we are currently just copying the prefix
- elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]:
- # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line
- lines_of_prefix_copied += 1
- continue
- # Because really short lines might be expected to be repeated, this is only a !heuristic!
- # Stop when it starts copying the file_suffix
- elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()):
- repeating_file_suffix = True
+ try:
+ async for chunk in generator:
+ # Stop early if it is repeating the file_suffix or the step was deleted
+ if repeating_file_suffix:
break
+ if sdk.current_step_was_deleted():
+ return
- # If none of the above, insert the line!
- if False:
- await handle_generated_line(chunk_lines[i])
+ # Accumulate lines
+ if "content" not in chunk:
+ continue
+ chunk = chunk["content"]
+ chunk_lines = chunk.split("\n")
+ chunk_lines[0] = unfinished_line + chunk_lines[0]
+ if chunk.endswith("\n"):
+ unfinished_line = ""
+ chunk_lines.pop() # because this will be an empty string
+ else:
+ unfinished_line = chunk_lines.pop()
+
+ # Deal with newly accumulated lines
+ for i in range(len(chunk_lines)):
+ # Trailing whitespace doesn't matter
+ chunk_lines[i] = chunk_lines[i].rstrip()
+ chunk_lines[i] = common_whitespace + chunk_lines[i]
+
+ # Lines that should signify the end of generation
+ if self.is_end_line(chunk_lines[i]):
+ break
+ # Lines that should be ignored, like the <> tags
+ elif self.line_to_be_ignored(chunk_lines[i], completion_lines_covered == 0):
+ continue
+ # Check if we are currently just copying the prefix
+ elif (lines_of_prefix_copied > 0 or completion_lines_covered == 0) and lines_of_prefix_copied < len(file_prefix.splitlines()) and chunk_lines[i] == full_file_contents_lines[lines_of_prefix_copied]:
+ # This is a sketchy way of stopping it from repeating the file_prefix. Is a bug if output happens to have a matching line
+ lines_of_prefix_copied += 1
+ continue
+ # Because really short lines might be expected to be repeated, this is only a !heuristic!
+ # Stop when it starts copying the file_suffix
+ elif chunk_lines[i].strip() == line_below_highlighted_range.strip() and len(chunk_lines[i].strip()) > 4 and not (len(original_lines_below_previous_blocks) > 0 and chunk_lines[i].strip() == original_lines_below_previous_blocks[0].strip()):
+ repeating_file_suffix = True
+ break
- lines.append(chunk_lines[i])
- completion_lines_covered += 1
- current_line_in_file += 1
+ # If none of the above, insert the line!
+ if False:
+ await handle_generated_line(chunk_lines[i])
- await sendDiffUpdate(lines + [common_whitespace + unfinished_line], sdk)
+ lines.append(chunk_lines[i])
+ completion_lines_covered += 1
+ current_line_in_file += 1
+ await sendDiffUpdate(lines + [common_whitespace if unfinished_line.startswith("<") else (common_whitespace + unfinished_line)], sdk)
+ finally:
+ await generator.aclose()
# Add the unfinished line
if unfinished_line != "" and not self.line_to_be_ignored(unfinished_line, completion_lines_covered == 0) and not self.is_end_line(unfinished_line):
unfinished_line = common_whitespace + unfinished_line
@@ -508,7 +565,7 @@ class DefaultModelEditCodeStep(Step):
completion_lines_covered += 1
current_line_in_file += 1
- await sendDiffUpdate(lines, sdk)
+ await sendDiffUpdate(lines, sdk, final=True)
if False:
# If the current block isn't empty, add that suggestion
@@ -542,6 +599,8 @@ class DefaultModelEditCodeStep(Step):
# Record the completion
completion = "\n".join(lines)
+ self._previous_contents = "\n".join(original_lines)
+ self._new_contents = completion
self._prompt_and_completion += prompt + completion
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -562,6 +621,13 @@ class DefaultModelEditCodeStep(Step):
rif_dict[rif.filepath] = rif.contents
for rif in rif_with_contents:
+ # If the file doesn't exist, ask them to save it first
+ if not os.path.exists(rif.filepath):
+ message = f"The file {rif.filepath} does not exist. Please save it first."
+ raise ContinueCustomException(
+ title=message, message=message
+ )
+
await sdk.ide.setFileOpen(rif.filepath)
await sdk.ide.setSuggestionsLocked(rif.filepath, True)
await self.stream_rif(rif, sdk)
diff --git a/continuedev/src/continuedev/steps/help.py b/continuedev/src/continuedev/steps/help.py
new file mode 100644
index 00000000..ba1e6087
--- /dev/null
+++ b/continuedev/src/continuedev/steps/help.py
@@ -0,0 +1,59 @@
+from textwrap import dedent
+from ..core.main import ChatMessage, Step
+from ..core.sdk import ContinueSDK
+from ..libs.util.telemetry import capture_event
+
+help = dedent("""\
+ Continue is an open-source coding autopilot. It is a VS Code extension that brings the power of ChatGPT to your IDE.
+
+ It gathers context for you and stores your interactions automatically, so that you can avoid copy/paste now and benefit from a customized Large Language Model (LLM) later.
+
+ Continue can be used to...
+ 1. Edit chunks of code with specific instructions (e.g. "/edit migrate this digital ocean terraform file into one that works for GCP")
+ 2. Get answers to questions without switching windows (e.g. "how do I find running process on port 8000?")
+ 3. Generate files from scratch (e.g. "/edit Create a Python CLI tool that uses the posthog api to get events from DAUs")
+
+ You tell Continue to edit a specific section of code by highlighting it. If you highlight multiple code sections, then it will only edit the one with the purple glow around it. You can switch which one has the purple glow by clicking the paint brush.
+
+ If you don't highlight any code, then Continue will insert at the location of your cursor.
+
+ Continue passes all of the sections of code you highlight, the code above and below the to-be edited highlighted code section, and all previous steps above input box as context to the LLM.
+
+ You can use cmd+m (Mac) / ctrl+m (Windows) to open Continue. You can use cmd+shift+e / ctrl+shift+e to open file Explorer. You can add your own OpenAI API key to VS Code Settings with `cmd+,`
+
+ If Continue is stuck loading, try using `cmd+shift+p` to open the command palette, search "Reload Window", and then select it. This will reload VS Code and Continue and often fixes issues.
+
+ If you have feedback, please use /feedback to let us know how you would like to use Continue. We are excited to hear from you!""")
+
+
+class HelpStep(Step):
+
+ name: str = "Help"
+ user_input: str
+ manage_own_chat_context: bool = True
+ description: str = ""
+
+ async def run(self, sdk: ContinueSDK):
+
+ question = self.user_input
+
+ prompt = dedent(f"""Please us the information below to provide a succinct answer to the following quesiton: {question}
+
+ Information:
+
+ {help}""")
+
+ self.chat_context.append(ChatMessage(
+ role="user",
+ content=prompt,
+ summary="Help"
+ ))
+ messages = await sdk.get_chat_context()
+ generator = sdk.models.gpt4.stream_chat(messages)
+ async for chunk in generator:
+ if "content" in chunk:
+ self.description += chunk["content"]
+ await sdk.update_ui()
+
+ capture_event(sdk.ide.unique_id, "help", {
+ "question": question, "answer": self.description})
diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py
index 4f543022..ce7cbc60 100644
--- a/continuedev/src/continuedev/steps/main.py
+++ b/continuedev/src/continuedev/steps/main.py
@@ -10,7 +10,7 @@ from ..models.filesystem import RangeInFile, RangeInFileWithContents
from ..core.observation import Observation, TextObservation, TracebackObservation
from ..libs.llm.prompt_utils import MarkdownStyleEncoderDecoder
from textwrap import dedent
-from ..core.main import Step
+from ..core.main import ContinueCustomException, Step
from ..core.sdk import ContinueSDK, Models
from ..core.observation import Observation
import subprocess
@@ -99,8 +99,8 @@ class FasterEditHighlightedCodeStep(Step):
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
range_in_files = await sdk.get_code_context(only_editing=True)
if len(range_in_files) == 0:
- # Get the full contents of all open files
- files = await sdk.ide.getOpenFiles()
+ # Get the full contents of all visible files
+ files = await sdk.ide.getVisibleFiles()
contents = {}
for file in files:
contents[file] = await sdk.ide.readFile(file)
@@ -191,8 +191,8 @@ class StarCoderEditHighlightedCodeStep(Step):
range_in_files = await sdk.get_code_context(only_editing=True)
found_highlighted_code = len(range_in_files) > 0
if not found_highlighted_code:
- # Get the full contents of all open files
- files = await sdk.ide.getOpenFiles()
+ # Get the full contents of all visible files
+ files = await sdk.ide.getVisibleFiles()
contents = {}
for file in files:
contents[file] = await sdk.ide.readFile(file)
@@ -251,44 +251,28 @@ class EditHighlightedCodeStep(Step):
highlighted_code = await sdk.ide.getHighlightedCode()
if highlighted_code is not None:
for rif in highlighted_code:
+ if os.path.dirname(rif.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")):
+ raise ContinueCustomException(
+ message="Please accept or reject the change before making another edit in this file.", title="Accept/Reject First")
if rif.range.start == rif.range.end:
range_in_files.append(
RangeInFileWithContents.from_range_in_file(rif, ""))
- # If nothing highlighted, edit the first open file
+ # If still no highlighted code, raise error
if len(range_in_files) == 0:
- # Get the full contents of all open files
- files = await sdk.ide.getOpenFiles()
- contents = {}
- for file in files:
- contents[file] = await sdk.ide.readFile(file)
-
- range_in_files = [RangeInFileWithContents.from_entire_file(
- filepath, content) for filepath, content in contents.items()]
-
- # If still no highlighted code, create a new file and edit there
- if len(range_in_files) == 0:
- # Create a new file
- new_file_path = "new_file.txt"
- await sdk.add_file(new_file_path, "")
- range_in_files = [
- RangeInFileWithContents.from_entire_file(new_file_path, "")]
+ raise ContinueCustomException(
+ message="Please highlight some code and try again.", title="No Code Selected")
range_in_files = list(map(lambda x: RangeInFile(
filepath=x.filepath, range=x.range
), range_in_files))
- await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files))
-
+ for range_in_file in range_in_files:
+ if os.path.dirname(range_in_file.filepath) == os.path.expanduser(os.path.join("~", ".continue", "diffs")):
+ self.description = "Please accept or reject the change before making another edit in this file."
+ return
-class FindCodeStep(Step):
- prompt: str
-
- async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return "Finding code"
-
- async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- return await sdk.ide.getOpenFiles()
+ await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files))
class UserInputStep(Step):
diff --git a/continuedev/src/continuedev/steps/open_config.py b/continuedev/src/continuedev/steps/open_config.py
index 441cb0e7..af55a95a 100644
--- a/continuedev/src/continuedev/steps/open_config.py
+++ b/continuedev/src/continuedev/steps/open_config.py
@@ -14,11 +14,14 @@ class OpenConfigStep(Step):
"custom_commands": [
{
"name": "test",
- "description": "Write unit tests like I do for the highlighted code"
+ "description": "Write unit tests like I do for the highlighted code",
"prompt": "Write a comprehensive set of unit tests for the selected code. It should setup, run tests that check for correctness including important edge cases, and teardown. Ensure that the tests are complete and sophisticated."
}
- ],
- ```""")
+ ]
+ ```
+ `"name"` is the command you will type.
+ `"description"` is the description displayed in the slash command menu.
+ `"prompt"` is the instruction given to the model. The overall prompt becomes "Task: {prompt}, Additional info: {user_input}". For example, if you entered "/test exactly 5 assertions", the overall prompt would become "Task: Write a comprehensive...and sophisticated, Additional info: exactly 5 assertions".""")
async def run(self, sdk: ContinueSDK):
global_dir = os.path.expanduser('~/.continue')
diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py
index 2eecc99c..bfb97630 100644
--- a/continuedev/src/continuedev/steps/search_directory.py
+++ b/continuedev/src/continuedev/steps/search_directory.py
@@ -6,6 +6,7 @@ from ..models.filesystem import RangeInFile
from ..models.main import Range
from ..core.main import Step
from ..core.sdk import ContinueSDK
+from ..libs.util.create_async_task import create_async_task
import os
import re
@@ -60,9 +61,9 @@ class EditAllMatchesStep(Step):
# Search all files for a given string
range_in_files = find_all_matches_in_dir(self.pattern, self.directory or await sdk.ide.getWorkspaceDirectory())
- tasks = [asyncio.create_task(sdk.edit_file(
+ tasks = [create_async_task(sdk.edit_file(
range=range_in_file.range,
filename=range_in_file.filepath,
prompt=self.user_request
- )) for range_in_file in range_in_files]
+ ), sdk.ide.unique_id) for range_in_file in range_in_files]
await asyncio.gather(*tasks)
diff --git a/continuedev/src/continuedev/steps/welcome.py b/continuedev/src/continuedev/steps/welcome.py
index 32ebc3ba..2dece649 100644
--- a/continuedev/src/continuedev/steps/welcome.py
+++ b/continuedev/src/continuedev/steps/welcome.py
@@ -29,4 +29,4 @@ class WelcomeStep(Step):
- Ask about how the class works, how to write it in another language, etc.
\"\"\"""")))
- await sdk.ide.setFileOpen(filepath=filepath)
+ # await sdk.ide.setFileOpen(filepath=filepath)