summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/core
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-16 22:12:44 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-16 22:12:44 -0700
commit73e1cfbefbf450ab6564aba653e0132843223c7a (patch)
treee7b8aaec84a400e6b1d1c23ab1e703204b20a4d9 /continuedev/src/continuedev/core
parentc5d05cec0cafa541c6b00153433864f95beeb56c (diff)
downloadsncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.gz
sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.tar.bz2
sncontinue-73e1cfbefbf450ab6564aba653e0132843223c7a.zip
templated system messages
Diffstat (limited to 'continuedev/src/continuedev/core')
-rw-r--r--continuedev/src/continuedev/core/config.py1
-rw-r--r--continuedev/src/continuedev/core/sdk.py10
2 files changed, 7 insertions, 4 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 957609c5..91a47c8e 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -85,6 +85,7 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS
on_traceback: Optional[List[OnTracebackSteps]] = [
OnTracebackSteps(step_name="DefaultOnTracebackStep")]
+ system_message: Optional[str] = None
# Want to force these to be the slash commands for now
@validator('slash_commands', pre=True)
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index eb60109c..ac57c122 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -34,10 +34,12 @@ MODEL_PROVIDER_TO_ENV_VAR = {
class Models:
provider_keys: Dict[ModelProvider, str] = {}
model_providers: List[ModelProvider]
+ system_message: str
def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
self.model_providers = model_providers
+ self.system_message = sdk.config.system_message
@classmethod
async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
@@ -53,12 +55,12 @@ class Models:
def __load_openai_model(self, model: str) -> OpenAI:
api_key = self.provider_keys["openai"]
if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model)
- return OpenAI(api_key=api_key, default_model=model)
+ return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message)
+ return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message)
def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model)
+ return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
@cached_property
def starcoder(self):
@@ -82,7 +84,7 @@ class Models:
@cached_property
def ggml(self):
- return GGML()
+ return GGML(system_message=self.system_message)
def __model_from_name(self, model_name: str):
if model_name == "starcoder":