summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorLuna <git@l4.pm>2023-07-29 15:10:59 -0300
committerLuna <git@l4.pm>2023-07-29 15:11:31 -0300
commit714867f9a0d99548eef30c870b32384454b873ed (patch)
tree68322ebccacadf74af8d043372294580aeb5faac /continuedev
parent2b651d2504638ea9db97ba612f702356e38a805e (diff)
downloadsncontinue-714867f9a0d99548eef30c870b32384454b873ed.tar.gz
sncontinue-714867f9a0d99548eef30c870b32384454b873ed.tar.bz2
sncontinue-714867f9a0d99548eef30c870b32384454b873ed.zip
turn Models and LLM into pydantic-compatible classes
required as they're part of the config class
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/config.py9
-rw-r--r--continuedev/src/continuedev/core/models.py67
-rw-r--r--continuedev/src/continuedev/core/sdk.py68
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py9
6 files changed, 81 insertions, 80 deletions
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index 6957ae5e..af37264d 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -2,7 +2,8 @@ import json
import os
from .main import Step
from .context import ContextProvider
-from ..libs.llm.openai import OpenAI
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from .models import Models
from pydantic import BaseModel, validator
from typing import List, Literal, Optional, Dict, Type, Union
import yaml
@@ -26,12 +27,6 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class AzureInfo(BaseModel):
- endpoint: str
- engine: str
- api_version: str
-
-
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
new file mode 100644
index 00000000..c939d504
--- /dev/null
+++ b/continuedev/src/continuedev/core/models.py
@@ -0,0 +1,67 @@
+from typing import Optional
+from pydantic import BaseModel
+from ..libs.llm import LLM
+
+
+class Models(BaseModel):
+ """Main class that holds the current model configuration"""
+ default: LLM
+ small: Optional[LLM] = None
+ medium: Optional[LLM] = None
+ large: Optional[LLM] = None
+
+ """
+ Better to have sdk.llm.stream_chat(messages, model="claude-2").
+ Then you also don't care that it' async.
+ And it's easier to add more models.
+ And intermediate shared code is easier to add.
+ And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
+ PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
+ Easy to reason about, can place anywhere.
+ And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
+ This can all happen inside of Models?
+
+ class Prompt:
+ def __init__(self, ...info):
+ '''take whatever info is needed to describe the prompt'''
+
+ def to_string(self, model: str) -> str:
+ '''depending on the model, return the single prompt string'''
+ """
+
+ async def _start(llm: LLM):
+ kwargs = {}
+ if llm.required_api_key:
+ kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key)
+ await llm.start(**kwargs)
+
+ async def start(sdk: "ContinueSDK"):
+ self.sdk = sdk
+ self.system_message = self.sdk.config.system_message
+ await self._start(self.default)
+ if self.small:
+ await self._start(self.small)
+ else:
+ self.small = self.default
+
+ if self.medium:
+ await self._start(self.medium)
+ else:
+ self.medium = self.default
+
+ if self.large:
+ await self._start(self.large)
+ else:
+ self.large = self.default
+
+ async def stop(sdk: "ContinueSDK"):
+ await self.default.stop()
+ if self.small:
+ await self.small.stop()
+
+ if self.medium:
+ await self.medium.stop()
+
+ if self.large:
+ await self.large.stop()
+
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 784f8ed1..b0f7d40a 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -21,74 +21,6 @@ class Autopilot:
pass
-class Models:
- """Main class that holds the current model configuration"""
- default: LLM
- small: Optional[LLM] = None
- medium: Optional[LLM] = None
- large: Optional[LLM] = None
-
- """
- Better to have sdk.llm.stream_chat(messages, model="claude-2").
- Then you also don't care that it' async.
- And it's easier to add more models.
- And intermediate shared code is easier to add.
- And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
- PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
- Easy to reason about, can place anywhere.
- And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
- This can all happen inside of Models?
-
- class Prompt:
- def __init__(self, ...info):
- '''take whatever info is needed to describe the prompt'''
-
- def to_string(self, model: str) -> str:
- '''depending on the model, return the single prompt string'''
- """
-
- def __init__(self, *, default, small=None, medium=None, large=None, custom=None):
- self.default = default
- self.small = small
- self.medium = medium
- self.large = large
- self.system_message = sdk.config.system_message
-
- async def _start(llm: LLM):
- kwargs = {}
- if llm.required_api_key:
- kwargs["api_key"] = await self.sdk.get_api_secret(llm.required_api_key)
- await llm.start(**kwargs)
-
- async def start(sdk: "ContinueSDK"):
- self.sdk = sdk
- await self._start(self.default)
- if self.small:
- await self._start(self.small)
- else:
- self.small = self.default
-
- if self.medium:
- await self._start(self.medium)
- else:
- self.medium = self.default
-
- if self.large:
- await self._start(self.large)
- else:
- self.large = self.default
-
- async def stop(sdk: "ContinueSDK"):
- await self.default.stop()
- if self.small:
- await self.small.stop()
-
- if self.medium:
- await self.medium.stop()
-
- if self.large:
- await self.large.stop()
-
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 6ae3dd46..0f6b1505 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,12 +1,14 @@
+import functools
from abc import ABC
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from pydantic import BaseModel, ConfigDict
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
from ...core.main import ChatMessage
from ...models.main import AbstractModel
from pydantic import BaseModel
-class LLM(ABC):
+class LLM(BaseModel, ABC):
required_api_key: Optional[str] = None
system_message: Union[str, None] = None
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index ef771a2e..52e44bfe 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,6 +1,7 @@
from functools import cached_property
import json
from typing import Any, Coroutine, Dict, Generator, List, Union
+from pydantic import ConfigDict
import aiohttp
from ...core.main import ChatMessage
@@ -15,7 +16,6 @@ class GGML(LLM):
def __init__(self, system_message: str = None):
self.system_message = system_message
- @cached_property
def name(self):
return "ggml"
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index c8de90a8..ef8830a6 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -2,11 +2,17 @@ from functools import cached_property
import json
from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
+from pydantic import BaseModel
from ...core.main import ChatMessage
import openai
from ..llm import LLM
from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
-from ...core.config import AzureInfo
+
+
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
class OpenAI(LLM):
@@ -32,7 +38,6 @@ class OpenAI(LLM):
async def stop(self):
pass
- @cached_property
def name(self):
return self.default_model