diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-07-16 22:12:44 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-07-16 22:12:44 -0700 |
commit | 73e1cfbefbf450ab6564aba653e0132843223c7a (patch) | |
tree | e7b8aaec84a400e6b1d1c23ab1e703204b20a4d9 /continuedev/src/continuedev/core | |
parent | c5d05cec0cafa541c6b00153433864f95beeb56c (diff) | |
download | sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.gz sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.bz2 sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.zip |
templated system messages
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r-- | continuedev/src/continuedev/core/config.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 10 |
2 files changed, 7 insertions, 4 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 957609c5..91a47c8e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -85,6 +85,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS on_traceback: Optional[List[OnTracebackSteps]] = [ OnTracebackSteps(step_name="DefaultOnTracebackStep")] + system_message: Optional[str] = None # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index eb60109c..ac57c122 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -34,10 +34,12 @@ MODEL_PROVIDER_TO_ENV_VAR = { class Models: provider_keys: Dict[ModelProvider, str] = {} model_providers: List[ModelProvider] + system_message: str def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk self.model_providers = model_providers + self.system_message = sdk.config.system_message @classmethod async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": @@ -53,12 +55,12 @@ class Models: def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) + return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message) def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: api_key = self.provider_keys["hf_inference_api"] - return HuggingFaceInferenceAPI(api_key=api_key, model=model) + return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) @cached_property def starcoder(self): @@ -82,7 +84,7 @@ class Models: @cached_property def ggml(self): - return GGML() + return GGML(system_message=self.system_message) def __model_from_name(self, model_name: str): if model_name == "starcoder": |