summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/agent.py34
-rw-r--r--continuedev/src/continuedev/core/env.py4
-rw-r--r--continuedev/src/continuedev/core/main.py14
-rw-r--r--continuedev/src/continuedev/core/policy.py17
-rw-r--r--continuedev/src/continuedev/core/sdk.py13
-rw-r--r--continuedev/src/continuedev/libs/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/core/core.py33
-rw-r--r--continuedev/src/continuedev/libs/steps/draft/dlt.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/main.py20
-rw-r--r--continuedev/src/continuedev/libs/steps/migration.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/nate.py6
-rw-r--r--continuedev/src/continuedev/libs/steps/pytest.py2
-rw-r--r--continuedev/src/continuedev/libs/steps/ty.py2
-rw-r--r--continuedev/src/continuedev/server/ide.py50
-rw-r--r--continuedev/src/continuedev/server/main.py2
-rw-r--r--continuedev/src/continuedev/server/notebook.py207
-rw-r--r--continuedev/src/continuedev/server/notebook_protocol.py28
-rw-r--r--continuedev/src/continuedev/server/session_manager.py101
-rw-r--r--extension/package-lock.json19
-rw-r--r--extension/package.json7
-rw-r--r--extension/react-app/src/hooks/ContinueNotebookClientProtocol.ts13
-rw-r--r--extension/react-app/src/hooks/messenger.ts108
-rw-r--r--extension/react-app/src/hooks/useContinueNotebookProtocol.ts49
-rw-r--r--extension/react-app/src/hooks/useWebsocket.ts171
-rw-r--r--extension/react-app/src/hooks/vscodeMessenger.ts68
-rw-r--r--extension/react-app/src/tabs/notebook.tsx70
-rw-r--r--extension/react-app/src/vscode/index.ts1
-rw-r--r--extension/react-app/tsconfig.json2
-rw-r--r--extension/scripts/continuedev-0.1.0-py3-none-any.whlbin53104 -> 56070 bytes
-rw-r--r--extension/src/activation/activate.ts4
-rw-r--r--extension/src/activation/environmentSetup.ts104
-rw-r--r--extension/src/commands.ts4
-rw-r--r--extension/src/continueIdeClient.ts146
-rw-r--r--extension/src/debugPanel.ts26
-rw-r--r--extension/src/extension.ts7
-rw-r--r--extension/src/test/runTest.ts30
-rw-r--r--extension/src/util/messenger.ts108
37 files changed, 864 insertions, 612 deletions
diff --git a/continuedev/src/continuedev/core/agent.py b/continuedev/src/continuedev/core/agent.py
index 329e3d4c..7f7466a2 100644
--- a/continuedev/src/continuedev/core/agent.py
+++ b/continuedev/src/continuedev/core/agent.py
@@ -13,7 +13,6 @@ from .sdk import ContinueSDK
class Agent(ContinueBaseModel):
- llm: LLM
policy: Policy
ide: AbstractIdeProtocolServer
history: History = History.from_empty()
@@ -31,27 +30,24 @@ class Agent(ContinueBaseModel):
def get_full_state(self) -> FullState:
return FullState(history=self.history, active=self._active, user_input_queue=self._main_user_input_queue)
- def on_update(self, callback: Callable[["FullState"], None]):
+ def on_update(self, callback: Coroutine["FullState", None, None]):
"""Subscribe to changes to state"""
self._on_update_callbacks.append(callback)
- def update_subscribers(self):
+ async def update_subscribers(self):
full_state = self.get_full_state()
for callback in self._on_update_callbacks:
- callback(full_state)
-
- def __get_step_params(self, step: "Step"):
- return ContinueSDK(agent=self, llm=self.llm.with_system_message(step.system_message))
+ await callback(full_state)
def give_user_input(self, input: str, index: int):
- self._user_input_queue.post(index, input)
+ self._user_input_queue.post(str(index), input)
async def wait_for_user_input(self) -> str:
self._active = False
- self.update_subscribers()
- user_input = await self._user_input_queue.get(self.history.current_index)
+ await self.update_subscribers()
+ user_input = await self._user_input_queue.get(str(self.history.current_index))
self._active = True
- self.update_subscribers()
+ await self.update_subscribers()
return user_input
_manual_edits_buffer: List[FileEditWithFullContents] = []
@@ -62,9 +58,9 @@ class Agent(ContinueBaseModel):
current_step = self.history.get_current().step
self.history.step_back()
if issubclass(current_step.__class__, ReversibleStep):
- await current_step.reverse(self.__get_step_params(current_step))
+ await current_step.reverse(ContinueSDK(self))
- self.update_subscribers()
+ await self.update_subscribers()
except Exception as e:
print(e)
@@ -94,17 +90,17 @@ class Agent(ContinueBaseModel):
# Run step
self._step_depth += 1
- observation = await step(self.__get_step_params(step))
+ observation = await step(ContinueSDK(self))
self._step_depth -= 1
# Add observation to history
self.history.get_current().observation = observation
# Update its description
- step._set_description(await step.describe(self.llm))
+ step._set_description(await step.describe(ContinueSDK(self)))
# Call all subscribed callbacks
- self.update_subscribers()
+ await self.update_subscribers()
return observation
@@ -138,7 +134,7 @@ class Agent(ContinueBaseModel):
# Doing this so active can make it to the frontend after steps are done. But want better state syncing tools
for callback in self._on_update_callbacks:
- callback(None)
+ await callback(None)
async def run_from_observation(self, observation: Observation):
next_step = self.policy.next(self.history)
@@ -158,7 +154,7 @@ class Agent(ContinueBaseModel):
async def accept_user_input(self, user_input: str):
self._main_user_input_queue.append(user_input)
- self.update_subscribers()
+ await self.update_subscribers()
if len(self._main_user_input_queue) > 1:
return
@@ -167,7 +163,7 @@ class Agent(ContinueBaseModel):
# Just run the step that takes user input, and
# then up to the policy to decide how to deal with it.
self._main_user_input_queue.pop(0)
- self.update_subscribers()
+ await self.update_subscribers()
await self.run_from_step(UserInputStep(user_input=user_input))
while len(self._main_user_input_queue) > 0:
diff --git a/continuedev/src/continuedev/core/env.py b/continuedev/src/continuedev/core/env.py
index 6267ed60..edd3297c 100644
--- a/continuedev/src/continuedev/core/env.py
+++ b/continuedev/src/continuedev/core/env.py
@@ -8,6 +8,10 @@ def get_env_var(var_name: str):
def save_env_var(var_name: str, var_value: str):
+ if not os.path.exists('.env'):
+ with open('.env', 'w') as f:
+ f.write(f'{var_name}="{var_value}"\n')
+ return
with open('.env', 'r') as f:
lines = f.readlines()
with open('.env', 'w') as f:
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index 51fcd299..6be5139b 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -72,7 +72,7 @@ class ContinueSDK:
pass
-class SequentialStep:
+class Models:
pass
@@ -94,7 +94,7 @@ class Step(ContinueBaseModel):
class Config:
copy_on_model_validation = False
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._description is not None:
return self._description
return "Running step: " + self.name
@@ -135,6 +135,16 @@ class Step(ContinueBaseModel):
return SequentialStep(steps=steps)
+class SequentialStep(Step):
+ steps: list[Step]
+ hide: bool = True
+
+ async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
+ for step in self.steps:
+ observation = await sdk.run_step(step)
+ return observation
+
+
class ValidatorObservation(Observation):
passed: bool
observation: Observation
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index 07101576..c0ba0f4f 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -1,20 +1,16 @@
from typing import List, Tuple, Type
-
-from ..models.main import ContinueBaseModel
-
from ..libs.steps.ty import CreatePipelineStep
from .main import Step, Validator, History, Policy
from .observation import Observation, TracebackObservation, UserInputObservation
from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep
from ..libs.steps.nate import WritePytestsStep, CreateTableStep
-from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
+# from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma
from ..libs.steps.continue_step import ContinueStepStep
class DemoPolicy(Policy):
ran_code_last: bool = False
- cmd: str
def next(self, history: History) -> Step:
observation = history.last_observation()
@@ -26,18 +22,15 @@ class DemoPolicy(Policy):
return CreatePipelineStep()
elif "/table" in observation.user_input:
return CreateTableStep(sql_str=" ".join(observation.user_input.split(" ")[1:]))
- elif "/ask" in observation.user_input:
- return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
- elif "/edit" in observation.user_input:
- return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
+ # elif "/ask" in observation.user_input:
+ # return AnswerQuestionChroma(question=" ".join(observation.user_input.split(" ")[1:]))
+ # elif "/edit" in observation.user_input:
+ # return EditFileChroma(request=" ".join(observation.user_input.split(" ")[1:]))
elif "/step" in observation.user_input:
return ContinueStepStep(prompt=" ".join(observation.user_input.split(" ")[1:]))
return StarCoderEditHighlightedCodeStep(user_input=observation.user_input)
state = history.get_current()
- if state is None or not self.ran_code_last:
- self.ran_code_last = True
- return RunCodeStep(cmd=self.cmd)
if observation is not None and isinstance(observation, TracebackObservation):
self.ran_code_last = False
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 750b335d..6ae0be04 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -4,6 +4,7 @@ from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDir
from ..models.filesystem import RangeInFile
from ..libs.llm import LLM
from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
+from ..libs.llm.openai import OpenAI
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import History, Step
@@ -29,20 +30,20 @@ class Models:
'HUGGING_FACE_TOKEN', 'Please enter your Hugging Face token')
return HuggingFaceInferenceAPI(api_key=api_key)
+ async def gpt35(self):
+ api_key = await self.sdk.get_user_secret(
+ 'OPENAI_API_KEY', 'Please enter your OpenAI API key')
+ return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo")
+
class ContinueSDK:
"""The SDK provided as parameters to a step"""
- llm: LLM
ide: AbstractIdeProtocolServer
steps: ContinueSDKSteps
models: Models
__agent: Agent
- def __init__(self, agent: Agent, llm: Union[LLM, None] = None):
- if llm is None:
- self.llm = agent.llm
- else:
- self.llm = llm
+ def __init__(self, agent: Agent):
self.ide = agent.ide
self.__agent = agent
self.steps = ContinueSDKSteps(self)
diff --git a/continuedev/src/continuedev/libs/steps/chroma.py b/continuedev/src/continuedev/libs/steps/chroma.py
index f13a2bab..39424c5c 100644
--- a/continuedev/src/continuedev/libs/steps/chroma.py
+++ b/continuedev/src/continuedev/libs/steps/chroma.py
@@ -40,7 +40,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = sdk.llm.complete(prompt)
+ answer = (await sdk.models.gpt35()).complete(prompt)
print(answer)
self._answer = answer
diff --git a/continuedev/src/continuedev/libs/steps/core/core.py b/continuedev/src/continuedev/libs/steps/core/core.py
index 0338d635..14b3cb80 100644
--- a/continuedev/src/continuedev/libs/steps/core/core.py
+++ b/continuedev/src/continuedev/libs/steps/core/core.py
@@ -4,27 +4,18 @@ from textwrap import dedent
from typing import Coroutine, List, Union
from ...llm.prompt_utils import MarkdownStyleEncoderDecoder
-from ...util.traceback_parsers import parse_python_traceback
-
from ....models.filesystem_edit import EditDiff, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
-from ...llm import LLM
from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
-from ....core.main import Step
+from ....core.main import Step, SequentialStep
class ContinueSDK:
pass
-class SequentialStep(Step):
- steps: list[Step]
- hide: bool = True
-
- async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
- for step in self.steps:
- observation = await sdk.run_step(step)
- return observation
+class Models:
+ pass
class ReversibleStep(Step):
@@ -52,7 +43,7 @@ def ShellCommandsStep(Step):
cwd: str | None = None
name: str = "Run Shell Commands"
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "\n".join(self.cmds)
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -81,13 +72,13 @@ class EditCodeStep(Step):
_prompt: Union[str, None] = None
_completion: Union[str, None] = None
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._edit_diffs is None:
return "Editing files: " + ", ".join(map(lambda rif: rif.filepath, self.range_in_files))
elif len(self._edit_diffs) == 0:
return "No edits made"
else:
- return llm.complete(dedent(f"""{self._prompt}{self._completion}
+ return (await models.gpt35()).complete(dedent(f"""{self._prompt}{self._completion}
Maximally concise summary of changes in bullet points (can use markdown):
"""))
@@ -102,7 +93,7 @@ class EditCodeStep(Step):
code_string = enc_dec.encode()
prompt = self.prompt.format(code=code_string)
- completion = sdk.llm.complete(prompt)
+ completion = (await sdk.models.gpt35()).complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -127,7 +118,7 @@ class EditFileStep(Step):
prompt: str
hide: bool = True
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing file: " + self.filepath
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -145,7 +136,7 @@ class ManualEditStep(ReversibleStep):
hide: bool = True
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Manual edit step"
# TODO - only handling FileEdit here, but need all other types of FileSystemEdits
# Also requires the merge_file_edit function
@@ -181,7 +172,7 @@ class UserInputStep(Step):
name: str = "User Input"
hide: bool = True
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return self.user_input
async def run(self, sdk: ContinueSDK) -> Coroutine[UserInputObservation, None, None]:
@@ -194,7 +185,7 @@ class WaitForUserInputStep(Step):
_description: Union[str, None] = None
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return self.prompt
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -207,7 +198,7 @@ class WaitForUserConfirmationStep(Step):
prompt: str
name: str = "Waiting for user confirmation"
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return self.prompt
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py
index 5ba5692a..460aa0cc 100644
--- a/continuedev/src/continuedev/libs/steps/draft/dlt.py
+++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py
@@ -10,7 +10,7 @@ class SetupPipelineStep(Step):
api_description: str # e.g. "I want to load data from the weatherapi.com API"
async def run(self, sdk: ContinueSDK):
- source_name = sdk.llm.complete(
+ source_name = (await sdk.models.gpt35()).complete(
f"Write a snake_case name for the data source described by {self.api_description}: ").strip()
filename = f'{source_name}.py'
diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py
index c8a85800..70c0d4b8 100644
--- a/continuedev/src/continuedev/libs/steps/main.py
+++ b/continuedev/src/continuedev/libs/steps/main.py
@@ -11,7 +11,7 @@ from ...core.observation import Observation, TextObservation, TracebackObservati
from ..llm.prompt_utils import MarkdownStyleEncoderDecoder
from textwrap import dedent
from ...core.main import Step
-from ...core.sdk import ContinueSDK
+from ...core.sdk import ContinueSDK, Models
from ...core.observation import Observation
import subprocess
from .core.core import EditCodeStep
@@ -20,7 +20,7 @@ from .core.core import EditCodeStep
class RunCodeStep(Step):
cmd: str
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return f"Ran command: `{self.cmd}`"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -59,7 +59,7 @@ class RunCommandStep(Step):
name: str = "Run command"
_description: str = None
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
if self._description is not None:
return self._description
return self.cmd
@@ -125,7 +125,7 @@ class FasterEditHighlightedCodeStep(Step):
Here is the description of changes to make:
""")
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing highlighted code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -154,7 +154,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in rif_with_contents:
rif_dict[rif.filepath] = rif.contents
- completion = sdk.llm.complete(prompt)
+ completion = (await sdk.models.gpt35()).complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -215,10 +215,10 @@ class FasterEditHighlightedCodeStep(Step):
class StarCoderEditHighlightedCodeStep(Step):
user_input: str
- hide = True
+ hide = False
_prompt: str = "<commit_before>{code}<commit_msg>{user_request}<commit_after>"
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing highlighted code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -271,7 +271,7 @@ This is the user request:
This is the code after being changed to perfectly satisfy the user request:
""")
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Editing highlighted code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -293,7 +293,7 @@ This is the code after being changed to perfectly satisfy the user request:
class FindCodeStep(Step):
prompt: str
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return "Finding code"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
@@ -307,7 +307,7 @@ class UserInputStep(Step):
class SolveTracebackStep(Step):
traceback: Traceback
- async def describe(self, llm: LLM) -> Coroutine[str, None, None]:
+ async def describe(self, models: Models) -> Coroutine[str, None, None]:
return f"```\n{self.traceback.full_traceback}\n```"
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
diff --git a/continuedev/src/continuedev/libs/steps/migration.py b/continuedev/src/continuedev/libs/steps/migration.py
index f044a60f..7b70422d 100644
--- a/continuedev/src/continuedev/libs/steps/migration.py
+++ b/continuedev/src/continuedev/libs/steps/migration.py
@@ -15,7 +15,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.llm.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await (await sdk.models.gpt35()).complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run_step(RunCommandStep(cmd=f"cd libs && poetry run alembic revision --autogenerate -m {description}"))
migration_file = f"libs/alembic/versions/{?}.py"
contents = await sdk.ide.readFile(migration_file)
diff --git a/continuedev/src/continuedev/libs/steps/nate.py b/continuedev/src/continuedev/libs/steps/nate.py
index a0e728e5..2f84e9d7 100644
--- a/continuedev/src/continuedev/libs/steps/nate.py
+++ b/continuedev/src/continuedev/libs/steps/nate.py
@@ -45,7 +45,7 @@ Here are additional instructions:
Here is a complete set of pytest unit tests:
""")
- # tests = sdk.llm.complete(prompt)
+ # tests = (await sdk.models.gpt35()).complete(prompt)
tests = '''
import pytest
@@ -169,9 +169,9 @@ export class Order {
tracking_number: string;
}'''
time.sleep(2)
- # orm_entity = sdk.llm.complete(
+ # orm_entity = (await sdk.models.gpt35()).complete(
# f"{self.sql_str}\n\nWrite a TypeORM entity called {entity_name} for this table, importing as necessary:")
- # sdk.llm.complete("What is the name of the entity?")
+ # (await sdk.models.gpt35()).complete("What is the name of the entity?")
await sdk.apply_filesystem_edit(AddFile(filepath=f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", content=orm_entity))
await sdk.ide.setFileOpen(f"/Users/natesesti/Desktop/continue/extension/examples/python/MyProject/src/entity/{entity_name}.ts", True)
diff --git a/continuedev/src/continuedev/libs/steps/pytest.py b/continuedev/src/continuedev/libs/steps/pytest.py
index b4e6dfd2..2e83ae2d 100644
--- a/continuedev/src/continuedev/libs/steps/pytest.py
+++ b/continuedev/src/continuedev/libs/steps/pytest.py
@@ -33,5 +33,5 @@ class WritePytestsStep(Step):
Here is a complete set of pytest unit tests:
""")
- tests = sdk.llm.complete(prompt)
+ tests = (await sdk.models.gpt35()).complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/libs/steps/ty.py b/continuedev/src/continuedev/libs/steps/ty.py
index 5ff03f04..9dde7c86 100644
--- a/continuedev/src/continuedev/libs/steps/ty.py
+++ b/continuedev/src/continuedev/libs/steps/ty.py
@@ -18,7 +18,7 @@ class SetupPipelineStep(Step):
api_description: str # e.g. "I want to load data from the weatherapi.com API"
async def run(self, sdk: ContinueSDK):
- # source_name = sdk.llm.complete(
+ # source_name = (await sdk.models.gpt35()).complete(
# f"Write a snake_case name for the data source described by {self.api_description}: ").strip()
filename = f'/Users/natesesti/Desktop/continue/extension/examples/python/{source_name}.py'
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index dd1dc463..50296841 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -1,5 +1,6 @@
# This is a separate server from server/main.py
import asyncio
+import json
import os
from typing import Any, Dict, List, Type, TypeVar, Union
import uuid
@@ -90,31 +91,33 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
def __init__(self, session_manager: SessionManager):
self.session_manager = session_manager
- async def _send_json(self, data: Any):
- await self.websocket.send_json(data)
+ async def _send_json(self, message_type: str, data: Any):
+ await self.websocket.send_json({
+ "messageType": message_type,
+ "data": data
+ })
async def _receive_json(self, message_type: str) -> Any:
return await self.sub_queue.get(message_type)
async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
- await self._send_json(data)
+ await self._send_json(message_type, data)
resp = await self._receive_json(message_type)
return resp_model.parse_obj(resp)
- async def handle_json(self, data: Any):
- t = data["messageType"]
- if t == "openNotebook":
+ async def handle_json(self, message_type: str, data: Any):
+ if message_type == "openNotebook":
await self.openNotebook()
- elif t == "setFileOpen":
+ elif message_type == "setFileOpen":
await self.setFileOpen(data["filepath"], data["open"])
- elif t == "fileEdits":
+ elif message_type == "fileEdits":
fileEdits = list(
map(lambda d: FileEditWithFullContents.parse_obj(d), data["fileEdits"]))
self.onFileEdits(fileEdits)
- elif t in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]:
- self.sub_queue.post(t, data)
+ elif message_type in ["highlightedCode", "openFiles", "readFile", "editFile", "workspaceDirectory"]:
+ self.sub_queue.post(message_type, data)
else:
- raise ValueError("Unknown message type", t)
+ raise ValueError("Unknown message type", message_type)
# ------------------------------- #
# Request actions in IDE, doesn't matter which Session
@@ -123,24 +126,21 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def setFileOpen(self, filepath: str, open: bool = True):
# Agent needs access to this.
- await self.websocket.send_json({
- "messageType": "setFileOpen",
+ await self._send_json("setFileOpen", {
"filepath": filepath,
"open": open
})
async def openNotebook(self):
session_id = self.session_manager.new_session(self)
- await self._send_json({
- "messageType": "openNotebook",
+ await self._send_json("openNotebook", {
"sessionId": session_id
})
async def showSuggestionsAndWait(self, suggestions: List[FileEdit]) -> bool:
ids = [str(uuid.uuid4()) for _ in suggestions]
for i in range(len(suggestions)):
- self._send_json({
- "messageType": "showSuggestion",
+ self._send_json("showSuggestion", {
"suggestion": suggestions[i],
"suggestionId": ids[i]
})
@@ -210,8 +210,7 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
async def saveFile(self, filepath: str):
"""Save a file"""
- await self._send_json({
- "messageType": "saveFile",
+ await self._send_json("saveFile", {
"filepath": filepath
})
@@ -293,10 +292,17 @@ ideProtocolServer = IdeProtocolServer(session_manager)
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
print("Accepted websocket connection from, ", websocket.client)
- await websocket.send_json({"messageType": "connected"})
+ await websocket.send_json({"messageType": "connected", "data": {}})
ideProtocolServer.websocket = websocket
while True:
- data = await websocket.receive_json()
- await ideProtocolServer.handle_json(data)
+ message = await websocket.receive_text()
+ message = json.loads(message)
+
+ if "messageType" not in message or "data" not in message:
+ continue
+ message_type = message["messageType"]
+ data = message["data"]
+
+ await ideProtocolServer.handle_json(message_type, data)
await websocket.close()
diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py
index 11ad1d8f..e87d5fa9 100644
--- a/continuedev/src/continuedev/server/main.py
+++ b/continuedev/src/continuedev/server/main.py
@@ -32,7 +32,7 @@ args = parser.parse_args()
def run_server():
- uvicorn.run(app, host="0.0.0.0", port=args.port, log_config="logging.ini")
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
if __name__ == "__main__":
diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py
index c5dcea31..edb61a45 100644
--- a/continuedev/src/continuedev/server/notebook.py
+++ b/continuedev/src/continuedev/server/notebook.py
@@ -1,18 +1,12 @@
-from fastapi import FastAPI, Depends, Header, WebSocket, APIRouter
-from typing import Any, Dict, List, Union
-from uuid import uuid4
+import json
+from fastapi import Depends, Header, WebSocket, APIRouter
+from typing import Any, Type, TypeVar, Union
from pydantic import BaseModel
from uvicorn.main import Server
-from ..models.filesystem_edit import FileEditWithFullContents
-from ..core.policy import DemoPolicy
-from ..core.main import FullState, History, Step
-from ..core.agent import Agent
-from ..libs.steps.nate import ImplementAbstractMethodStep
-from ..core.observation import Observation
-from ..libs.llm.openai import OpenAI
-from .ide_protocol import AbstractIdeProtocolServer
-from ..core.env import get_env_var
+from .session_manager import SessionManager, session_manager, Session
+from .notebook_protocol import AbstractNotebookProtocolServer
+from ..libs.util.queue import AsyncSubscriptionQueue
import asyncio
import nest_asyncio
nest_asyncio.apply()
@@ -36,160 +30,99 @@ class AppStatus:
Server.handle_exit = AppStatus.handle_exit
-class Session:
- session_id: str
- agent: Agent
- ws: Union[WebSocket, None]
-
- def __init__(self, session_id: str, agent: Agent):
- self.session_id = session_id
- self.agent = agent
- self.ws = None
-
-
-class DemoAgent(Agent):
- first_seen: bool = False
- cumulative_edit_string = ""
-
- def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
- for edit in edits:
- self.cumulative_edit_string += edit.fileEdit.replacement
- self._manual_edits_buffer.append(edit)
- # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
- # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
- # FOR DEMO PURPOSES
- if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement:
- self.cumulative_edit_string = ""
- asyncio.create_task(self.run_from_step(
- ImplementAbstractMethodStep()))
-
-
-class SessionManager:
- sessions: Dict[str, Session] = {}
- _event_loop: Union[asyncio.BaseEventLoop, None] = None
-
- def get_session(self, session_id: str) -> Session:
- if session_id not in self.sessions:
- raise KeyError("Session ID not recognized")
- return self.sessions[session_id]
-
- def new_session(self, ide: AbstractIdeProtocolServer) -> str:
- cmd = "python3 /Users/natesesti/Desktop/continue/extension/examples/python/main.py"
- agent = DemoAgent(llm=OpenAI(api_key=get_env_var("OPENAI_API_KEY")),
- policy=DemoPolicy(cmd=cmd), ide=ide)
- session_id = str(uuid4())
- session = Session(session_id=session_id, agent=agent)
- self.sessions[session_id] = session
+def session(x_continue_session_id: str = Header("anonymous")) -> Session:
+ return session_manager.get_session(x_continue_session_id)
- def on_update(state: FullState):
- session_manager.send_ws_data(session_id, {
- "messageType": "state",
- "state": agent.get_full_state().dict()
- })
- agent.on_update(on_update)
- asyncio.create_task(agent.run_policy())
- return session_id
+def websocket_session(session_id: str) -> Session:
+ return session_manager.get_session(session_id)
- def remove_session(self, session_id: str):
- del self.sessions[session_id]
- def register_websocket(self, session_id: str, ws: WebSocket):
- self.sessions[session_id].ws = ws
- print("Registered websocket for session", session_id)
+T = TypeVar("T", bound=BaseModel)
- def send_ws_data(self, session_id: str, data: Any):
- if self.sessions[session_id].ws is None:
- print(f"Session {session_id} has no websocket")
- return
+# You should probably abstract away the websocket stuff into a separate class
- async def a():
- await self.sessions[session_id].ws.send_json(data)
- # Run coroutine in background
- if self._event_loop is None or self._event_loop.is_closed():
- self._event_loop = asyncio.new_event_loop()
- self._event_loop.run_until_complete(a())
- self._event_loop.close()
- else:
- self._event_loop.run_until_complete(a())
- self._event_loop.close()
+class NotebookProtocolServer(AbstractNotebookProtocolServer):
+ websocket: WebSocket
+ session: Session
+ sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
+ def __init__(self, session: Session):
+ self.session = session
-session_manager = SessionManager()
+ async def _send_json(self, data: Any):
+ await self.websocket.send_json(data)
+ async def _receive_json(self, message_type: str) -> Any:
+ return await self.sub_queue.get(message_type)
-def session(x_continue_session_id: str = Header("anonymous")) -> Session:
- return session_manager.get_session(x_continue_session_id)
+ async def _send_and_receive_json(self, data: Any, resp_model: Type[T], message_type: str) -> T:
+ await self._send_json(data)
+ resp = await self._receive_json(message_type)
+ return resp_model.parse_obj(resp)
+ def handle_json(self, message_type: str, data: Any):
+ try:
+ if message_type == "main_input":
+ self.on_main_input(data["input"])
+ elif message_type == "step_user_input":
+ self.on_step_user_input(data["input"], data["index"])
+ elif message_type == "refinement_input":
+ self.on_refinement_input(data["input"], data["index"])
+ elif message_type == "reverse_to_index":
+ self.on_reverse_to_index(data["index"])
+ except Exception as e:
+ print(e)
-def websocket_session(session_id: str) -> Session:
- return session_manager.get_session(session_id)
+ async def send_state_update(self):
+ state = self.session.agent.get_full_state().dict()
+ await self._send_json({
+ "messageType": "state_update",
+ "state": state
+ })
+ def on_main_input(self, input: str):
+ # Do something with user input
+ asyncio.create_task(self.session.agent.accept_user_input(input))
-class StartSessionBody(BaseModel):
- config_file_path: Union[str, None]
+ def on_reverse_to_index(self, index: int):
+ # Reverse the history to the given index
+ asyncio.create_task(self.session.agent.reverse_to_index(index))
+ def on_step_user_input(self, input: str, index: int):
+ asyncio.create_task(
+ self.session.agent.give_user_input(input, index))
-class StartSessionResp(BaseModel):
- session_id: str
+ def on_refinement_input(self, input: str, index: int):
+ asyncio.create_task(
+ self.session.agent.accept_refinement_input(input, index))
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
await websocket.accept()
+ print("Session started")
session_manager.register_websocket(session.session_id, websocket)
- data = await websocket.receive_text()
+ protocol = NotebookProtocolServer(session)
+ protocol.websocket = websocket
+
# Update any history that may have happened before connection
- await websocket.send_json({
- "messageType": "state",
- "state": session_manager.get_session(session.session_id).agent.get_full_state().dict()
- })
- print("Session started", data)
+ await protocol.send_state_update()
+
while AppStatus.should_exit is False:
- data = await websocket.receive_json()
- print("Received data", data)
+ message = await websocket.receive_json()
+ print("Received message", message)
+ if type(message) is str:
+ message = json.loads(message)
- if "messageType" not in data:
+ if "messageType" not in message or "data" not in message:
continue
- messageType = data["messageType"]
+ message_type = message["messageType"]
+ data = message["data"]
- try:
- if messageType == "main_input":
- # Do something with user input
- asyncio.create_task(
- session.agent.accept_user_input(data["value"]))
- elif messageType == "step_user_input":
- asyncio.create_task(
- session.agent.give_user_input(data["value"], data["index"]))
- elif messageType == "refinement_input":
- asyncio.create_task(
- session.agent.accept_refinement_input(data["value"], data["index"]))
- elif messageType == "reverse":
- # Reverse the history to the given index
- asyncio.create_task(
- session.agent.reverse_to_index(data["index"]))
- except Exception as e:
- print(e)
+ protocol.handle_json(message_type, data)
print("Closing websocket")
await websocket.close()
-
-
-@router.post("/run")
-def request_run(step: Step, session=Depends(session)):
- """Tell an agent to take a specific action."""
- asyncio.create_task(session.agent.run_from_step(step))
- return "Success"
-
-
-@router.get("/history")
-def get_history(session=Depends(session)) -> History:
- return session.agent.history
-
-
-@router.post("/observation")
-def post_observation(observation: Observation, session=Depends(session)):
- asyncio.create_task(session.agent.run_from_observation(observation))
- return "Success"
diff --git a/continuedev/src/continuedev/server/notebook_protocol.py b/continuedev/src/continuedev/server/notebook_protocol.py
new file mode 100644
index 00000000..c2be82e0
--- /dev/null
+++ b/continuedev/src/continuedev/server/notebook_protocol.py
@@ -0,0 +1,28 @@
+from typing import Any
+from abc import ABC, abstractmethod
+
+
+class AbstractNotebookProtocolServer(ABC):
+ @abstractmethod
+ async def handle_json(self, data: Any):
+ """Handle a json message"""
+
+ @abstractmethod
+ def on_main_input(self, input: str):
+ """Called when the user inputs something"""
+
+ @abstractmethod
+ def on_reverse_to_index(self, index: int):
+ """Called when the user requests reverse to a previous index"""
+
+ @abstractmethod
+ def on_refinement_input(self, input: str, index: int):
+ """Called when the user inputs a refinement"""
+
+ @abstractmethod
+ def on_step_user_input(self, input: str, index: int):
+ """Called when the user inputs a step"""
+
+ @abstractmethod
+ async def send_state_update(self, state: dict):
+ """Send a state update to the client"""
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
new file mode 100644
index 00000000..b48c21b7
--- /dev/null
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -0,0 +1,101 @@
+from fastapi import WebSocket
+from typing import Any, Dict, List, Union
+from uuid import uuid4
+
+from ..models.filesystem_edit import FileEditWithFullContents
+from ..core.policy import DemoPolicy
+from ..core.main import FullState
+from ..core.agent import Agent
+from ..libs.steps.nate import ImplementAbstractMethodStep
+from .ide_protocol import AbstractIdeProtocolServer
+import asyncio
+import nest_asyncio
+nest_asyncio.apply()
+
+
+class Session:
+ session_id: str
+ agent: Agent
+ ws: Union[WebSocket, None]
+
+ def __init__(self, session_id: str, agent: Agent):
+ self.session_id = session_id
+ self.agent = agent
+ self.ws = None
+
+
+class DemoAgent(Agent):
+ first_seen: bool = False
+ cumulative_edit_string = ""
+
+ def handle_manual_edits(self, edits: List[FileEditWithFullContents]):
+ for edit in edits:
+ self.cumulative_edit_string += edit.fileEdit.replacement
+ self._manual_edits_buffer.append(edit)
+ # Note that you're storing a lot of unecessary data here. Can compress into EditDiffs on the spot, and merge.
+ # self._manual_edits_buffer = merge_file_edit(self._manual_edits_buffer, edit)
+ # FOR DEMO PURPOSES
+ if edit.fileEdit.filepath.endswith("filesystem.py") and "List" in self.cumulative_edit_string and ":" in edit.fileEdit.replacement:
+ self.cumulative_edit_string = ""
+ asyncio.create_task(self.run_from_step(
+ ImplementAbstractMethodStep()))
+
+
+class SessionManager:
+ sessions: Dict[str, Session] = {}
+ _event_loop: Union[asyncio.BaseEventLoop, None] = None
+
+ def get_session(self, session_id: str) -> Session:
+ if session_id not in self.sessions:
+ raise KeyError("Session ID not recognized")
+ return self.sessions[session_id]
+
+ def new_session(self, ide: AbstractIdeProtocolServer) -> str:
+ agent = DemoAgent(policy=DemoPolicy(), ide=ide)
+ session_id = str(uuid4())
+ session = Session(session_id=session_id, agent=agent)
+ self.sessions[session_id] = session
+
+ async def on_update(state: FullState):
+ await session_manager.send_ws_data(session_id, "state_update", {
+ "state": agent.get_full_state().dict()
+ })
+
+ agent.on_update(on_update)
+ asyncio.create_task(agent.run_policy())
+ return session_id
+
+ def remove_session(self, session_id: str):
+ del self.sessions[session_id]
+
+ def register_websocket(self, session_id: str, ws: WebSocket):
+ self.sessions[session_id].ws = ws
+ print("Registered websocket for session", session_id)
+
+ async def send_ws_data(self, session_id: str, message_type: str, data: Any):
+ if self.sessions[session_id].ws is None:
+ print(f"Session {session_id} has no websocket")
+ return
+
+ async def a():
+ await self.sessions[session_id].ws.send_json({
+ "messageType": message_type,
+ "data": data
+ })
+
+ # Run coroutine in background
+ await self.sessions[session_id].ws.send_json({
+ "messageType": message_type,
+ "data": data
+ })
+ return
+ if self._event_loop is None or self._event_loop.is_closed():
+ self._event_loop = asyncio.new_event_loop()
+ self._event_loop.run_until_complete(a())
+ self._event_loop.close()
+ else:
+ self._event_loop.run_until_complete(a())
+ self._event_loop.close()
+
+
+session_manager = SessionManager()
diff --git a/extension/package-lock.json b/extension/package-lock.json
index 20ac24be..04af09d3 100644
--- a/extension/package-lock.json
+++ b/extension/package-lock.json
@@ -28,6 +28,7 @@
"@types/node": "16.x",
"@types/node-fetch": "^2.6.2",
"@types/vscode": "^1.74.0",
+ "@types/ws": "^8.5.4",
"@typescript-eslint/eslint-plugin": "^5.45.0",
"@typescript-eslint/parser": "^5.45.0",
"@vscode/test-electron": "^2.2.0",
@@ -2027,6 +2028,15 @@
"integrity": "sha512-LyeCIU3jb9d38w0MXFwta9r0Jx23ugujkAxdwLTNCyspdZTKUc43t7ppPbCiPoQ/Ivd/pnDFZrb4hWd45wrsgA==",
"dev": true
},
+ "node_modules/@types/ws": {
+ "version": "8.5.4",
+ "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.4.tgz",
+ "integrity": "sha512-zdQDHKUgcX/zBc4GrwsE/7dVdAD8JR4EuiAXiiUhhfyIJXXb2+PrGshFyeXWQPMmmZ2XxgaqclgpIC7eTXc1mg==",
+ "dev": true,
+ "dependencies": {
+ "@types/node": "*"
+ }
+ },
"node_modules/@typescript-eslint/eslint-plugin": {
"version": "5.48.2",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.48.2.tgz",
@@ -9246,6 +9256,15 @@
"integrity": "sha512-LyeCIU3jb9d38w0MXFwta9r0Jx23ugujkAxdwLTNCyspdZTKUc43t7ppPbCiPoQ/Ivd/pnDFZrb4hWd45wrsgA==",
"dev": true
},
+ "@types/ws": {
+ "version": "8.5.4",
+ "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.4.tgz",
+ "integrity": "sha512-zdQDHKUgcX/zBc4GrwsE/7dVdAD8JR4EuiAXiiUhhfyIJXXb2+PrGshFyeXWQPMmmZ2XxgaqclgpIC7eTXc1mg==",
+ "dev": true,
+ "requires": {
+ "@types/node": "*"
+ }
+ },
"@typescript-eslint/eslint-plugin": {
"version": "5.48.2",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.48.2.tgz",
diff --git a/extension/package.json b/extension/package.json
index dc0192c3..c96655a9 100644
--- a/extension/package.json
+++ b/extension/package.json
@@ -148,7 +148,7 @@
},
"scripts": {
"vscode:prepublish": "npm run esbuild-base -- --minify",
- "esbuild-base": "esbuild ./src/extension.ts --bundle --outfile=out/extension.js --external:vscode --format=cjs --platform=node",
+ "esbuild-base": "rm -rf ./out && esbuild ./src/extension.ts --bundle --outfile=out/extension.js --external:vscode --format=cjs --platform=node",
"esbuild": "rm -rf ./out && npm run esbuild-base -- --sourcemap",
"esbuild-watch": "npm run esbuild-base -- --sourcemap --watch",
"test-compile": "tsc -p ./",
@@ -160,9 +160,9 @@
"pretest": "npm run compile && npm run lint",
"lint": "eslint src --ext ts",
"test": "node ./out/test/runTest.js",
- "package": "cp ./config/prod_config.json ./config/config.json && mkdir -p ./build && vsce package --out ./build && chmod 777 ./build/continue-0.0.2.vsix && cp ./config/dev_config.json ./config/config.json",
+ "package": "cp ./config/prod_config.json ./config/config.json && mkdir -p ./build && vsce package --out ./build && chmod 777 ./build/continue-0.0.5.vsix && cp ./config/dev_config.json ./config/config.json",
"full-package": "cd ../continuedev && poetry build && cp ./dist/continuedev-0.1.0-py3-none-any.whl ../extension/scripts/continuedev-0.1.0-py3-none-any.whl && cd ../extension && npm run typegen && npm run clientgen && cd react-app && npm run build && cd .. && npm run package",
- "install-extension": "code --install-extension ./build/continue-0.0.1.vsix",
+ "install-extension": "code --install-extension ./build/continue-0.0.5.vsix",
"uninstall": "code --uninstall-extension .continue",
"reinstall": "rm -rf ./build && npm run package && npm run uninstall && npm run install-extension"
},
@@ -173,6 +173,7 @@
"@types/node": "16.x",
"@types/node-fetch": "^2.6.2",
"@types/vscode": "^1.74.0",
+ "@types/ws": "^8.5.4",
"@typescript-eslint/eslint-plugin": "^5.45.0",
"@typescript-eslint/parser": "^5.45.0",
"@vscode/test-electron": "^2.2.0",
diff --git a/extension/react-app/src/hooks/ContinueNotebookClientProtocol.ts b/extension/react-app/src/hooks/ContinueNotebookClientProtocol.ts
new file mode 100644
index 00000000..75fd7373
--- /dev/null
+++ b/extension/react-app/src/hooks/ContinueNotebookClientProtocol.ts
@@ -0,0 +1,13 @@
+abstract class AbstractContinueNotebookClientProtocol {
+ abstract sendMainInput(input: string): void;
+
+ abstract reverseToIndex(index: number): void;
+
+ abstract sendRefinementInput(input: string, index: number): void;
+
+ abstract sendStepUserInput(input: string, index: number): void;
+
+ abstract onStateUpdate(state: any): void;
+}
+
+export default AbstractContinueNotebookClientProtocol;
diff --git a/extension/react-app/src/hooks/messenger.ts b/extension/react-app/src/hooks/messenger.ts
new file mode 100644
index 00000000..e2a0bab8
--- /dev/null
+++ b/extension/react-app/src/hooks/messenger.ts
@@ -0,0 +1,108 @@
+// console.log("Websocket import");
+// const WebSocket = require("ws");
+
+export abstract class Messenger {
+ abstract send(messageType: string, data: object): void;
+
+ abstract onMessageType(
+ messageType: string,
+ callback: (data: object) => void
+ ): void;
+
+ abstract onMessage(callback: (messageType: string, data: any) => void): void;
+
+ abstract onOpen(callback: () => void): void;
+
+ abstract onClose(callback: () => void): void;
+
+ abstract sendAndReceive(messageType: string, data: any): Promise<any>;
+}
+
+export class WebsocketMessenger extends Messenger {
+ websocket: WebSocket;
+ private onMessageListeners: {
+ [messageType: string]: ((data: object) => void)[];
+ } = {};
+ private onOpenListeners: (() => void)[] = [];
+ private onCloseListeners: (() => void)[] = [];
+ private serverUrl: string;
+
+ _newWebsocket(): WebSocket {
+ // // Dynamic import, because WebSocket is builtin with browser, but not with node. And can't use require in browser.
+ // if (typeof process === "object") {
+ // console.log("Using node");
+ // // process is only available in Node
+ // var WebSocket = require("ws");
+ // }
+
+ const newWebsocket = new WebSocket(this.serverUrl);
+ for (const listener of this.onOpenListeners) {
+ this.onOpen(listener);
+ }
+ for (const listener of this.onCloseListeners) {
+ this.onClose(listener);
+ }
+ for (const messageType in this.onMessageListeners) {
+ for (const listener of this.onMessageListeners[messageType]) {
+ this.onMessageType(messageType, listener);
+ }
+ }
+ return newWebsocket;
+ }
+
+ constructor(serverUrl: string) {
+ super();
+ this.serverUrl = serverUrl;
+ this.websocket = this._newWebsocket();
+ }
+
+ send(messageType: string, data: object) {
+ const payload = JSON.stringify({ messageType, data });
+ if (this.websocket.readyState === this.websocket.OPEN) {
+ this.websocket.send(payload);
+ } else {
+ if (this.websocket.readyState !== this.websocket.CONNECTING) {
+ this.websocket = this._newWebsocket();
+ }
+ this.websocket.addEventListener("open", () => {
+ this.websocket.send(payload);
+ });
+ }
+ }
+
+ sendAndReceive(messageType: string, data: any): Promise<any> {
+ return new Promise((resolve, reject) => {
+ const eventListener = (data: any) => {
+ // THIS ISN"T GETTING CALLED
+ resolve(data);
+ this.websocket.removeEventListener("message", eventListener);
+ };
+ this.onMessageType(messageType, eventListener);
+ this.send(messageType, data);
+ });
+ }
+
+ onMessageType(messageType: string, callback: (data: any) => void): void {
+ this.websocket.addEventListener("message", (event: any) => {
+ const msg = JSON.parse(event.data);
+ if (msg.messageType === messageType) {
+ callback(msg.data);
+ }
+ });
+ }
+
+ onMessage(callback: (messageType: string, data: any) => void): void {
+ this.websocket.addEventListener("message", (event) => {
+ const msg = JSON.parse(event.data);
+ callback(msg.messageType, msg.data);
+ });
+ }
+
+ onOpen(callback: () => void): void {
+ this.websocket.addEventListener("open", callback);
+ }
+
+ onClose(callback: () => void): void {
+ this.websocket.addEventListener("close", callback);
+ }
+}
diff --git a/extension/react-app/src/hooks/useContinueNotebookProtocol.ts b/extension/react-app/src/hooks/useContinueNotebookProtocol.ts
new file mode 100644
index 00000000..d5ffbf09
--- /dev/null
+++ b/extension/react-app/src/hooks/useContinueNotebookProtocol.ts
@@ -0,0 +1,49 @@
+import AbstractContinueNotebookClientProtocol from "./ContinueNotebookClientProtocol";
+// import { Messenger, WebsocketMessenger } from "../../../src/util/messenger";
+import { Messenger, WebsocketMessenger } from "./messenger";
+import { VscodeMessenger } from "./vscodeMessenger";
+
+class ContinueNotebookClientProtocol extends AbstractContinueNotebookClientProtocol {
+ messenger: Messenger;
+ // Server URL must contain the session ID param
+ serverUrlWithSessionId: string;
+
+ constructor(
+ serverUrlWithSessionId: string,
+ useVscodeMessagePassing: boolean = false
+ ) {
+ super();
+ this.serverUrlWithSessionId = serverUrlWithSessionId;
+ if (useVscodeMessagePassing) {
+ this.messenger = new VscodeMessenger(serverUrlWithSessionId);
+ } else {
+ this.messenger = new WebsocketMessenger(serverUrlWithSessionId);
+ }
+ }
+
+ sendMainInput(input: string) {
+ this.messenger.send("main_input", { input });
+ }
+
+ reverseToIndex(index: number) {
+ this.messenger.send("reverse_to_index", { index });
+ }
+
+ sendRefinementInput(input: string, index: number) {
+ this.messenger.send("refinement_input", { input, index });
+ }
+
+ sendStepUserInput(input: string, index: number) {
+ this.messenger.send("step_user_input", { input, index });
+ }
+
+ onStateUpdate(callback: (state: any) => void) {
+ this.messenger.onMessageType("state_update", (data: any) => {
+ if (data.state) {
+ callback(data.state);
+ }
+ });
+ }
+}
+
+export default ContinueNotebookClientProtocol;
diff --git a/extension/react-app/src/hooks/useWebsocket.ts b/extension/react-app/src/hooks/useWebsocket.ts
index 6e8e68fa..b98be577 100644
--- a/extension/react-app/src/hooks/useWebsocket.ts
+++ b/extension/react-app/src/hooks/useWebsocket.ts
@@ -1,158 +1,39 @@
import React, { useEffect, useState } from "react";
import { RootStore } from "../redux/store";
import { useSelector } from "react-redux";
+import ContinueNotebookClientProtocol from "./useContinueNotebookProtocol";
import { postVscMessage } from "../vscode";
-abstract class Messenger {
- abstract send(data: string): void;
-}
-
-class VscodeMessenger extends Messenger {
- url: string;
-
- constructor(
- url: string,
- onMessage: (message: { data: any }) => void,
- onOpen: (messenger: Messenger) => void,
- onClose: (messenger: Messenger) => void
- ) {
- super();
- this.url = url;
- window.addEventListener("message", (event: any) => {
- switch (event.data.type) {
- case "websocketForwardingMessage":
- onMessage(event.data);
- break;
- case "websocketForwardingOpen":
- onOpen(this);
- break;
- case "websocketForwardingClose":
- onClose(this);
- break;
- }
- });
-
- postVscMessage("websocketForwardingOpen", { url: this.url });
- }
-
- send(data: string) {
- postVscMessage("websocketForwardingMessage", {
- message: data,
- url: this.url,
- });
- }
-}
-
-class WebsocketMessenger extends Messenger {
- websocket: WebSocket;
- constructor(
- websocket: WebSocket,
- onMessage: (message: { data: any }) => void,
- onOpen: (messenger: Messenger) => void,
- onClose: (messenger: Messenger) => void
- ) {
- super();
- this.websocket = websocket;
-
- websocket.addEventListener("close", () => {
- onClose(this);
- });
-
- websocket.addEventListener("open", () => {
- onOpen(this);
- });
-
- websocket.addEventListener("message", (event) => {
- onMessage(event.data);
- });
- }
-
- static async connect(
- url: string,
- sessionId: string,
- onMessage: (message: { data: any }) => void,
- onOpen: (messenger: Messenger) => void,
- onClose: (messenger: Messenger) => void
- ): Promise<WebsocketMessenger> {
- const ws = new WebSocket(url);
-
- return new Promise((resolve, reject) => {
- ws.addEventListener("open", () => {
- resolve(new WebsocketMessenger(ws, onMessage, onOpen, onClose));
- });
- });
- }
-
- send(data: string) {
- this.websocket.send(JSON.stringify(data));
- }
-}
-
-function useContinueWebsocket(
- serverUrl: string,
- onMessage: (message: { data: any }) => void,
- useVscodeMessagePassing: boolean = true
-) {
+function useContinueNotebookProtocol(useVscodeMessagePassing: boolean = false) {
const sessionId = useSelector((state: RootStore) => state.config.sessionId);
- const [websocket, setWebsocket] = useState<Messenger | undefined>(undefined);
+ const serverHttpUrl = useSelector((state: RootStore) => state.config.apiUrl);
+ const [client, setClient] = useState<
+ ContinueNotebookClientProtocol | undefined
+ >(undefined);
- async function connect() {
- while (!sessionId) {
- await new Promise((resolve) => setTimeout(resolve, 300));
+ useEffect(() => {
+ if (!sessionId || !serverHttpUrl) {
+ if (useVscodeMessagePassing) {
+ postVscMessage("onLoad", {});
+ }
+ setClient(undefined);
+ return;
}
- console.log("Creating websocket", sessionId);
- console.log("Using vscode message passing", useVscodeMessagePassing);
-
- const onClose = (messenger: Messenger) => {
- console.log("Websocket closed");
- setWebsocket(undefined);
- };
-
- const onOpen = (messenger: Messenger) => {
- console.log("Websocket opened");
- messenger.send(JSON.stringify({ sessionId }));
- };
-
- const url =
- serverUrl.replace("http", "ws") +
+ const serverUrlWithSessionId =
+ serverHttpUrl.replace("http", "ws") +
"/notebook/ws?session_id=" +
encodeURIComponent(sessionId);
- const messenger: Messenger = useVscodeMessagePassing
- ? new VscodeMessenger(url, onMessage, onOpen, onClose)
- : await WebsocketMessenger.connect(
- url,
- sessionId,
- onMessage,
- onOpen,
- onClose
- );
-
- setWebsocket(messenger);
-
- return messenger;
- }
-
- async function getConnection() {
- if (!websocket) {
- return await connect();
- }
- return websocket;
- }
-
- async function send(message: object) {
- let ws = await getConnection();
- ws.send(JSON.stringify(message));
- }
-
- useEffect(() => {
- if (!sessionId) {
- return;
- }
- connect();
- }, [sessionId]);
-
- return { send };
+ console.log("Creating websocket", serverUrlWithSessionId);
+ console.log("Using vscode message passing", useVscodeMessagePassing);
+ const newClient = new ContinueNotebookClientProtocol(
+ serverUrlWithSessionId,
+ useVscodeMessagePassing
+ );
+ setClient(newClient);
+ }, [sessionId, serverHttpUrl]);
+
+ return client;
}
-export default useContinueWebsocket;
+export default useContinueNotebookProtocol;
diff --git a/extension/react-app/src/hooks/vscodeMessenger.ts b/extension/react-app/src/hooks/vscodeMessenger.ts
new file mode 100644
index 00000000..746c4302
--- /dev/null
+++ b/extension/react-app/src/hooks/vscodeMessenger.ts
@@ -0,0 +1,68 @@
+import { postVscMessage } from "../vscode";
+// import { Messenger } from "../../../src/util/messenger";
+import { Messenger } from "./messenger";
+
+export class VscodeMessenger extends Messenger {
+ serverUrl: string;
+
+ constructor(serverUrl: string) {
+ super();
+ this.serverUrl = serverUrl;
+ postVscMessage("websocketForwardingOpen", { url: this.serverUrl });
+ }
+
+ send(messageType: string, data: object) {
+ postVscMessage("websocketForwardingMessage", {
+ message: { messageType, data },
+ url: this.serverUrl,
+ });
+ }
+
+ onMessageType(messageType: string, callback: (data: object) => void): void {
+ window.addEventListener("message", (event: any) => {
+ if (event.data.type === "websocketForwardingMessage") {
+ if (event.data.message.messageType === messageType) {
+ callback(event.data.message.data);
+ }
+ }
+ });
+ }
+
+ onMessage(callback: (messageType: string, data: any) => void): void {
+ window.addEventListener("message", (event: any) => {
+ if (event.data.type === "websocketForwardingMessage") {
+ callback(event.data.message.messageType, event.data.message.data);
+ }
+ });
+ }
+
+ sendAndReceive(messageType: string, data: any): Promise<any> {
+ return new Promise((resolve) => {
+ const handler = (event: any) => {
+ if (event.data.type === "websocketForwardingMessage") {
+ if (event.data.message.messageType === messageType) {
+ window.removeEventListener("message", handler);
+ resolve(event.data.message.data);
+ }
+ }
+ };
+ window.addEventListener("message", handler);
+ this.send(messageType, data);
+ });
+ }
+
+ onOpen(callback: () => void): void {
+ window.addEventListener("message", (event: any) => {
+ if (event.data.type === "websocketForwardingOpen") {
+ callback();
+ }
+ });
+ }
+ onClose(callback: () => void): void {
+ window.addEventListener("message", (event: any) => {
+ if (event.data.type === "websocketForwardingClose") {
+ callback();
+ }
+ });
+ }
+}
diff --git a/extension/react-app/src/tabs/notebook.tsx b/extension/react-app/src/tabs/notebook.tsx
index a9c69c5b..02c9ff31 100644
--- a/extension/react-app/src/tabs/notebook.tsx
+++ b/extension/react-app/src/tabs/notebook.tsx
@@ -14,6 +14,7 @@ import StepContainer from "../components/StepContainer";
import { useSelector } from "react-redux";
import { RootStore } from "../redux/store";
import useContinueWebsocket from "../hooks/useWebsocket";
+import useContinueNotebookProtocol from "../hooks/useWebsocket";
let TopNotebookDiv = styled.div`
display: grid;
@@ -33,8 +34,6 @@ interface NotebookProps {
}
function Notebook(props: NotebookProps) {
- const serverUrl = useSelector((state: RootStore) => state.config.apiUrl);
-
const [waitingForSteps, setWaitingForSteps] = useState(false);
const [userInputQueue, setUserInputQueue] = useState<string[]>([]);
const [history, setHistory] = useState<History | undefined>();
@@ -157,30 +156,17 @@ function Notebook(props: NotebookProps) {
// } as any
// );
- const { send: websocketSend } = useContinueWebsocket(serverUrl, (msg) => {
- let data = JSON.parse(msg.data);
- if (data.messageType === "state") {
- setWaitingForSteps(data.state.active);
- setHistory(data.state.history);
- setUserInputQueue(data.state.user_input_queue);
- }
- });
+ const client = useContinueNotebookProtocol();
- // useEffect(() => {
- // (async () => {
- // if (sessionId && props.firstObservation) {
- // let resp = await fetch(serverUrl + "/observation", {
- // method: "POST",
- // headers: new Headers({
- // "x-continue-session-id": sessionId,
- // }),
- // body: JSON.stringify({
- // observation: props.firstObservation,
- // }),
- // });
- // }
- // })();
- // }, [props.firstObservation]);
+ useEffect(() => {
+ console.log("CLIENT ON STATE UPDATE: ", client, client?.onStateUpdate);
+ client?.onStateUpdate((state) => {
+ console.log("Received state update: ", state);
+ setWaitingForSteps(state.active);
+ setHistory(state.history);
+ setUserInputQueue(state.user_input_queue);
+ });
+ }, [client]);
const mainTextInputRef = useRef<HTMLTextAreaElement>(null);
@@ -201,14 +187,12 @@ function Notebook(props: NotebookProps) {
const onMainTextInput = () => {
if (mainTextInputRef.current) {
- let value = mainTextInputRef.current.value;
+ if (!client) return;
+ let input = mainTextInputRef.current.value;
setWaitingForSteps(true);
- websocketSend({
- messageType: "main_input",
- value: value,
- });
+ client.sendMainInput(input);
setUserInputQueue((queue) => {
- return [...queue, value];
+ return [...queue, input];
});
mainTextInputRef.current.value = "";
mainTextInputRef.current.style.height = "";
@@ -216,17 +200,20 @@ function Notebook(props: NotebookProps) {
};
const onStepUserInput = (input: string, index: number) => {
+ if (!client) return;
console.log("Sending step user input", input, index);
- websocketSend({
- messageType: "step_user_input",
- value: input,
- index,
- });
+ client.sendStepUserInput(input, index);
};
// const iterations = useSelector(selectIterations);
return (
<TopNotebookDiv>
+ {typeof client === "undefined" && (
+ <>
+ <Loader></Loader>
+ <p>Server disconnected</p>
+ </>
+ )}
{history?.timeline.map((node: HistoryNode, index: number) => {
return (
<StepContainer
@@ -237,17 +224,10 @@ function Notebook(props: NotebookProps) {
inFuture={index > history?.current_index}
historyNode={node}
onRefinement={(input: string) => {
- websocketSend({
- messageType: "refinement_input",
- value: input,
- index,
- });
+ client?.sendRefinementInput(input, index);
}}
onReverse={() => {
- websocketSend({
- messageType: "reverse",
- index,
- });
+ client?.reverseToIndex(index);
}}
/>
);
diff --git a/extension/react-app/src/vscode/index.ts b/extension/react-app/src/vscode/index.ts
index 7e373cd9..0785aa4d 100644
--- a/extension/react-app/src/vscode/index.ts
+++ b/extension/react-app/src/vscode/index.ts
@@ -5,6 +5,7 @@ declare const vscode: any;
export function postVscMessage(type: string, data: any) {
if (typeof vscode === "undefined") {
+ console.log("Unable to send message: vscode is undefined");
return;
}
vscode.postMessage({
diff --git a/extension/react-app/tsconfig.json b/extension/react-app/tsconfig.json
index 3d0a51a8..940a9359 100644
--- a/extension/react-app/tsconfig.json
+++ b/extension/react-app/tsconfig.json
@@ -16,6 +16,6 @@
"noEmit": true,
"jsx": "react-jsx"
},
- "include": ["src"],
+ "include": ["src", "../src/util/messenger.ts"],
"references": [{ "path": "./tsconfig.node.json" }]
}
diff --git a/extension/scripts/continuedev-0.1.0-py3-none-any.whl b/extension/scripts/continuedev-0.1.0-py3-none-any.whl
index d1483db9..2019c904 100644
--- a/extension/scripts/continuedev-0.1.0-py3-none-any.whl
+++ b/extension/scripts/continuedev-0.1.0-py3-none-any.whl
Binary files differ
diff --git a/extension/src/activation/activate.ts b/extension/src/activation/activate.ts
index a0aa560b..712ffe13 100644
--- a/extension/src/activation/activate.ts
+++ b/extension/src/activation/activate.ts
@@ -10,7 +10,7 @@ import { getContinueServerUrl } from "../bridge";
export let extensionContext: vscode.ExtensionContext | undefined = undefined;
-export let ideProtocolClient: IdeProtocolClient | undefined = undefined;
+export let ideProtocolClient: IdeProtocolClient;
export function activateExtension(
context: vscode.ExtensionContext,
@@ -24,7 +24,7 @@ export function activateExtension(
let serverUrl = getContinueServerUrl();
ideProtocolClient = new IdeProtocolClient(
- serverUrl.replace("http", "ws") + "/ide/ws",
+ `${serverUrl.replace("http", "ws")}/ide/ws`,
context
);
diff --git a/extension/src/activation/environmentSetup.ts b/extension/src/activation/environmentSetup.ts
index 93a471ff..ad6ac71b 100644
--- a/extension/src/activation/environmentSetup.ts
+++ b/extension/src/activation/environmentSetup.ts
@@ -28,18 +28,7 @@ async function runCommand(cmd: string): Promise<[string, string | undefined]> {
return [stdout, stderr];
}
-async function getPythonCmdAssumingInstalled() {
- const [, stderr] = await runCommand("python3 --version");
- if (stderr) {
- return "python";
- }
- return "python3";
-}
-
-async function setupPythonEnv() {
- console.log("Setting up python env for Continue extension...");
- // First check that python3 is installed
-
+async function getPythonPipCommands() {
var [stdout, stderr] = await runCommand("python3 --version");
let pythonCmd = "python3";
if (stderr) {
@@ -58,28 +47,77 @@ async function setupPythonEnv() {
}
}
let pipCmd = pythonCmd.endsWith("3") ? "pip3" : "pip";
+ return [pythonCmd, pipCmd];
+}
+function getActivateUpgradeCommands(pythonCmd: string, pipCmd: string) {
let activateCmd = ". env/bin/activate";
let pipUpgradeCmd = `${pipCmd} install --upgrade pip`;
if (process.platform == "win32") {
activateCmd = ".\\env\\Scripts\\activate";
pipUpgradeCmd = `${pythonCmd} -m pip install --upgrade pip`;
}
+ return [activateCmd, pipUpgradeCmd];
+}
- let command = `cd ${path.join(
+function checkEnvExists() {
+ const envBinActivatePath = path.join(
getExtensionUri().fsPath,
- "scripts"
- )} && ${pythonCmd} -m venv env && ${activateCmd} && ${pipUpgradeCmd} && ${pipCmd} install -r requirements.txt`;
- var [stdout, stderr] = await runCommand(command);
- if (stderr) {
- throw new Error(stderr);
+ "scripts",
+ "env",
+ "bin",
+ "activate"
+ );
+ return fs.existsSync(envBinActivatePath);
+}
+
+async function setupPythonEnv() {
+ console.log("Setting up python env for Continue extension...");
+
+ // Assemble the command to create the env
+ const [pythonCmd, pipCmd] = await getPythonPipCommands();
+ const [activateCmd, pipUpgradeCmd] = getActivateUpgradeCommands(
+ pythonCmd,
+ pipCmd
+ );
+ const createEnvCommand = [
+ `cd ${path.join(getExtensionUri().fsPath, "scripts")}`,
+ `${pythonCmd} -m venv env`,
+ ].join(" && ");
+
+ // Repeat until it is successfully created (sometimes it fails to generate the bin, need to try again)
+ while (true) {
+ const [, stderr] = await runCommand(createEnvCommand);
+ if (stderr) {
+ throw new Error(stderr);
+ }
+ if (checkEnvExists()) {
+ break;
+ } else {
+ // Remove the env and try again
+ const removeCommand = `rm -rf ${path.join(
+ getExtensionUri().fsPath,
+ "scripts",
+ "env"
+ )}`;
+ await runCommand(removeCommand);
+ }
}
console.log(
"Successfully set up python env at ",
getExtensionUri().fsPath + "/scripts/env"
);
- await startContinuePythonServer();
+ const installRequirementsCommand = [
+ `cd ${path.join(getExtensionUri().fsPath, "scripts")}`,
+ activateCmd,
+ pipUpgradeCmd,
+ `${pipCmd} install -r requirements.txt`,
+ ].join(" && ");
+ const [, stderr] = await runCommand(installRequirementsCommand);
+ if (stderr) {
+ throw new Error(stderr);
+ }
}
function readEnvFile(path: string) {
@@ -116,29 +154,19 @@ function writeEnvFile(path: string, key: string, value: string) {
}
export async function startContinuePythonServer() {
+ await setupPythonEnv();
+
// Check vscode settings
let serverUrl = getContinueServerUrl();
if (serverUrl !== "http://localhost:8000") {
return;
}
- let envFile = path.join(getExtensionUri().fsPath, "scripts", ".env");
- let openai_api_key: string | undefined =
- readEnvFile(envFile)["OPENAI_API_KEY"];
- while (typeof openai_api_key === "undefined" || openai_api_key === "") {
- openai_api_key = await vscode.window.showInputBox({
- prompt: "Enter your OpenAI API key",
- placeHolder: "Enter your OpenAI API key",
- });
- // Write to .env file
- }
- writeEnvFile(envFile, "OPENAI_API_KEY", openai_api_key);
-
console.log("Starting Continue python server...");
// Check if already running by calling /health
try {
- let response = await fetch(serverUrl + "/health");
+ const response = await fetch(serverUrl + "/health");
if (response.status === 200) {
console.log("Continue python server already running");
return;
@@ -152,15 +180,18 @@ export async function startContinuePythonServer() {
pythonCmd = "python";
}
+ // let command = `cd ${path.join(
+ // getExtensionUri().fsPath,
+ // "scripts"
+ // )} && ${activateCmd} && cd env/lib/python3.11/site-packages && ${pythonCmd} -m continuedev.server.main`;
let command = `cd ${path.join(
getExtensionUri().fsPath,
"scripts"
)} && ${activateCmd} && cd .. && ${pythonCmd} -m scripts.run_continue_server`;
try {
// exec(command);
- let child = spawn(command, {
+ const child = spawn(command, {
shell: true,
- detached: true,
});
child.stdout.on("data", (data: any) => {
console.log(`stdout: ${data}`);
@@ -194,11 +225,6 @@ export function isPythonEnvSetup(): boolean {
return fs.existsSync(path.join(pathToEnvCfg));
}
-export async function setupExtensionEnvironment() {
- console.log("Setting up environment for Continue extension...");
- await Promise.all([setupPythonEnv()]);
-}
-
export async function downloadPython3() {
// Download python3 and return the command to run it (python or python3)
let os = process.platform;
diff --git a/extension/src/commands.ts b/extension/src/commands.ts
index 18f08e31..aeeb4b4f 100644
--- a/extension/src/commands.ts
+++ b/extension/src/commands.ts
@@ -62,11 +62,11 @@ const commandsMap: { [command: string]: (...args: any) => any } = {
"continue.acceptSuggestion": acceptSuggestionCommand,
"continue.rejectSuggestion": rejectSuggestionCommand,
"continue.openDebugPanel": () => {
- ideProtocolClient?.openNotebook();
+ ideProtocolClient.openNotebook();
},
"continue.focusContinueInput": async () => {
if (!debugPanelWebview) {
- await ideProtocolClient?.openNotebook();
+ await ideProtocolClient.openNotebook();
}
debugPanelWebview?.postMessage({
type: "focusContinueInput",
diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts
index 35eb668d..477d1420 100644
--- a/extension/src/continueIdeClient.ts
+++ b/extension/src/continueIdeClient.ts
@@ -10,30 +10,28 @@ import {
} from "./suggestions";
import { debugPanelWebview, setupDebugPanel } from "./debugPanel";
import { FileEditWithFullContents } from "../schema/FileEditWithFullContents";
-const util = require("util");
-const exec = util.promisify(require("child_process").exec);
-const WebSocket = require("ws");
import fs = require("fs");
+import { WebsocketMessenger } from "./util/messenger";
class IdeProtocolClient {
- private _ws: WebSocket | null = null;
- private _panels: Map<string, vscode.WebviewPanel> = new Map();
- private readonly _serverUrl: string;
- private readonly _context: vscode.ExtensionContext;
+ private messenger: WebsocketMessenger | null = null;
+ private panels: Map<string, vscode.WebviewPanel> = new Map();
+ private readonly context: vscode.ExtensionContext;
private _makingEdit = 0;
constructor(serverUrl: string, context: vscode.ExtensionContext) {
- this._context = context;
- this._serverUrl = serverUrl;
- let ws = new WebSocket(serverUrl);
- this._ws = ws;
- ws.onclose = () => {
- this._ws = null;
- };
- ws.on("message", (data: any) => {
- this.handleMessage(JSON.parse(data));
+ this.context = context;
+
+ let messenger = new WebsocketMessenger(serverUrl);
+ this.messenger = messenger;
+ messenger.onClose(() => {
+ this.messenger = null;
+ });
+ messenger.onMessage((messageType, data) => {
+ this.handleMessage(messageType, data);
});
+
// Setup listeners for any file changes in open editors
vscode.workspace.onDidChangeTextDocument((event) => {
if (this._makingEdit === 0) {
@@ -58,125 +56,52 @@ class IdeProtocolClient {
};
}
);
- this.send("fileEdits", { fileEdits });
+ this.messenger?.send("fileEdits", { fileEdits });
} else {
this._makingEdit--;
}
});
}
- async isConnected() {
- if (this._ws === null || this._ws.readyState !== WebSocket.OPEN) {
- let ws = new WebSocket(this._serverUrl);
- ws.onclose = () => {
- this._ws = null;
- };
- ws.on("message", (data: any) => {
- this.handleMessage(JSON.parse(data));
- });
- this._ws = ws;
-
- return new Promise((resolve, reject) => {
- ws.addEventListener("open", () => {
- resolve(null);
- });
- });
- }
- }
-
- async startCore() {
- var { stdout, stderr } = await exec(
- "cd /Users/natesesti/Desktop/continue/continue && poetry shell"
- );
- if (stderr) {
- throw new Error(stderr);
- }
- var { stdout, stderr } = await exec(
- "cd .. && uvicorn continue.src.server.main:app --reload --reload-dir continue"
- );
- if (stderr) {
- throw new Error(stderr);
- }
- var { stdout, stderr } = await exec("python3 -m continue.src.libs.ide");
- if (stderr) {
- throw new Error(stderr);
- }
- }
-
- async send(messageType: string, data: object) {
- await this.isConnected();
- let msg = JSON.stringify({ messageType, ...data });
- this._ws!.send(msg);
- console.log("Sent message", msg);
- }
-
- async receiveMessage(messageType: string): Promise<any> {
- await this.isConnected();
- console.log("Connected to websocket");
- return await new Promise((resolve, reject) => {
- if (!this._ws) {
- reject("Not connected to websocket");
- }
- this._ws!.onmessage = (event: any) => {
- let message = JSON.parse(event.data);
- console.log("RECEIVED MESSAGE", message);
- if (message.messageType === messageType) {
- resolve(message);
- }
- };
- });
- }
-
- async sendAndReceive(message: any, messageType: string): Promise<any> {
- try {
- await this.send(messageType, message);
- let msg = await this.receiveMessage(messageType);
- console.log("Received message", msg);
- return msg;
- } catch (e) {
- console.log("Error sending message", e);
- }
- }
-
- async handleMessage(message: any) {
- switch (message.messageType) {
+ async handleMessage(messageType: string, data: any) {
+ switch (messageType) {
case "highlightedCode":
- this.send("highlightedCode", {
+ this.messenger?.send("highlightedCode", {
highlightedCode: this.getHighlightedCode(),
});
break;
case "workspaceDirectory":
- this.send("workspaceDirectory", {
+ this.messenger?.send("workspaceDirectory", {
workspaceDirectory: this.getWorkspaceDirectory(),
});
case "openFiles":
- this.send("openFiles", {
+ this.messenger?.send("openFiles", {
openFiles: this.getOpenFiles(),
});
break;
case "readFile":
- this.send("readFile", {
- contents: this.readFile(message.filepath),
+ this.messenger?.send("readFile", {
+ contents: this.readFile(data.filepath),
});
break;
case "editFile":
- let fileEdit = await this.editFile(message.edit);
- this.send("editFile", {
+ const fileEdit = await this.editFile(data.edit);
+ this.messenger?.send("editFile", {
fileEdit,
});
break;
case "saveFile":
- this.saveFile(message.filepath);
+ this.saveFile(data.filepath);
break;
case "setFileOpen":
- this.openFile(message.filepath);
+ this.openFile(data.filepath);
// TODO: Close file
break;
case "openNotebook":
case "connected":
break;
default:
- throw Error("Unknown message type:" + message.messageType);
+ throw Error("Unknown message type:" + messageType);
}
}
getWorkspaceDirectory() {
@@ -209,17 +134,20 @@ class IdeProtocolClient {
// Initiate Request
closeNotebook(sessionId: string) {
- this._panels.get(sessionId)?.dispose();
- this._panels.delete(sessionId);
+ this.panels.get(sessionId)?.dispose();
+ this.panels.delete(sessionId);
}
async openNotebook() {
console.log("OPENING NOTEBOOK");
- let resp = await this.sendAndReceive({}, "openNotebook");
- let sessionId = resp.sessionId;
+ if (this.messenger === null) {
+ console.log("MESSENGER IS NULL");
+ }
+ const resp = await this.messenger?.sendAndReceive("openNotebook", {});
+ const sessionId = resp.sessionId;
console.log("SESSION ID", sessionId);
- let column = getRightViewColumn();
+ const column = getRightViewColumn();
const panel = vscode.window.createWebviewPanel(
"continue.debugPanelView",
"Continue",
@@ -231,9 +159,9 @@ class IdeProtocolClient {
);
// And set its HTML content
- panel.webview.html = setupDebugPanel(panel, this._context, sessionId);
+ panel.webview.html = setupDebugPanel(panel, this.context, sessionId);
- this._panels.set(sessionId, panel);
+ this.panels.set(sessionId, panel);
}
acceptRejectSuggestion(accept: boolean, key: SuggestionRanges) {
diff --git a/extension/src/debugPanel.ts b/extension/src/debugPanel.ts
index 4192595c..a295085f 100644
--- a/extension/src/debugPanel.ts
+++ b/extension/src/debugPanel.ts
@@ -16,7 +16,6 @@ import {
import { sendTelemetryEvent, TelemetryEvent } from "./telemetry";
import { RangeInFile, SerializedDebugContext } from "./client";
import { addFileSystemToDebugContext } from "./util/util";
-const WebSocket = require("ws");
class StreamManager {
private _fullText: string = "";
@@ -108,15 +107,15 @@ class WebsocketConnection {
this._onOpen = onOpen;
this._onClose = onClose;
- this._ws.onmessage = (event) => {
+ this._ws.addEventListener("message", (event) => {
this._onMessage(event.data);
- };
- this._ws.onclose = () => {
+ });
+ this._ws.addEventListener("close", () => {
this._onClose();
- };
- this._ws.onopen = () => {
+ });
+ this._ws.addEventListener("open", () => {
this._onOpen();
- };
+ });
}
public send(message: string) {
@@ -230,6 +229,19 @@ export function setupDebugPanel(
apiUrl: getContinueServerUrl(),
sessionId,
});
+
+ // // Listen for changes to server URL in settings
+ // vscode.workspace.onDidChangeConfiguration((event) => {
+ // if (event.affectsConfiguration("continue.serverUrl")) {
+ // debugPanelWebview?.postMessage({
+ // type: "onLoad",
+ // vscMachineId: vscode.env.machineId,
+ // apiUrl: getContinueServerUrl(),
+ // sessionId,
+ // });
+ // }
+ // });
+
break;
}
diff --git a/extension/src/extension.ts b/extension/src/extension.ts
index e0b94278..88af0d19 100644
--- a/extension/src/extension.ts
+++ b/extension/src/extension.ts
@@ -4,7 +4,6 @@
import * as vscode from "vscode";
import {
- setupExtensionEnvironment,
isPythonEnvSetup,
startContinuePythonServer,
} from "./activation/environmentSetup";
@@ -26,11 +25,7 @@ export function activate(context: vscode.ExtensionContext) {
cancellable: false,
},
async () => {
- if (isPythonEnvSetup()) {
- await startContinuePythonServer();
- } else {
- await setupExtensionEnvironment();
- }
+ await startContinuePythonServer();
dynamicImportAndActivate(context, true);
}
);
diff --git a/extension/src/test/runTest.ts b/extension/src/test/runTest.ts
index 27b3ceb2..e810ed5b 100644
--- a/extension/src/test/runTest.ts
+++ b/extension/src/test/runTest.ts
@@ -1,23 +1,23 @@
-import * as path from 'path';
+import * as path from "path";
-import { runTests } from '@vscode/test-electron';
+import { runTests } from "@vscode/test-electron";
async function main() {
- try {
- // The folder containing the Extension Manifest package.json
- // Passed to `--extensionDevelopmentPath`
- const extensionDevelopmentPath = path.resolve(__dirname, '../../');
+ try {
+ // The folder containing the Extension Manifest package.json
+ // Passed to `--extensionDevelopmentPath`
+ const extensionDevelopmentPath = path.resolve(__dirname, "../../");
- // The path to test runner
- // Passed to --extensionTestsPath
- const extensionTestsPath = path.resolve(__dirname, './suite/index');
+ // The path to test runner
+ // Passed to --extensionTestsPath
+ const extensionTestsPath = path.resolve(__dirname, "./suite/index");
- // Download VS Code, unzip it and run the integration test
- await runTests({ extensionDevelopmentPath, extensionTestsPath });
- } catch (err) {
- console.error('Failed to run tests');
- process.exit(1);
- }
+ // Download VS Code, unzip it and run the integration test
+ await runTests({ extensionDevelopmentPath, extensionTestsPath });
+ } catch (err) {
+ console.error("Failed to run tests");
+ process.exit(1);
+ }
}
main();
diff --git a/extension/src/util/messenger.ts b/extension/src/util/messenger.ts
new file mode 100644
index 00000000..6f8bb29d
--- /dev/null
+++ b/extension/src/util/messenger.ts
@@ -0,0 +1,108 @@
+console.log("Websocket import");
+const WebSocket = require("ws");
+
+export abstract class Messenger {
+ abstract send(messageType: string, data: object): void;
+
+ abstract onMessageType(
+ messageType: string,
+ callback: (data: object) => void
+ ): void;
+
+ abstract onMessage(callback: (messageType: string, data: any) => void): void;
+
+ abstract onOpen(callback: () => void): void;
+
+ abstract onClose(callback: () => void): void;
+
+ abstract sendAndReceive(messageType: string, data: any): Promise<any>;
+}
+
+export class WebsocketMessenger extends Messenger {
+ websocket: WebSocket;
+ private onMessageListeners: {
+ [messageType: string]: ((data: object) => void)[];
+ } = {};
+ private onOpenListeners: (() => void)[] = [];
+ private onCloseListeners: (() => void)[] = [];
+ private serverUrl: string;
+
+ _newWebsocket(): WebSocket {
+ // // Dynamic import, because WebSocket is builtin with browser, but not with node. And can't use require in browser.
+ // if (typeof process === "object") {
+ // console.log("Using node");
+ // // process is only available in Node
+ // var WebSocket = require("ws");
+ // }
+
+ const newWebsocket = new WebSocket(this.serverUrl);
+ for (const listener of this.onOpenListeners) {
+ this.onOpen(listener);
+ }
+ for (const listener of this.onCloseListeners) {
+ this.onClose(listener);
+ }
+ for (const messageType in this.onMessageListeners) {
+ for (const listener of this.onMessageListeners[messageType]) {
+ this.onMessageType(messageType, listener);
+ }
+ }
+ return newWebsocket;
+ }
+
+ constructor(serverUrl: string) {
+ super();
+ this.serverUrl = serverUrl;
+ this.websocket = this._newWebsocket();
+ }
+
+ send(messageType: string, data: object) {
+ const payload = JSON.stringify({ messageType, data });
+ if (this.websocket.readyState === this.websocket.OPEN) {
+ this.websocket.send(payload);
+ } else {
+ if (this.websocket.readyState !== this.websocket.CONNECTING) {
+ this.websocket = this._newWebsocket();
+ }
+ this.websocket.addEventListener("open", () => {
+ this.websocket.send(payload);
+ });
+ }
+ }
+
+ sendAndReceive(messageType: string, data: any): Promise<any> {
+ return new Promise((resolve, reject) => {
+ const eventListener = (data: any) => {
+ // THIS ISN"T GETTING CALLED
+ resolve(data);
+ this.websocket.removeEventListener("message", eventListener);
+ };
+ this.onMessageType(messageType, eventListener);
+ this.send(messageType, data);
+ });
+ }
+
+ onMessageType(messageType: string, callback: (data: any) => void): void {
+ this.websocket.addEventListener("message", (event: any) => {
+ const msg = JSON.parse(event.data);
+ if (msg.messageType === messageType) {
+ callback(msg.data);
+ }
+ });
+ }
+
+ onMessage(callback: (messageType: string, data: any) => void): void {
+ this.websocket.addEventListener("message", (event) => {
+ const msg = JSON.parse(event.data);
+ callback(msg.messageType, msg.data);
+ });
+ }
+
+ onOpen(callback: () => void): void {
+ this.websocket.addEventListener("open", callback);
+ }
+
+ onClose(callback: () => void): void {
+ this.websocket.addEventListener("close", callback);
+ }
+}