summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/config.py4
-rw-r--r--continuedev/src/continuedev/core/sdk.py75
2 files changed, 40 insertions, 39 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 23f4fe65..6957ae5e 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -39,7 +39,9 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- llm: LLM = OpenAI(default_model="gpt-4")
+ models: Models = Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ )
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 92a72b23..183518ac 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -8,15 +8,11 @@ from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
-from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
-from ..libs.llm.openai import OpenAI
-from ..libs.llm.anthropic import AnthropicLLM
-from ..libs.llm.ggml import GGML
+from ..libs.llm import LLM
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage
from ..plugins.steps.core.core import *
-from ..libs.llm.proxy_server import ProxyServer
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
@@ -25,18 +21,12 @@ class Autopilot:
pass
-ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
-MODEL_PROVIDER_TO_ENV_VAR = {
- "openai": "OPENAI_API_KEY",
- "hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY",
-}
-
-
class Models:
- provider_keys: Dict[ModelProvider, str] = {}
- model_providers: List[ModelProvider]
- system_message: str
+ """Main class that holds the current model configuration"""
+ default: LLM
+ small: Optional[LLM] = None
+ medium: Optional[LLM] = None
+ large: Optional[LLM] = None
"""
Better to have sdk.llm.stream_chat(messages, model="claude-2").
@@ -57,32 +47,41 @@ class Models:
'''depending on the model, return the single prompt string'''
"""
- def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
- self.sdk = sdk
- self.model_providers = model_providers
+ def __init__(self, *, default, small=None, medium=None, large=None, custom=None):
+ self.default = default
+ self.small = small
+ self.medium = medium
+ self.large = large
self.system_message = sdk.config.system_message
- @classmethod
- async def create(cls, sdk: "ContinueSDK") -> "Models":
- return self.default
+ async def start(sdk: "ContinueSDK"):
+ self.sdk = sdk
+ await self.default.start()
+ if self.small:
+ await self.small.start()
+ else:
+ self.small = self.default
- def __load_openai_model(self, model: str) -> OpenAI:
- api_key = self.provider_keys["openai"]
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, azure_info=self.sdk.config.azure_openai_info, write_log=self.sdk.write_log)
+ if self.medium:
+ await self.medium.start()
+ else:
+ self.medium = self.default
- def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
- api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
+ if self.large:
+ await self.large.start()
+ else:
+ self.large = self.default
- def __load_anthropic_model(self, model: str) -> AnthropicLLM:
- api_key = self.provider_keys["anthropic"]
- return AnthropicLLM(api_key, model, self.system_message)
+ async def stop(sdk: "ContinueSDK"):
+ await self.default.stop()
+ if self.small:
+ await self.small.stop()
- @property
- def default(self):
- return sdk.config.llm if sdk.config.llm is not None else ProxyServer(default_model="gpt-4")
+ if self.medium:
+ await self.medium.stop()
+
+ if self.large:
+ await self.large.stop()
class ContinueSDK(AbstractContinueSDK):
@@ -118,8 +117,8 @@ class ContinueSDK(AbstractContinueSDK):
active=False
))
- sdk.models = await Models.create(sdk)
- await sdk.models.start()
+ sdk.models = sdk.config.models
+ await sdk.models.start(sdk)
return sdk
@property