summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CONTRIBUTING.md1
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py2
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/core/config.py19
-rw-r--r--continuedev/src/continuedev/core/context.py6
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/core/models.py65
-rw-r--r--continuedev/src/continuedev/core/sdk.py139
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt13
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py31
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py47
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py100
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py53
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py92
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py165
-rw-r--r--continuedev/src/continuedev/libs/llm/utils.py34
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py86
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/README.md2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py8
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py31
-rw-r--r--continuedev/src/continuedev/plugins/steps/draft/migration.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py4
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py2
-rw-r--r--docs/docs/customization.md59
-rw-r--r--docs/docs/walkthroughs/create-a-recipe.md2
32 files changed, 555 insertions, 440 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e6dea5c4..50d694f4 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -107,7 +107,6 @@ When state is updated on the server, we currently send the entirety of the objec
- `history`, a record of previously run Steps. Displayed in order in the sidebar.
- `active`, whether the autopilot is currently running a step. Displayed as a loader while step is running.
- `user_input_queue`, the queue of user inputs that have not yet been processed due to waiting for previous Steps to complete. Displayed below the `active` loader until popped from the queue.
-- `default_model`, the default model used for completions. Displayed as a toggleable button on the bottom of the GUI.
- `selected_context_items`, the ranges of code and other items (like GitHub Issues, files, etc...) that have been selected to include as context. Displayed just above the main text input.
- `slash_commands`, the list of available slash commands. Displayed in the main text input dropdown.
- `adding_highlighted_code`, whether highlighting of new code for context is locked. Displayed as a button adjacent to `highlighted_ranges`.
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 94d7be10..e048f877 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -73,7 +73,7 @@ class AbstractContinueSDK(ABC):
pass
@abstractmethod
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
pass
config: ContinueConfig
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 5ab5f8ae..de95a259 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -96,7 +96,6 @@ class Autopilot(ContinueBaseModel):
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
- default_model=self.continue_sdk.config.default_model,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
"code"].adding_highlighted_code,
diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py
index fe0946cd..84b6b10b 100644
--- a/continuedev/src/continuedev/core/config.py
+++ b/continuedev/src/continuedev/core/config.py
@@ -1,3 +1,9 @@
+import json
+import os
+from .main import Step
+from .context import ContextProvider
+from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
+from .models import Models
from pydantic import BaseModel, validator
from typing import List, Literal, Optional, Dict, Type
@@ -23,13 +29,6 @@ class OnTracebackSteps(BaseModel):
params: Optional[Dict] = {}
-class OpenAIServerInfo(BaseModel):
- api_base: Optional[str] = None
- engine: Optional[str] = None
- api_version: Optional[str] = None
- api_type: Literal["azure", "openai"] = "openai"
-
-
class ContinueConfig(BaseModel):
"""
A pydantic class for the continue config file.
@@ -37,8 +36,9 @@ class ContinueConfig(BaseModel):
steps_on_startup: List[Step] = []
disallowed_steps: Optional[List[str]] = []
allow_anonymous_telemetry: Optional[bool] = True
- default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k",
- "gpt-4", "claude-2", "ggml"] = 'gpt-4'
+ models: Models = Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ )
temperature: Optional[float] = 0.5
custom_commands: Optional[List[CustomCommand]] = [CustomCommand(
name="test",
@@ -48,7 +48,6 @@ class ContinueConfig(BaseModel):
slash_commands: Optional[List[SlashCommand]] = []
on_traceback: Optional[List[OnTracebackSteps]] = []
system_message: Optional[str] = None
- openai_server_info: Optional[OpenAIServerInfo] = None
policy_override: Optional[Policy] = None
context_providers: List[ContextProvider] = []
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index 86522ce1..e968c35c 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -178,12 +178,6 @@ class ContextManager:
except Exception as e:
logger.debug(f"Error loading meilisearch index: {e}")
- # def compile_chat_messages(self, max_tokens: int) -> List[Dict]:
- # """
- # Compiles the chat prompt into a single string.
- # """
- # return compile_chat_messages(self.model, self.chat_history, max_tokens, self.prompt, self.functions, self.system_message)
-
async def select_context_item(self, id: str, query: str):
"""
Selects the ContextItem with the given id.
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index df9b98ef..2553850f 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):
history: History
active: bool
user_input_queue: List[str]
- default_model: str
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
new file mode 100644
index 00000000..900762b6
--- /dev/null
+++ b/continuedev/src/continuedev/core/models.py
@@ -0,0 +1,65 @@
+from typing import Optional, Any
+from pydantic import BaseModel, validator
+from ..libs.llm import LLM
+
+
+class Models(BaseModel):
+ """Main class that holds the current model configuration"""
+ default: LLM
+ small: Optional[LLM] = None
+ medium: Optional[LLM] = None
+ large: Optional[LLM] = None
+
+ # TODO namespace these away to not confuse readers,
+ # or split Models into ModelsConfig, which gets turned into Models
+ sdk: "ContinueSDK" = None
+ system_message: Any = None
+
+ """
+ Better to have sdk.llm.stream_chat(messages, model="claude-2").
+ Then you also don't care that it' async.
+ And it's easier to add more models.
+ And intermediate shared code is easier to add.
+ And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
+ PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
+ Easy to reason about, can place anywhere.
+ And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
+ This can all happen inside of Models?
+
+ class Prompt:
+ def __init__(self, ...info):
+ '''take whatever info is needed to describe the prompt'''
+
+ def to_string(self, model: str) -> str:
+ '''depending on the model, return the single prompt string'''
+ """
+
+ async def start(self, sdk: "ContinueSDK"):
+ """Start each of the LLMs, or fall back to default"""
+ self.sdk = sdk
+ self.system_message = self.sdk.config.system_message
+ await sdk.start_model(self.default)
+ if self.small:
+ await sdk.start_model(self.small)
+ else:
+ self.small = self.default
+
+ if self.medium:
+ await sdk.start_model(self.medium)
+ else:
+ self.medium = self.default
+
+ if self.large:
+ await sdk.start_model(self.large)
+ else:
+ self.large = self.default
+
+ async def stop(self, sdk: "ContinueSDK"):
+ """Stop each LLM (if it's not the default, which is shared)"""
+ await self.default.stop()
+ if self.small is not self.default:
+ await self.small.stop()
+ if self.medium is not self.default:
+ await self.medium.stop()
+ if self.large is not self.default:
+ await self.large.stop()
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 50b89a8b..bf22d696 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -9,17 +9,14 @@ from .abstract_sdk import AbstractContinueSDK
from .config import ContinueConfig
from ..models.filesystem_edit import FileEdit, FileSystemEdit, AddFile, DeleteFile, AddDirectory, DeleteDirectory
from ..models.filesystem import RangeInFile
-from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
-from ..libs.llm.openai import OpenAI
-from ..libs.llm.anthropic import AnthropicLLM
-from ..libs.llm.ggml import GGML
+from ..libs.llm import LLM
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, HistoryNode, Step, ChatMessage
from ..plugins.steps.core.core import *
-from ..libs.llm.proxy_server import ProxyServer
from ..libs.util.telemetry import posthog_logger
from ..libs.util.paths import getConfigFilePath
+from .models import Models
from ..libs.util.logging import logger
@@ -27,121 +24,6 @@ class Autopilot:
pass
-ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
-MODEL_PROVIDER_TO_ENV_VAR = {
- "openai": "OPENAI_API_KEY",
- "hf_inference_api": "HUGGING_FACE_TOKEN",
- "anthropic": "ANTHROPIC_API_KEY",
-}
-
-
-class Models:
- provider_keys: Dict[ModelProvider, str] = {}
- model_providers: List[ModelProvider]
- system_message: str
-
- """
- Better to have sdk.llm.stream_chat(messages, model="claude-2").
- Then you also don't care that it' async.
- And it's easier to add more models.
- And intermediate shared code is easier to add.
- And you can make constants like ContinueModels.GPT35 = "gpt-3.5-turbo"
- PromptTransformer would be a good concept: You pass a prompt or list of messages and a model, then it outputs the prompt for that model.
- Easy to reason about, can place anywhere.
- And you can even pass a Prompt object to sdk.llm.stream_chat maybe, and it'll automatically be transformed for the given model.
- This can all happen inside of Models?
-
- class Prompt:
- def __init__(self, ...info):
- '''take whatever info is needed to describe the prompt'''
-
- def to_string(self, model: str) -> str:
- '''depending on the model, return the single prompt string'''
- """
-
- def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
- self.sdk = sdk
- self.model_providers = model_providers
- self.system_message = sdk.config.system_message
-
- @classmethod
- async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
- if sdk.config.default_model == "claude-2":
- with_providers.append("anthropic")
-
- models = Models(sdk, with_providers)
- for provider in with_providers:
- if provider in MODEL_PROVIDER_TO_ENV_VAR:
- env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
- models.provider_keys[provider] = await sdk.get_user_secret(
- env_var, f'Please add your {env_var} to the .env file')
-
- return models
-
- def __load_openai_model(self, model: str) -> OpenAI:
- api_key = self.provider_keys["openai"]
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message, write_log=self.sdk.write_log)
- return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message, openai_server_info=self.sdk.config.openai_server_info, write_log=self.sdk.write_log)
-
- def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
- api_key = self.provider_keys["hf_inference_api"]
- return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message)
-
- def __load_anthropic_model(self, model: str) -> AnthropicLLM:
- api_key = self.provider_keys["anthropic"]
- return AnthropicLLM(api_key, model, self.system_message)
-
- @cached_property
- def claude2(self):
- return self.__load_anthropic_model("claude-2")
-
- @cached_property
- def starcoder(self):
- return self.__load_hf_inference_api_model("bigcode/starcoder")
-
- @cached_property
- def gpt35(self):
- return self.__load_openai_model("gpt-3.5-turbo")
-
- @cached_property
- def gpt350613(self):
- return self.__load_openai_model("gpt-3.5-turbo-0613")
-
- @cached_property
- def gpt3516k(self):
- return self.__load_openai_model("gpt-3.5-turbo-16k")
-
- @cached_property
- def gpt4(self):
- return self.__load_openai_model("gpt-4")
-
- @cached_property
- def ggml(self):
- return GGML(system_message=self.system_message)
-
- def __model_from_name(self, model_name: str):
- if model_name == "starcoder":
- return self.starcoder
- elif model_name == "gpt-3.5-turbo":
- return self.gpt35
- elif model_name == "gpt-3.5-turbo-16k":
- return self.gpt3516k
- elif model_name == "gpt-4":
- return self.gpt4
- elif model_name == "claude-2":
- return self.claude2
- elif model_name == "ggml":
- return self.ggml
- else:
- raise Exception(f"Unknown model {model_name}")
-
- @property
- def default(self):
- default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt4
-
-
class ContinueSDK(AbstractContinueSDK):
"""The SDK provided as parameters to a step"""
ide: AbstractIdeProtocolServer
@@ -179,11 +61,13 @@ class ContinueSDK(AbstractContinueSDK):
active=False
))
+ sdk.models = sdk.config.models
+ await sdk.models.start(sdk)
+
# When the config is loaded, setup posthog logger
posthog_logger.setup(
sdk.ide.unique_id, sdk.config.allow_anonymous_telemetry)
- sdk.models = await Models.create(sdk)
return sdk
@property
@@ -193,6 +77,16 @@ class ContinueSDK(AbstractContinueSDK):
def write_log(self, message: str):
self.history.timeline[self.history.current_index].logs.append(message)
+ async def start_model(self, llm: LLM):
+ kwargs = {}
+ if llm.requires_api_key:
+ kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key)
+ if llm.requires_unique_id:
+ kwargs["unique_id"] = self.ide.unique_id
+ if llm.requires_write_log:
+ kwargs["write_log"] = self.write_log
+ await llm.start(**kwargs)
+
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
return path
@@ -262,7 +156,8 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
- async def get_user_secret(self, env_var: str, prompt: str) -> str:
+ async def get_user_secret(self, env_var: str) -> str:
+ # TODO support error prompt dynamically set on env_var
return await self.ide.getUserSecret(env_var)
_last_valid_config: ContinueConfig = None
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index 69fd357b..cf8b0324 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -9,9 +9,11 @@ import subprocess
from continuedev.core.main import Step
from continuedev.core.sdk import ContinueSDK
+from continuedev.core.models import Models
from continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
from continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
from continuedev.plugins.context_providers.google import GoogleContextProvider
+from continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from continuedev.plugins.policies.default import DefaultPolicy
from continuedev.plugins.steps.open_config import OpenConfigStep
@@ -36,9 +38,9 @@ class CommitMessageStep(Step):
diff = subprocess.check_output(
["git", "diff"], cwd=dir).decode("utf-8")
- # Ask gpt-3.5-16k to write a commit message,
+ # Ask the LLM to write a commit message,
# and set it as the description of this step
- self.description = await sdk.models.gpt3516k.complete(
+ self.description = await sdk.models.default.complete(
f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:")
@@ -48,9 +50,10 @@ config = ContinueConfig(
# See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry
allow_anonymous_telemetry=True,
- # GPT-4 is recommended for best results
- # See options here: https://continue.dev/docs/customization#change-the-default-llm
- default_model="gpt-4",
+ models=Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ medium=MaybeProxyOpenAI(model="gpt-3.5-turbo")
+ ),
# Set a system message with information that the LLM should always keep in mind
# E.g. "Please give concise answers. Always respond in Spanish."
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 2766db4b..50577993 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,14 +1,30 @@
-from abc import ABC
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from abc import ABC, abstractproperty
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
from ...core.main import ChatMessage
-from ...models.main import AbstractModel
-from pydantic import BaseModel
+from ...models.main import ContinueBaseModel
-class LLM(ABC):
+class LLM(ContinueBaseModel, ABC):
+ requires_api_key: Optional[str] = None
+ requires_unique_id: bool = False
+ requires_write_log: bool = False
+
system_message: Union[str, None] = None
+ @abstractproperty
+ def name(self):
+ """Return the name of the LLM."""
+ raise NotImplementedError
+
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ """Start the connection to the LLM."""
+ raise NotImplementedError
+
+ async def stop(self):
+ """Stop the connection to the LLM."""
+ raise NotImplementedError
+
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
@@ -24,3 +40,8 @@ class LLM(ABC):
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
raise NotImplementedError
+
+ @abstractproperty
+ def context_length(self) -> int:
+ """Return the context length of the LLM in tokens, as counted by count_tokens."""
+ raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 625d4e57..ec1b7e40 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,32 +1,39 @@
from functools import cached_property
import time
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, prune_raw_prompt_from_top
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
class AnthropicLLM(LLM):
- api_key: str
- default_model: str
- async_client: AsyncAnthropic
+ model: str = "claude-2"
- def __init__(self, api_key: str, default_model: str, system_message: str = None):
- self.api_key = api_key
- self.default_model = default_model
+ requires_api_key: str = "ANTHROPIC_API_KEY"
+ _async_client: AsyncAnthropic = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __init__(self, model: str, system_message: str = None):
+ self.model = model
self.system_message = system_message
- self.async_client = AsyncAnthropic(api_key=api_key)
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self._async_client = AsyncAnthropic(api_key=api_key)
+
+ async def stop(self):
+ pass
@cached_property
def name(self):
- return self.default_model
+ return self.model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
args = args.copy()
@@ -40,7 +47,13 @@ class AnthropicLLM(LLM):
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
+
+ @property
+ def context_length(self):
+ if self.model == "claude-2":
+ return 100000
+ raise Exception(f"Unknown Anthropic model {self.model}")
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -60,7 +73,7 @@ class AnthropicLLM(LLM):
args["stream"] = True
args = self._transform_args(args)
- async for chunk in await self.async_client.completions.create(
+ async for chunk in await self._async_client.completions.create(
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**args
):
@@ -73,8 +86,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
- async for chunk in await self.async_client.completions.create(
+ args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
+ async for chunk in await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
):
@@ -88,8 +101,8 @@ class AnthropicLLM(LLM):
args = self._transform_args(args)
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
- resp = (await self.async_client.completions.create(
+ args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
+ resp = (await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
)).completion
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 21374359..7742e8c3 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,6 +1,7 @@
from functools import cached_property
import json
from typing import Any, Coroutine, Dict, Generator, List, Union
+from pydantic import ConfigDict
import aiohttp
from ...core.main import ChatMessage
@@ -11,15 +12,29 @@ SERVER_URL = "http://localhost:8000"
class GGML(LLM):
+ # this is model-specific
+ max_context_length: int = 2048
- def __init__(self, system_message: str = None):
- self.system_message = system_message
+ _client_session: aiohttp.ClientSession = None
- @cached_property
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession()
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
def name(self):
return "ggml"
@property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
def default_args(self):
return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
@@ -33,54 +48,51 @@ class GGML(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
-
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": messages,
- **args
- }) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ yield line.decode("utf-8")
+ except:
+ raise Exception(str(line))
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
args["stream"] = True
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/chat/completions", json={
- "messages": messages,
- **args
- }) as resp:
- # This is streaming application/json instaed of text/event-stream
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
- continue # hehe
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- yield json.loads(chunk[6:])["choices"][0]["delta"]
- except:
- raise Exception(str(line[0]))
+ async with self._client_session.post(f"{SERVER_URL}/v1/chat/completions", json={
+ "messages": messages,
+ **args
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"):
+ continue
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
+ except:
+ raise Exception(str(line[0]))
async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
- async with aiohttp.ClientSession() as session:
- async with session.post(f"{SERVER_URL}/v1/completions", json={
- "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
- **args
- }) as resp:
- try:
- return await resp.text()
- except:
- raise Exception(await resp.text())
+ async with self._client_session.post(f"{SERVER_URL}/v1/completions", json={
+ "messages": compile_chat_messages(args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message),
+ **args
+ }) as resp:
+ try:
+ return await resp.text()
+ except:
+ raise Exception(await resp.text())
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 36f03270..49f593d8 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import List, Optional
from ...core.main import ChatMessage
from ..llm import LLM
import requests
@@ -8,14 +8,18 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
- api_key: str
model: str
- def __init__(self, api_key: str, model: str, system_message: str = None):
- self.api_key = api_key
+ requires_api_key: str = "HUGGING_FACE_TOKEN"
+ api_key: str = None
+
+ def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ self.api_key = api_key
+
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
new file mode 100644
index 00000000..edf58fd7
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -0,0 +1,53 @@
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional, Callable
+
+from ...core.main import ChatMessage
+from . import LLM
+from .proxy_server import ProxyServer
+from .openai import OpenAI
+
+
+class MaybeProxyOpenAI(LLM):
+ model: str
+
+ requires_api_key: Optional[str] = "OPENAI_API_KEY"
+ requires_write_log: bool = True
+ requires_unique_id: bool = True
+ system_message: Union[str, None] = None
+
+ llm: Optional[LLM] = None
+
+ @property
+ def name(self):
+ return self.llm.name
+
+ @property
+ def context_length(self):
+ return self.llm.context_length
+
+ async def start(self, *, api_key: Optional[str] = None, unique_id: str, write_log: Callable[[str], None]):
+ if api_key is None or api_key.strip() == "":
+ self.llm = ProxyServer(model=self.model)
+ else:
+ self.llm = OpenAI(model=self.model)
+
+ await self.llm.start(api_key=api_key, write_log=write_log, unique_id=unique_id)
+
+ async def stop(self):
+ await self.llm.stop()
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ return await self.llm.complete(prompt, with_history=with_history, **kwargs)
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_complete(
+ prompt, with_history=with_history, **kwargs)
+ async for item in resp:
+ yield item
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = self.llm.stream_chat(messages=messages, **kwargs)
+ async for item in resp:
+ yield item
+
+ def count_tokens(self, text: str):
+ return self.llm.count_tokens(text)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index d1ca4ef9..fce6e8ab 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,48 +1,83 @@
from functools import cached_property
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
-from ...core.main import ChatMessage
+from pydantic import BaseModel
import openai
+
+from ...core.main import ChatMessage
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens, format_chat_messages, prune_raw_prompt_from_top
-from ...core.config import OpenAIServerInfo
+
+
+class OpenAIServerInfo(BaseModel):
+ api_base: Optional[str] = None
+ engine: Optional[str] = None
+ api_version: Optional[str] = None
+ api_type: Literal["azure", "openai"] = "openai"
+
+
+CHAT_MODELS = {
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
+}
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
+
+class AzureInfo(BaseModel):
+ endpoint: str
+ engine: str
+ api_version: str
class OpenAI(LLM):
- api_key: str
- default_model: str
+ model: str
+ openai_server_info: Optional[OpenAIServerInfo] = None
- def __init__(self, api_key: str, default_model: str, system_message: str = None, openai_server_info: OpenAIServerInfo = None, write_log: Callable[[str], None] = None):
- self.api_key = api_key
- self.default_model = default_model
- self.system_message = system_message
- self.openai_server_info = openai_server_info
+ requires_api_key = "OPENAI_API_KEY"
+ requires_write_log = True
+
+ system_message: Optional[str] = None
+ azure_info: Optional[AzureInfo] = None
+ write_log: Optional[Callable[[str], None]] = None
+ api_key: str = None
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs):
self.write_log = write_log
+ self.api_key = api_key
+ openai.api_key = self.api_key
- openai.api_key = api_key
+ if self.openai_server_info is not None:
+ openai.api_type = self.openai_server_info.api_type
+ if self.openai_server_info.api_base is not None:
+ openai.api_base = self.openai_server_info.api_base
+ if self.openai_server_info.api_version is not None:
+ openai.api_version = self.openai_server_info.api_version
- # Using an Azure OpenAI deployment
- if openai_server_info is not None:
- openai.api_type = openai_server_info.api_type
- if openai_server_info.api_base is not None:
- openai.api_base = openai_server_info.api_base
- if openai_server_info.api_version is not None:
- openai.api_version = openai_server_info.api_version
+ async def stop(self):
+ pass
- @cached_property
+ @property
def name(self):
- return self.default_model
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- args = {**DEFAULT_ARGS, "model": self.default_model}
+ args = {**DEFAULT_ARGS, "model": self.model}
if self.openai_server_info is not None:
args["engine"] = self.openai_server_info.engine
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
@@ -51,7 +86,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -78,12 +113,13 @@ class OpenAI(LLM):
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
- args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"
+ # TODO what to do here? why should we change to gpt-3.5-turbo-0613 if the user didn't ask for it?
+ args["model"] = self.model if self.model in CHAT_MODELS else "gpt-3.5-turbo-0613"
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
completion = ""
async for chunk in await openai.ChatCompletion.acreate(
@@ -100,7 +136,7 @@ class OpenAI(LLM):
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = (await openai.ChatCompletion.acreate(
messages=messages,
@@ -109,7 +145,7 @@ class OpenAI(LLM):
self.write_log(f"Completion: \n\n{resp}")
else:
prompt = prune_raw_prompt_from_top(
- args["model"], prompt, args["max_tokens"])
+ args["model"], self.context_length, prompt, args["max_tokens"])
self.write_log(f"Prompt:\n\n{prompt}")
resp = (await openai.Completion.acreate(
prompt=prompt,
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index f9e3fa01..1a48f213 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,7 +1,6 @@
-
import json
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
@@ -16,26 +15,51 @@ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
# SERVER_URL = "http://127.0.0.1:8080"
SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16384,
+ "gpt-4": 8192,
+}
+
class ProxyServer(LLM):
- unique_id: str
- name: str
- default_model: Literal["gpt-3.5-turbo", "gpt-4"]
- write_log: Callable[[str], None]
+ model: str
+ system_message: Optional[str]
- def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None, write_log: Callable[[str], None] = None):
- self.unique_id = unique_id
- self.default_model = default_model
- self.system_message = system_message
- self.name = default_model
+ unique_id: str = None
+ write_log: Callable[[str], None] = None
+ _client_session: aiohttp.ClientSession
+
+ requires_unique_id = True
+ requires_write_log = True
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], unique_id: str, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context))
self.write_log = write_log
+ self.unique_id = unique_id
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return MAX_TOKENS_FOR_MODEL[self.model]
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
def get_headers(self):
# headers with unique id
@@ -45,75 +69,72 @@ class ProxyServer(LLM):
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- if resp.status != 200:
- raise Exception(await resp.text())
-
- response_text = await resp.text()
- self.write_log(f"Completion: \n\n{response_text}")
- return response_text
+ async with self._client_session.post(f"{SERVER_URL}/complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ if resp.status != 200:
+ raise Exception(await resp.text())
+
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
+ args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_chat", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- # This is streaming application/json instaed of text/event-stream
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- json_chunk = "{}" if json_chunk == "" else json_chunk
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- loaded_chunk = json.loads(chunk)
- yield loaded_chunk
- if "content" in loaded_chunk:
- completion += loaded_chunk["content"]
- except Exception as e:
- posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
- "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
- else:
- break
-
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ json_chunk = "{}" if json_chunk == "" else json_chunk
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
+ except Exception as e:
+ posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
+ "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
+ else:
+ break
+
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
+ self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
- async with session.post(f"{SERVER_URL}/stream_complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_any():
- if line:
- try:
- decoded_line = line.decode("utf-8")
- yield decoded_line
- completion += decoded_line
- except:
- raise Exception(str(line))
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
+ except:
+ raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")
diff --git a/continuedev/src/continuedev/libs/llm/utils.py b/continuedev/src/continuedev/libs/llm/utils.py
deleted file mode 100644
index 76240d4e..00000000
--- a/continuedev/src/continuedev/libs/llm/utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-from transformers import AutoTokenizer, AutoModelForCausalLM
-from transformers import GPT2TokenizerFast
-
-gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
-def count_tokens(text: str) -> int:
- return len(gpt2_tokenizer.encode(text))
-
-prices = {
- # All prices are per 1k tokens
- "fine-tune-train": {
- "davinci": 0.03,
- "curie": 0.03,
- "babbage": 0.0006,
- "ada": 0.0004,
- },
- "completion": {
- "davinci": 0.02,
- "curie": 0.002,
- "babbage": 0.0005,
- "ada": 0.0004,
- },
- "fine-tune-completion": {
- "davinci": 0.12,
- "curie": 0.012,
- "babbage": 0.0024,
- "ada": 0.0016,
- },
- "embedding": {
- "ada": 0.0004
- }
-}
-
-def get_price(text: str, model: str="davinci", task: str="completion") -> float:
- return count_tokens(text) * prices[task][model] / 1000 \ No newline at end of file
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index c58ae499..6add7b1a 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -2,43 +2,47 @@ import json
from typing import Dict, List, Union
from ...core.main import ChatMessage
from .templating import render_templated_string
+from ...libs.llm import LLM
import tiktoken
+# TODO move many of these into specific LLM.properties() function that
+# contains max tokens, if its a chat model or not, default args (not all models
+# want to be run at 0.5 temp). also lets custom models made for long contexts
+# exist here (likg LLongMA)
aliases = {
"ggml": "gpt-3.5-turbo",
"claude-2": "gpt-3.5-turbo",
}
DEFAULT_MAX_TOKENS = 2048
-MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16384,
- "gpt-4": 8192,
- "ggml": 2048,
- "claude-2": 100000
-}
-CHAT_MODELS = {
- "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"
-}
DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0}
-def encoding_for_model(model: str):
- return tiktoken.encoding_for_model(aliases.get(model, model))
+def encoding_for_model(model_name: str):
+ try:
+ return tiktoken.encoding_for_model(aliases.get(model_name, model_name))
+ except:
+ return tiktoken.encoding_for_model("gpt-3.5-turbo")
-def count_tokens(model: str, text: Union[str, None]):
+def count_tokens(model_name: str, text: Union[str, None]):
if text is None:
return 0
- encoding = encoding_for_model(model)
+ encoding = encoding_for_model(model_name)
return len(encoding.encode(text, disallowed_special=()))
-def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: int):
- max_tokens = MAX_TOKENS_FOR_MODEL.get(
- model, DEFAULT_MAX_TOKENS) - tokens_for_completion
- encoding = encoding_for_model(model)
+def count_chat_message_tokens(model_name: str, chat_message: ChatMessage) -> int:
+ # Doing simpler, safer version of what is here:
+ # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ # every message follows <|start|>{role/name}\n{content}<|end|>\n
+ TOKENS_PER_MESSAGE = 4
+ return count_tokens(model_name, chat_message.content) + TOKENS_PER_MESSAGE
+
+
+def prune_raw_prompt_from_top(model_name: str, context_length: int, prompt: str, tokens_for_completion: int):
+ max_tokens = context_length - tokens_for_completion
+ encoding = encoding_for_model(model_name)
tokens = encoding.encode(prompt, disallowed_special=())
if len(tokens) <= max_tokens:
return prompt
@@ -46,53 +50,45 @@ def prune_raw_prompt_from_top(model: str, prompt: str, tokens_for_completion: in
return encoding.decode(tokens[-max_tokens:])
-def count_chat_message_tokens(model: str, chat_message: ChatMessage) -> int:
- # Doing simpler, safer version of what is here:
- # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- # every message follows <|start|>{role/name}\n{content}<|end|>\n
- TOKENS_PER_MESSAGE = 4
- return count_tokens(model, chat_message.content) + TOKENS_PER_MESSAGE
-
-
-def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+def prune_chat_history(model_name: str, chat_history: List[ChatMessage], context_length: int, tokens_for_completion: int):
total_tokens = tokens_for_completion + \
- sum(count_chat_message_tokens(model, message)
+ sum(count_chat_message_tokens(model_name, message)
for message in chat_history)
# 1. Replace beyond last 5 messages with summary
i = 0
- while total_tokens > max_tokens and i < len(chat_history) - 5:
+ while total_tokens > context_length and i < len(chat_history) - 5:
message = chat_history[0]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 2. Remove entire messages until the last 5
- while len(chat_history) > 5 and total_tokens > max_tokens and len(chat_history) > 0:
+ while len(chat_history) > 5 and total_tokens > context_length and len(chat_history) > 0:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 3. Truncate message in the last 5, except last 1
i = 0
- while total_tokens > max_tokens and len(chat_history) > 0 and i < len(chat_history) - 1:
+ while total_tokens > context_length and len(chat_history) > 0 and i < len(chat_history) - 1:
message = chat_history[i]
- total_tokens -= count_tokens(model, message.content)
- total_tokens += count_tokens(model, message.summary)
+ total_tokens -= count_tokens(model_name, message.content)
+ total_tokens += count_tokens(model_name, message.summary)
message.content = message.summary
i += 1
# 4. Remove entire messages in the last 5, except last 1
- while total_tokens > max_tokens and len(chat_history) > 1:
+ while total_tokens > context_length and len(chat_history) > 1:
message = chat_history.pop(0)
- total_tokens -= count_tokens(model, message.content)
+ total_tokens -= count_tokens(model_name, message.content)
# 5. Truncate last message
- if total_tokens > max_tokens and len(chat_history) > 0:
+ if total_tokens > context_length and len(chat_history) > 0:
message = chat_history[0]
message.content = prune_raw_prompt_from_top(
- model, message.content, tokens_for_completion)
- total_tokens = max_tokens
+ model_name, context_length, message.content, tokens_for_completion)
+ total_tokens = context_length
return chat_history
@@ -101,7 +97,7 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
TOKEN_BUFFER_FOR_SAFETY = 100
-def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model_name: str, msgs: Union[List[ChatMessage], None], context_length: int, max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
@@ -125,10 +121,10 @@ def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_
function_tokens = 0
if functions is not None:
for function in functions:
- function_tokens += count_tokens(model, json.dumps(function))
+ function_tokens += count_tokens(model_name, json.dumps(function))
msgs_copy = prune_chat_history(
- model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+ model_name, msgs_copy, context_length, function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
history = [msg.to_dict(with_functions=functions is not None)
for msg in msgs_copy]
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index 433e309e..872f8d62 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -27,7 +27,7 @@ class SetupPipelineStep(Step):
async def run(self, sdk: ContinueSDK):
sdk.context.set("api_description", self.api_description)
- source_name = (await sdk.models.gpt35.complete(
+ source_name = (await sdk.models.medium.complete(
f"Write a snake_case name for the data source described by {self.api_description}: ")).strip()
filename = f'{source_name}.py'
@@ -89,7 +89,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{file_content}
```
@@ -101,7 +101,7 @@ class ValidatePipelineStep(Step):
This is a brief summary of the error followed by a suggestion on how it can be fixed by editing the resource function:"""))
- api_documentation_url = await sdk.models.gpt35.complete(dedent(f"""\
+ api_documentation_url = await sdk.models.medium.complete(dedent(f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
```python
{file_content}
@@ -151,7 +151,7 @@ class RunQueryStep(Step):
output = await sdk.run('.env/bin/python3 query.py', name="Run test query", description="Running `.env/bin/python3 query.py` to test that the data was loaded into DuckDB as expected", handle_error=False)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
```python
{await sdk.ide.readFile(os.path.join(sdk.ide.workspace_directory, "query.py"))}
```
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index 6ef5ffd6..c66cd629 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -42,7 +42,7 @@ class WritePytestsRecipe(Step):
"{self.user_input}"
Here is a complete set of pytest unit tests:""")
- tests = await sdk.models.gpt35.complete(prompt)
+ tests = await sdk.models.medium.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md
index 12073835..3f2f804c 100644
--- a/continuedev/src/continuedev/plugins/steps/README.md
+++ b/continuedev/src/continuedev/plugins/steps/README.md
@@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
Here's an example:
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index f72a8ec0..455d5a13 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -9,6 +9,7 @@ from .core.core import DisplayErrorStep, MessageStep
from ...core.main import FunctionCall, Models
from ...core.main import ChatMessage, Step, step_to_json_schema
from ...core.sdk import ContinueSDK
+from ...libs.llm.openai import OpenAI
import openai
import os
from dotenv import load_dotenv
@@ -41,7 +42,7 @@ class SimpleChatStep(Step):
self.description += chunk["content"]
await sdk.update_ui()
- self.name = remove_quotes_and_escapes(await sdk.models.gpt35.complete(
+ self.name = remove_quotes_and_escapes(await sdk.models.medium.complete(
f"Write a short title for the following chat message: {self.description}"))
self.chat_context.append(ChatMessage(
@@ -166,7 +167,10 @@ class ChatWithFunctions(Step):
msg_content = ""
msg_step = None
- async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
+ gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(gpt350613)
+
+ async for msg_chunk in gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions):
if sdk.current_step_was_deleted():
return
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index dbe8363e..658cc7f3 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -56,7 +56,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:""")
- answer = await sdk.models.gpt35.complete(prompt)
+ answer = await sdk.models.medium.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index de7cf3ac..4476c7ae 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -11,11 +11,12 @@ from pydantic import validator
from ....libs.llm.ggml import GGML
from ....models.main import Range
+from ....libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ....models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullContents, FileSystemEdit
from ....models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents
-from ....core.observation import Observation, TextObservation, UserInputObservation
-from ....core.main import ChatMessage, ContinueCustomException, Step
-from ....libs.util.count_tokens import MAX_TOKENS_FOR_MODEL, DEFAULT_MAX_TOKENS
+from ....core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation
+from ....core.main import ChatMessage, ContinueCustomException, Step, SequentialStep
+from ....libs.util.count_tokens import DEFAULT_MAX_TOKENS
from ....libs.util.strings import dedent_and_get_common_whitespace, remove_quotes_and_escapes
@@ -97,7 +98,7 @@ class ShellCommandsStep(Step):
return f"Error when running shell commands:\n```\n{self._err_text}\n```"
cmds_str = "\n".join(self.cmds)
- return await models.gpt35.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
+ return await models.medium.complete(f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
cwd = await sdk.ide.getWorkspaceDirectory() if self.cwd is None else self.cwd
@@ -105,7 +106,7 @@ class ShellCommandsStep(Step):
for cmd in self.cmds:
output = await sdk.ide.runCommand(cmd)
if self.handle_error and output is not None and output_contains_error(output):
- suggestion = await sdk.models.gpt35.complete(dedent(f"""\
+ suggestion = await sdk.models.medium.complete(dedent(f"""\
While running the command `{cmd}`, the following error occurred:
```ascii
@@ -185,7 +186,7 @@ class DefaultModelEditCodeStep(Step):
else:
changes = '\n'.join(difflib.ndiff(
self._previous_contents.splitlines(), self._new_contents.splitlines()))
- description = await models.gpt3516k.complete(dedent(f"""\
+ description = await models.medium.complete(dedent(f"""\
Diff summary: "{self.user_input}"
```diff
@@ -193,7 +194,7 @@ class DefaultModelEditCodeStep(Step):
```
Please give brief a description of the changes made above using markdown bullet points. Be concise:"""))
- name = await models.gpt3516k.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
+ name = await models.medium.complete(f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:")
self.name = remove_quotes_and_escapes(name)
return f"{remove_quotes_and_escapes(description)}"
@@ -203,8 +204,7 @@ class DefaultModelEditCodeStep(Step):
# We care because if this prompt itself goes over the limit, then the entire message will have to be cut from the completion.
# Overflow won't happen, but prune_chat_messages in count_tokens.py will cut out this whole thing, instead of us cutting out only as many lines as we need.
model_to_use = sdk.models.default
- max_tokens = int(MAX_TOKENS_FOR_MODEL.get(
- model_to_use.name, DEFAULT_MAX_TOKENS) / 2)
+ max_tokens = int(model_to_use.context_length / 2)
TOKENS_TO_BE_CONSIDERED_LARGE_RANGE = 1200
if model_to_use.count_tokens(rif.contents) > TOKENS_TO_BE_CONSIDERED_LARGE_RANGE:
@@ -222,8 +222,9 @@ class DefaultModelEditCodeStep(Step):
# If using 3.5 and overflows, upgrade to 3.5.16k
if model_to_use.name == "gpt-3.5-turbo":
- if total_tokens > MAX_TOKENS_FOR_MODEL["gpt-3.5-turbo"]:
- model_to_use = sdk.models.gpt3516k
+ if total_tokens > model_to_use.context_length:
+ model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
+ await sdk.start_model(model_to_use)
# Remove tokens from the end first, and then the start to clear space
# This part finds the start and end lines
@@ -233,20 +234,20 @@ class DefaultModelEditCodeStep(Step):
cur_start_line = 0
cur_end_line = len(full_file_contents_lst) - 1
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_end_line > min_end_line:
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_end_line])
cur_end_line -= 1
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
- if total_tokens > MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens > model_to_use.context_length:
while cur_start_line < max_start_line:
cur_start_line += 1
total_tokens -= model_to_use.count_tokens(
full_file_contents_lst[cur_start_line])
- if total_tokens < MAX_TOKENS_FOR_MODEL[model_to_use.name]:
+ if total_tokens < model_to_use.context_length:
break
# Now use the found start/end lines to get the prefix and suffix strings
diff --git a/continuedev/src/continuedev/plugins/steps/draft/migration.py b/continuedev/src/continuedev/plugins/steps/draft/migration.py
index a76d491b..c38f54dc 100644
--- a/continuedev/src/continuedev/plugins/steps/draft/migration.py
+++ b/continuedev/src/continuedev/plugins/steps/draft/migration.py
@@ -13,7 +13,7 @@ class MigrationStep(Step):
recent_edits = await sdk.ide.get_recent_edits(self.edited_file)
recent_edits_string = "\n\n".join(
map(lambda x: x.to_string(), recent_edits))
- description = await sdk.models.gpt35.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
+ description = await sdk.models.medium.complete(f"{recent_edits_string}\n\nGenerate a short description of the migration made in the above changes:\n")
await sdk.run([
"cd libs",
"poetry run alembic revision --autogenerate -m " + description,
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index 6997a547..ec670999 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -56,7 +56,7 @@ class HelpStep(Step):
summary="Help"
))
messages = await sdk.get_chat_context()
- generator = sdk.models.gpt4.stream_chat(messages)
+ generator = sdk.models.default.stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index b54d394a..3d8d96fb 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -23,6 +23,6 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.gpt35.complete(
+ gpt_parsed = await sdk.models.default.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:")
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index f28d9660..d2d6f4dd 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -101,7 +101,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.gpt35.complete(prompt)
+ completion = await sdk.models.medium.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -169,7 +169,7 @@ class StarCoderEditHighlightedCodeStep(Step):
_prompt_and_completion: str = ""
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return await models.gpt35.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
+ return await models.medium.complete(f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:")
async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]:
range_in_files = await sdk.get_code_context(only_editing=True)
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index 8b2e7c2e..da6acdbf 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -27,7 +27,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""")
- resp = (await sdk.models.gpt35.complete(prompt)).lower()
+ resp = (await sdk.models.medium.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 522a84a3..456dba84 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -42,7 +42,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.gpt35.complete(dedent(f"""\
+ pattern = await sdk.models.medium.complete(dedent(f"""\
This is the user request:
{self.user_request}
diff --git a/docs/docs/customization.md b/docs/docs/customization.md
index fa4d110e..60764527 100644
--- a/docs/docs/customization.md
+++ b/docs/docs/customization.md
@@ -4,11 +4,25 @@ Continue can be deeply customized by editing the `ContinueConfig` object in `~/.
## Change the default LLM
-Change the `default_model` field to any of "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "claude-2", or "ggml".
+In `config.py`, you'll find the `models` property:
+
+```python
+config = ContinueConfig(
+ ...
+ models=Models(
+ default=MaybeProxyOpenAI(model="gpt-4"),
+ medium=MaybeProxyOpenAI(model="gpt-3.5-turbo")
+ )
+)
+```
+
+The `default` model is the one used for most operations, including responding to your messages and editing code. The `medium` model is used for summarization tasks that require less quality. There are also `small` and `large` roles that can be filled, but all will fall back to `default` if not set. The values of these fields must be of the [`LLM`](https://github.com/continuedev/continue/blob/main/continuedev/src/continuedev/libs/llm/__init__.py) class, which implements methods for retrieving and streaming completions from an LLM.
+
+Below, we describe the `LLM` classes available in the Continue core library, and how they can be used.
### Adding an OpenAI API key
-New users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code.
+With the `MaybeProxyOpenAI` `LLM`, new users can try out Continue with GPT-4 using a proxy server that securely makes calls to OpenAI using our API key. Continue should just work the first time you install the extension in VS Code.
Once you are using Continue regularly though, you will need to add an OpenAI API key that has access to GPT-4 by following these steps:
@@ -18,34 +32,55 @@ Once you are using Continue regularly though, you will need to add an OpenAI API
4. Click Edit in settings.json under Continue: OpenAI_API_KEY" section
5. Paste your API key as the value for "continue.OPENAI_API_KEY" in settings.json
-### claude-2 and gpt-X
+The `MaybeProxyOpenAI` class will automatically switch to using your API key instead of ours. If you'd like to explicitly use one or the other, you can use the `ProxyServer` or `OpenAI` classes instead.
+
+These classes support any models available through the OpenAI API, assuming your API key has access, including "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", and "gpt-4-32k".
+
+### claude-2
-If you have access, simply set `default_model` to the model you would like to use, then you will be prompted for a personal API key after reloading VS Code. If using an OpenAI model, you can press enter to try with our API key for free.
+Import the `Anthropic` LLM class and set it as the default model:
+
+```python
+from continuedev.libs.llm.anthropic import Anthropic
+
+config = ContinueConfig(
+ ...
+ models=Models(
+ default=Anthropic(model="claude-2")
+ )
+)
+```
+
+Continue will automatically prompt you for your Anthropic API key, which must have access to Claude 2. You can request early access [here](https://www.anthropic.com/earlyaccess).
### Local models with ggml
See our [5 minute quickstart](https://github.com/continuedev/ggml-server-example) to run any model locally with ggml. While these models don't yet perform as well, they are free, entirely private, and run offline.
-Once the model is running on localhost:8000, set `default_model` in `~/.continue/config.py` to "ggml".
+Once the model is running on localhost:8000, import the `GGML` LLM class from `continuedev.libs.llm.ggml` and set `default=GGML(max_context_length=2048)`.
### Self-hosting an open-source model
-If you want to self-host on Colab, RunPod, Replicate, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`.
+If you want to self-host on Colab, RunPod, Replicate, HuggingFace, Haven, or another hosting provider you will need to wire up a new LLM class. It only needs to implement 3 primary methods: `stream_complete`, `complete`, and `stream_chat`, and you can see examples in `continuedev/src/continuedev/libs/llm`.
If by chance the provider has the exact same API interface as OpenAI, the `GGML` class will work for you out of the box, after changing the endpoint at the top of the file.
### Azure OpenAI Service
-If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, set `default_model` to "gpt-4", and then set the `openai_server_info` property in the `ContinueConfig` like so:
+If you'd like to use OpenAI models but are concerned about privacy, you can use the Azure OpenAI service, which is GDPR and HIPAA compliant. After applying for access [here](https://azure.microsoft.com/en-us/products/ai-services/openai-service), you will typically hear back within only a few days. Once you have access, instantiate the model like so:
```python
+from continuedev.libs.llm.openai import OpenAI, OpenAIServerInfo
+
config = ContinueConfig(
...
- openai_server_info=OpenAIServerInfo(
- api_base="https://my-azure-openai-instance.openai.azure.com/",
- engine="my-azure-openai-deployment",
- api_version="2023-03-15-preview",
- api_type="azure"
+ models=Models(
+ default=OpenAI(model="gpt-3.5-turbo", server_info=OpenAIServerInfo(
+ api_base="https://my-azure-openai-instance.openai.azure.com/"
+ engine="my-azure-openai-deployment",
+ api_version="2023-03-15-preview",
+ api_type="azure"
+ ))
)
)
```
diff --git a/docs/docs/walkthroughs/create-a-recipe.md b/docs/docs/walkthroughs/create-a-recipe.md
index 5d80d083..2cb28f77 100644
--- a/docs/docs/walkthroughs/create-a-recipe.md
+++ b/docs/docs/walkthroughs/create-a-recipe.md
@@ -31,7 +31,7 @@ If you'd like to override the default description of your steps, which is just t
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.gpt35.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
## 2. Compose steps together into a complete recipe