diff options
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/config.py | 29 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 17 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/steps/draft/dlt.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/steps/main.py | 20 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/steps/steps_on_startup.py | 30 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/main.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/notebook.py | 2 | 
8 files changed, 112 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py new file mode 100644 index 00000000..e62f0e4f --- /dev/null +++ b/continuedev/src/continuedev/core/config.py @@ -0,0 +1,29 @@ +import json +import os +from pydantic import BaseModel +from typing import List, Optional, Dict +import yaml + + +class ContinueConfig(BaseModel): +    """ +    A pydantic class for the continue config file. +    """ +    steps_on_startup: Optional[Dict[str, Dict]] = {} +    server_url: Optional[str] = None + + +def load_config(config_file: str) -> ContinueConfig: +    """ +    Load the config file and return a ContinueConfig object. +    """ +    _, ext = os.path.splitext(config_file) +    if ext == '.yaml': +        with open(config_file, 'r') as f: +            config_dict = yaml.safe_load(f) +    elif ext == '.json': +        with open(config_file, 'r') as f: +            config_dict = json.load(f) +    else: +        raise ValueError(f'Unknown config file extension: {ext}') +    return ContinueConfig(**config_dict) diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index c0ba0f4f..9f68515f 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -1,9 +1,10 @@  from typing import List, Tuple, Type -from ..libs.steps.ty import CreatePipelineStep +from ..libs.steps.steps_on_startup import StepsOnStartupStep +from ..libs.steps.draft.dlt import CreatePipelineStep  from .main import Step, Validator, History, Policy  from .observation import Observation, TracebackObservation, UserInputObservation -from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep +from ..libs.steps.main import EditHighlightedCodeStep, SolveTracebackStep, RunCodeStep, FasterEditHighlightedCodeStep, StarCoderEditHighlightedCodeStep, MessageStep, EmptyStep  from ..libs.steps.nate import WritePytestsStep, CreateTableStep  # from ..libs.steps.chroma import AnswerQuestionChroma, EditFileChroma  from ..libs.steps.continue_step import ContinueStepStep @@ -13,6 +14,10 @@ class DemoPolicy(Policy):      ran_code_last: bool = False      def next(self, history: History) -> Step: +        # At the very start, run initial Steps spcecified in the config +        if history.get_current() is None: +            return MessageStep(message="Welcome to Continue!") >> StepsOnStartupStep() +          observation = history.last_observation()          if observation is not None and isinstance(observation, UserInputObservation):              # This could be defined with ObservationTypePolicy. Ergonomics not right though. diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 6ae0be04..4d82a1ae 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -1,5 +1,7 @@  import os  from typing import Coroutine, Union + +from .config import ContinueConfig, load_config  from ..models.filesystem_edit import FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory  from ..models.filesystem import RangeInFile  from ..libs.llm import LLM @@ -106,3 +108,18 @@ class ContinueSDK:          val = (await self.run_step(WaitForUserInputStep(prompt=prompt))).text          save_env_var(env_var, val)          return val + +    async def get_config(self) -> ContinueConfig: +        dir = await self.ide.getWorkspaceDirectory() +        yaml_path = os.path.join(dir, 'continue.yaml') +        json_path = os.path.join(dir, 'continue.json') +        if os.path.exists(yaml_path): +            return load_config(yaml_path) +        elif os.path.exists(json_path): +            return load_config(json_path) +        else: +            return ContinueConfig() + +    def set_loading_message(self, message: str): +        # self.__agent.set_loading_message(message) +        raise NotImplementedError() diff --git a/continuedev/src/continuedev/libs/steps/draft/dlt.py b/continuedev/src/continuedev/libs/steps/draft/dlt.py index 460aa0cc..778ced1d 100644 --- a/continuedev/src/continuedev/libs/steps/draft/dlt.py +++ b/continuedev/src/continuedev/libs/steps/draft/dlt.py @@ -2,7 +2,7 @@ from textwrap import dedent  from ....models.filesystem_edit import AddFile  from ....core.main import Step  from ....core.sdk import ContinueSDK -from ..main import WaitForUserInputStep +from ..core.core import WaitForUserInputStep  class SetupPipelineStep(Step): @@ -77,6 +77,6 @@ class CreatePipelineStep(Step):      async def run(self, sdk: ContinueSDK):          await sdk.run_step(              WaitForUserInputStep(prompt="What API do you want to load data from?") >> -            SetupPipelineStep() >> +            SetupPipelineStep(api_description="WeatherAPI.com API") >>              ValidatePipelineStep()          ) diff --git a/continuedev/src/continuedev/libs/steps/main.py b/continuedev/src/continuedev/libs/steps/main.py index 70c0d4b8..d31db0eb 100644 --- a/continuedev/src/continuedev/libs/steps/main.py +++ b/continuedev/src/continuedev/libs/steps/main.py @@ -331,3 +331,23 @@ class SolveTracebackStep(Step):          await sdk.run_step(EditCodeStep(              range_in_files=range_in_files, prompt=prompt))          return None + + +class MessageStep(Step): +    message: str + +    async def describe(self, models: Models) -> Coroutine[str, None, None]: +        return self.message + +    async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: +        pass + + +class EmptyStep(Step): +    hide: bool = True + +    async def describe(self, models: Models) -> Coroutine[str, None, None]: +        return "" + +    async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: +        pass diff --git a/continuedev/src/continuedev/libs/steps/steps_on_startup.py b/continuedev/src/continuedev/libs/steps/steps_on_startup.py new file mode 100644 index 00000000..fd1eb8f0 --- /dev/null +++ b/continuedev/src/continuedev/libs/steps/steps_on_startup.py @@ -0,0 +1,30 @@ + + +from ...core.main import ContinueSDK, Models, Step +from .main import UserInputStep +from .draft.dlt import CreatePipelineStep + + +step_name_to_step_class = { +    "UserInputStep": UserInputStep, +    "CreatePipelineStep": CreatePipelineStep +} + + +class StepsOnStartupStep(Step): +    hide: bool = True + +    async def describe(self, models: Models): +        return "Running steps on startup" + +    async def run(self, sdk: ContinueSDK): +        steps_descriptions = (await sdk.get_config()).steps_on_startup + +        for step_name, step_params in steps_descriptions.items(): +            try: +                step = step_name_to_step_class[step_name](**step_params) +            except: +                print( +                    f"Incorrect parameters for step {step_name}. Parameters provided were: {step_params}") +                continue +            await sdk.run_step(step) diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index 1977bfdd..1ffe1450 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,3 +1,4 @@ +import os  from fastapi import FastAPI  from fastapi.middleware.cors import CORSMiddleware  from .ide import router as ide_router @@ -32,8 +33,11 @@ args = parser.parse_args()  def run_server(): -    uvicorn.run(app, host="0.0.0.0", port=args.port, -                log_config="logging.yaml") +    if os.path.exists("logging.yaml"): +        uvicorn.run(app, host="0.0.0.0", port=args.port, +                    log_config="logging.yaml") +    else: +        uvicorn.run(app, host="0.0.0.0", port=args.port)  if __name__ == "__main__": diff --git a/continuedev/src/continuedev/server/notebook.py b/continuedev/src/continuedev/server/notebook.py index ab9211a8..9ca510dd 100644 --- a/continuedev/src/continuedev/server/notebook.py +++ b/continuedev/src/continuedev/server/notebook.py @@ -53,7 +53,7 @@ class NotebookProtocolServer(AbstractNotebookProtocolServer):      async def _send_json(self, message_type: str, data: Any):          await self.websocket.send_json({ -            "message_type": message_type, +            "messageType": message_type,              "data": data          })  | 
