diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-01 00:09:19 -0400 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-01 00:09:19 -0400 |
commit | ea5d50af9ba84242c25e82069d86c08ac039e543 (patch) | |
tree | 29b99530ca6df54f5bc641be667c6820ff7e4014 /continuedev/src | |
parent | e6dded34c26fd17ede17776755cc41c26782a045 (diff) | |
download | sncontinue-ea5d50af9ba84242c25e82069d86c08ac039e543.tar.gz sncontinue-ea5d50af9ba84242c25e82069d86c08ac039e543.tar.bz2 sncontinue-ea5d50af9ba84242c25e82069d86c08ac039e543.zip |
Polishing for dlt codespace and !config!
Diffstat (limited to 'continuedev/src')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 29 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/policy.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 17 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/steps/draft/dlt.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/steps/main.py | 20 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/steps/steps_on_startup.py | 30 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/main.py | 8 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/notebook.py | 2 |
8 files changed, 112 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py new file mode 100644 index 00000000..e62f0e4f --- /dev/null +++ b/continuedev/src/continuedev/core/config.py @@ -0,0 +1,29 @@ +import json +import os +from pydantic import BaseModel +from typing import List, Optional, Dict +import yaml + + +class ContinueConfig(BaseModel): + """ + A pydantic class for the continue config file. + """ + steps_on_startup: Optional[Dict[str, Dict]] = {} + server_url: Optional[str] = None + + +def load_config(config_file: str) -> ContinueConfig: + """ + Load the config file and return a ContinueConfig object. + """ + _, ext = os.path.splitext(config_file) + if ext == '.yaml': + with open(config_file, 'r') as f: + config_dict = yaml.safe_load(f) + elif ext == '.json': + with open(config_file, 'r') as f: + config_dict = json.load(f) + else: + raise ValueError(f'Unknown config file extension: {ext}') + return ContinueConfig(**config_dict) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index c0ba0f4f..9f68515f 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,9 +1,10 @@ from typing import List, Tuple, Type -from ..libs.steps.ty import CreatePipelineStep +from ..libs.steps.steps_on_startup import StepsOnStartupStep +from ..libs.steps.draft.dlt import CreatePipelineStep from .main import Step, Validator, History, Policy from .observation import Observation, TracebackObservation, UserInputObservation -from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep +from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep from ..libs.steps.nate import WritePytestsStep, CreateTableStep # from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma from ..libs.steps.continue_step import ContinueStepStep @@ -13,6 +14,10 @@ class DemoPolicy(Policy): ran_code_last: bool = False def next(self, history: History) -> Step: + # At the very start, run initial Steps spcecified in the config + if history.get_current() is None: + return MessageStep(message="Welcome to Continue!") >> StepsOnStartupStep() + observation = history.last_observation() if observation is not None and isinstance(observation, UserInputObservation): # This could be defined with ObservationTypePolicy. Ergonomics not right though. diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 6ae0be04..4d82a1ae 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,5 +1,7 @@ import os from typing import Coroutine, Union + +from .config import ContinueConfig, load_config from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory from ..models.filesystem import RangeInFile from ..libs.llm import LLM @@ -106,3 +108,18 @@ class ContinueSDK: val = (await self.run_step(WaitForUserInputStep(prompt=prompt))).text save_env_var(env_var, val) return val + + async def get_config(self) -> ContinueConfig: + dir = await self.ide.getWorkspaceDirectory() + yaml_path = os.path.join(dir, 'continue.yaml') + json_path = os.path.join(dir, 'continue.json') + if os.path.exists(yaml_path): + return load_config(yaml_path) + elif os.path.exists(json_path): + return load_config(json_path) + else: + return ContinueConfig() + + def set_loading_message(self, message: str): + # self.__agent.set_loading_message(message) + raise NotImplementedError() diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py index 460aa0cc..778ced1d 100644 --- a/continuedev/src/continuedev/libs/steps/draft/dlt.py +++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py @@ -2,7 +2,7 @@ from textwrap import dedent from ....models.filesystem_edit import AddFile from ....core.main import Step from ....core.sdk import ContinueSDK -from ..main import WaitForUserInputStep +from ..core.core import WaitForUserInputStep class SetupPipelineStep(Step): @@ -77,6 +77,6 @@ class CreatePipelineStep(Step): async def run(self, sdk: ContinueSDK): await sdk.run_step( WaitForUserInputStep(prompt="What API do you want to load data from?") >> - SetupPipelineStep() >> + SetupPipelineStep(api_description="WeatherAPI.com API") >> ValidatePipelineStep() ) diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py index 70c0d4b8..d31db0eb 100644 --- a/continuedev/src/continuedev/libs/steps/main.py +++ b/continuedev/src/continuedev/libs/steps/main.py @@ -331,3 +331,23 @@ class SolveTracebackStep(Step): await sdk.run_step(EditCodeStep( range_in_files=range_in_files, prompt=prompt)) return None + + +class MessageStep(Step): + message: str + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return self.message + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + pass + + +class EmptyStep(Step): + hide: bool = True + + async def describe(self, models: Models) -> Coroutine[str, None, None]: + return "" + + async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + pass diff --git a/continuedev/src/continuedev/libs/steps/steps_on_startup.py b/continuedev/src/continuedev/libs/steps/steps_on_startup.py new file mode 100644 index 00000000..fd1eb8f0 --- /dev/null +++ b/continuedev/src/continuedev/libs/steps/steps_on_startup.py @@ -0,0 +1,30 @@ + + +from ...core.main import ContinueSDK, Models, Step +from .main import UserInputStep +from .draft.dlt import CreatePipelineStep + + +step_name_to_step_class = { + "UserInputStep": UserInputStep, + "CreatePipelineStep": CreatePipelineStep +} + + +class StepsOnStartupStep(Step): + hide: bool = True + + async def describe(self, models: Models): + return "Running steps on startup" + + async def run(self, sdk: ContinueSDK): + steps_descriptions = (await sdk.get_config()).steps_on_startup + + for step_name, step_params in steps_descriptions.items(): + try: + step = step_name_to_step_class[step_name](**step_params) + except: + print( + f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}") + continue + await sdk.run_step(step) diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 1977bfdd..1ffe1450 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,3 +1,4 @@ +import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router @@ -32,8 +33,11 @@ args = parser.parse_args() def run_server(): - uvicorn.run(app, host="0.0.0.0", port=args.port, - log_config="logging.yaml") + if os.path.exists("logging.yaml"): + uvicorn.run(app, host="0.0.0.0", port=args.port, + log_config="logging.yaml") + else: + uvicorn.run(app, host="0.0.0.0", port=args.port) if __name__ == "__main__": diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py index ab9211a8..9ca510dd 100644 --- a/continuedev/src/continuedev/server/notebook.py +++ b/continuedev/src/continuedev/server/notebook.py @@ -53,7 +53,7 @@ class NotebookProtocolServer(AbstractNotebookProtocolServer): async def _send_json(self, message_type: str, data: Any): await self.websocket.send_json({ - "message_type": message_type, + "messageType": message_type, "data": data }) |