diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-13 13:06:18 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-13 13:06:18 -0700 |
commit | 2e98a6a61bc32e77c03feefe211dce0dd920ce0d (patch) | |
tree | 1075b79df46044607d494f4c5300f82ddceac629 /continuedev | |
parent | 358e038019c397456c4e2a97de045de0c4299529 (diff) | |
download | sncontinue-2e98a6a61bc32e77c03feefe211dce0dd920ce0d.tar.gz sncontinue-2e98a6a61bc32e77c03feefe211dce0dd920ce0d.tar.bz2 sncontinue-2e98a6a61bc32e77c03feefe211dce0dd920ce0d.zip |
configurable default model and small updates
Diffstat (limited to 'continuedev')
10 files changed, 54 insertions, 25 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index ed056be9..ebd1fdb8 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -1,11 +1,9 @@ import json import os from pydantic import BaseModel, validator -from typing import List, Optional, Dict +from typing import List, Literal, Optional, Dict import yaml -from .main import Step - class SlashCommand(BaseModel): name: str @@ -21,6 +19,7 @@ class ContinueConfig(BaseModel): steps_on_startup: Optional[Dict[str, Dict]] = {} server_url: Optional[str] = None allow_anonymous_telemetry: Optional[bool] = True + default_model: Literal["gpt-3.5-turbo", "gpt-4", "starcoder"] slash_commands: Optional[List[SlashCommand]] = [ # SlashCommand( # name="pytest", diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index 2b50307a..9fcda882 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -26,7 +26,7 @@ class DemoPolicy(Policy): # At the very start, run initial Steps spcecified in the config if history.get_current() is None: return ( - # MessageStep(name="Welcome to Continue!", message="") >> + MessageStep(name="Welcome to Continue!", message="You can type a question or instructions for a file edit in the text box. If you highlight code, edits will be localized to the highlighted range. Otherwise, the currently open file is taken as context. If you type '/', you can see the list of available slash commands.") >> # SetupContinueWorkspaceStep() >> # CreateCodebaseIndexChroma() >> StepsOnStartupStep()) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 1f4cdfb2..a54048b0 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -3,7 +3,7 @@ from functools import cached_property from typing import Coroutine, Union import os -from ..steps.core.core import Gpt35EditCodeStep +from ..steps.core.core import DefaultModelEditCodeStep from ..models.main import Range from .abstract_sdk import AbstractContinueSDK from .config import ContinueConfig, load_config @@ -41,6 +41,29 @@ class Models: return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo") return asyncio.get_event_loop().run_until_complete(load_gpt35()) + @cached_property + def gpt4(self): + async def load_gpt4(): + api_key = await self.sdk.get_user_secret( + 'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file') + return OpenAI(api_key=api_key, default_model="gpt-4") + return asyncio.get_event_loop().run_until_complete(load_gpt4()) + + def __model_from_name(self, model_name: str): + if model_name == "starcoder": + return self.starcoder + elif model_name == "gpt35": + return self.gpt35 + elif model_name == "gpt4": + return self.gpt4 + else: + raise Exception(f"Unknown model {model_name}") + + @cached_property + def default(self): + default_model = self.sdk.config.default_model + return self.__model_from_name(default_model) if default_model is not None else self.gpt35 + class ContinueSDK(AbstractContinueSDK): """The SDK provided as parameters to a step""" @@ -85,7 +108,7 @@ class ContinueSDK(AbstractContinueSDK): await self.ide.setFileOpen(filepath) contents = await self.ide.readFile(filepath) - await self.run_step(Gpt35EditCodeStep( + await self.run_step(DefaultModelEditCodeStep( range_in_files=[RangeInFile(filepath=filepath, range=range) if range is not None else RangeInFile.from_entire_file( filepath, contents)], user_input=prompt, diff --git a/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py b/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py index 6a743fb5..9744146c 100644 --- a/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/AddTransformRecipe/steps.py @@ -49,15 +49,15 @@ class AddTransformStep(Step): filename = f'{source_name}_pipeline.py' abs_filepath = os.path.join(sdk.ide.workspace_directory, filename) + # Open the file and highlight the function to be edited + await sdk.ide.setFileOpen(abs_filepath) + await sdk.run_step(MessageStep(message=dedent("""\ This step will customize your resource function with a transform of your choice: - Add a filter or map transformation depending on your request - Load the data into a local DuckDB instance - Open up a Streamlit app for you to view the data"""), name="Write transformation function")) - # Open the file and highlight the function to be edited - await sdk.ide.setFileOpen(abs_filepath) - with open(os.path.join(os.path.dirname(__file__), 'dlt_transform_docs.md')) as f: dlt_transform_docs = f.read() @@ -77,6 +77,8 @@ class AddTransformStep(Step): name=f"Writing transform function {AI_ASSISTED_STRING}" ) + await sdk.wait_for_user_confirmation("Press Continue to confirm that the changes are okay before we run the pipeline.") + # run the pipeline and load the data await sdk.run(f'python3 {filename}', name="Run the pipeline", description=f"Running `python3 {filename}` to load the data into a local DuckDB instance") diff --git a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py index 4b8971c2..6426df87 100644 --- a/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DDtoBQRecipe/steps.py @@ -82,7 +82,7 @@ class LoadDataStep(Step): docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = sdk.models.gpt35.complete(dedent(f"""\ + suggestion = sdk.models.gpt4.complete(dedent(f"""\ When trying to load data into BigQuery, the following error occurred: ```ascii diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/main.py b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/main.py index f3601c2d..503b0c85 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/main.py +++ b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/main.py @@ -5,7 +5,7 @@ from ...core.main import Step from ...core.sdk import ContinueSDK from ...steps.core.core import WaitForUserInputStep from ...steps.core.core import MessageStep -from .steps import SetupPipelineStep, DeployAirflowStep +from .steps import SetupPipelineStep, DeployAirflowStep, RunPipelineStep # https://github.com/dlt-hub/dlt-deploy-template/blob/master/airflow-composer/dag_template.py @@ -46,5 +46,6 @@ class DeployPipelineAirflowRecipe(Step): ) await sdk.run_step( SetupPipelineStep(source_name=source_name) >> + RunPipelineStep(source_name=source_name) >> DeployAirflowStep(source_name=source_name) ) diff --git a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py index c9749348..2f346b49 100644 --- a/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py +++ b/continuedev/src/continuedev/recipes/DeployPipelineAirflowRecipe/steps.py @@ -69,24 +69,26 @@ class DeployAirflowStep(Step): f'dlt --non-interactive deploy {self.source_name}_pipeline.py airflow-composer', ], description="Running `dlt deploy airflow` to deploy the dlt pipeline to Airflow", name="Deploy dlt pipeline to Airflow") - # Modify the DAG file + # Get filepaths, open the DAG file directory = await sdk.ide.getWorkspaceDirectory() pipeline_filepath = os.path.join( directory, f"{self.source_name}_pipeline.py") dag_filepath = os.path.join( directory, f"dags/dag_{self.source_name}_pipeline.py") + await sdk.ide.setFileOpen(dag_filepath) + # Replace the pipeline name and dataset name await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="'pipeline_name'", replacement=f"'{self.source_name}_pipeline'")) await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="'dataset_name'", replacement=f"'{self.source_name}_data'")) await sdk.run_step(FindAndReplaceStep(filepath=pipeline_filepath, pattern="pipeline_or_source_script", replacement=f"{self.source_name}_pipeline")) # Prompt the user for the DAG schedule - edit_dag_range = Range.from_shorthand(18, 0, 23, 0) - await sdk.ide.highlightCode(range_in_file=RangeInFile(filepath=dag_filepath, range=edit_dag_range), color="#33993333") - response = await sdk.run_step(WaitForUserInputStep(prompt="When would you like this Airflow DAG to run? (e.g. every day, every Monday, every 1st of the month, etc.)")) - await sdk.edit_file(dag_filepath, prompt=f"Edit the DAG so that it runs at the following schedule: '{response.text}'", - range=edit_dag_range) + # edit_dag_range = Range.from_shorthand(18, 0, 23, 0) + # await sdk.ide.highlightCode(range_in_file=RangeInFile(filepath=dag_filepath, range=edit_dag_range), color="#33993333") + # response = await sdk.run_step(WaitForUserInputStep(prompt="When would you like this Airflow DAG to run? (e.g. every day, every Monday, every 1st of the month, etc.)")) + # await sdk.edit_file(dag_filepath, prompt=f"Edit the DAG so that it runs at the following schedule: '{response.text}'", + # range=edit_dag_range) # Tell the user to check the schedule and fill in owner, email, other default_args await sdk.run_step(MessageStep(message="Fill in the owner, email, and other default_args in the DAG file with your own personal information. Then the DAG will be ready to run!", name="Fill in default_args")) diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index aadcfa8e..aca0ab2a 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -11,7 +11,7 @@ class SimpleChatStep(Step): async def run(self, sdk: ContinueSDK): self.description = "" - for chunk in sdk.models.gpt35.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): + for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): self.description += chunk await sdk.update_ui() diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 8dc2478b..ff498b9b 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -109,7 +109,7 @@ class ShellCommandsStep(Step): # return None -class Gpt35EditCodeStep(Step): +class DefaultModelEditCodeStep(Step): user_input: str range_in_files: List[RangeInFile] name: str = "Editing Code" @@ -153,11 +153,13 @@ class Gpt35EditCodeStep(Step): async def describe(self, models: Models) -> Coroutine[str, None, None]: description = models.gpt35.complete( f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points. Be concise and only mention changes made to the commit before, not prefix or suffix:") - self.name = models.gpt35.complete( - f"Write a short title for this description: {description}") + # self.name = models.gpt35.complete( + # f"Write a short title for this description: {description}") return description async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: + self.name = self.user_input + rif_with_contents = [] for range_in_file in self.range_in_files: file_contents = await sdk.ide.readRangeInFile(range_in_file) @@ -174,7 +176,7 @@ class Gpt35EditCodeStep(Step): prompt = self._prompt.format( code=rif.contents, user_request=self.user_input, file_prefix=segs[0], file_suffix=segs[1]) - completion = str(sdk.models.gpt35.complete(prompt)) + completion = str(sdk.models.default.complete(prompt)) eot_token = "<|endoftext|>" completion = completion.removesuffix(eot_token) @@ -225,7 +227,7 @@ class EditFileStep(Step): async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: file_contents = await sdk.ide.readFile(self.filepath) - await sdk.run_step(Gpt35EditCodeStep( + await sdk.run_step(DefaultModelEditCodeStep( range_in_files=[RangeInFile.from_entire_file( self.filepath, file_contents)], user_input=self.prompt diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 36e4f519..0e42d8bf 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -15,7 +15,7 @@ from ..core.main import Step from ..core.sdk import ContinueSDK, Models from ..core.observation import Observation import subprocess -from .core.core import Gpt35EditCodeStep +from .core.core import DefaultModelEditCodeStep from ..libs.util.calculate_diff import calculate_diff2 @@ -287,7 +287,7 @@ class EditHighlightedCodeStep(Step): range_in_files = [RangeInFile.from_entire_file( filepath, content) for filepath, content in contents.items()] - await sdk.run_step(Gpt35EditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) + await sdk.run_step(DefaultModelEditCodeStep(user_input=self.user_input, range_in_files=range_in_files)) class FindCodeStep(Step): |