summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-09-03 21:58:46 -0700
committerGitHub <noreply@github.com>2023-09-03 21:58:46 -0700
commite645a89192b28cc16a1303bfa5551834c64ecb77 (patch)
tree6da1d0b5f59cef5c9fd9a615119742550fe1ad2c
parente49c6f55ae0c00bc660bbe885ea44f3a2fb1dc35 (diff)
downloadsncontinue-e645a89192b28cc16a1303bfa5551834c64ecb77.tar.gz
sncontinue-e645a89192b28cc16a1303bfa5551834c64ecb77.tar.bz2
sncontinue-e645a89192b28cc16a1303bfa5551834c64ecb77.zip
refactor: :construction: Initial, not tested, refactor of LLM (#448)
* refactor: :construction: Initial, not tested, refactor of LLM * refactor: :construction: replace usages of _complete with complete * fix: :bug: fixes after refactor * refactor: :recycle: template raw completions in chat format * test: :white_check_mark: simplified edit prompt and UNIT TESTS! * ci: :green_heart: unit tests in ci * fix: :bug: fixes for unit tests in ci * fix: :bug: start uvicorn in tests without poetry * fix: :closed_lock_with_key: add secrets to main.yaml * feat: :adhesive_bandage: timeout for all LLM classes * ci: :green_heart: prepare main.yaml for main branch
-rw-r--r--.github/workflows/main.yaml11
-rw-r--r--.vscode/launch.json9
-rw-r--r--continuedev/dev_requirements.txt2
-rw-r--r--continuedev/poetry.lock107
-rw-r--r--continuedev/pyproject.toml6
-rw-r--r--continuedev/src/continuedev/core/autopilot.py2
-rw-r--r--continuedev/src/continuedev/core/models.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py255
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py100
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py94
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py24
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py75
-rw-r--r--continuedev/src/continuedev/libs/llm/hugging_face.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py123
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py28
-rw-r--r--continuedev/src/continuedev/libs/llm/ollama.py122
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py118
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/chat.py30
-rw-r--r--continuedev/src/continuedev/libs/llm/prompts/edit.py13
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py88
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py20
-rw-r--r--continuedev/src/continuedev/libs/llm/text_gen_interface.py72
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py102
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py13
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py7
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py2
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/README.md2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py6
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py12
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py6
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py2
-rw-r--r--continuedev/src/continuedev/server/session_manager.py8
-rw-r--r--continuedev/src/continuedev/tests/llm_test.py192
-rw-r--r--continuedev/src/continuedev/tests/util/openai_mock.py139
-rw-r--r--docs/docs/concepts/sdk.md2
-rw-r--r--docs/docs/customization.md2
-rw-r--r--docs/docs/walkthroughs/create-a-recipe.md2
-rw-r--r--extension/react-app/src/components/ModelSelect.tsx10
44 files changed, 999 insertions, 832 deletions
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml
index 8c348024..49dde09c 100644
--- a/.github/workflows/main.yaml
+++ b/.github/workflows/main.yaml
@@ -39,6 +39,17 @@ jobs:
run: |
chmod 777 dist/run
+ - name: Test Python Server
+ env:
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
+ TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
+ run: |
+ cd continuedev
+ pip install -r dev_requirements.txt
+ cd src
+ python -m pytest
+
- name: Upload Artifacts
uses: actions/upload-artifact@v3
with:
diff --git a/.vscode/launch.json b/.vscode/launch.json
index 12cfaef8..08061d13 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -17,6 +17,15 @@
],
"configurations": [
{
+ "name": "Pytest",
+ "type": "python",
+ "request": "launch",
+ "program": "continuedev/src/continuedev/tests/llm_test.py",
+ "console": "integratedTerminal",
+ "justMyCode": true,
+ "subProcess": true
+ },
+ {
"name": "Server",
"type": "python",
"request": "launch",
diff --git a/continuedev/dev_requirements.txt b/continuedev/dev_requirements.txt
new file mode 100644
index 00000000..2fa7631b
--- /dev/null
+++ b/continuedev/dev_requirements.txt
@@ -0,0 +1,2 @@
+pytest==7.4.1
+pytest-asyncio==0.21.1 \ No newline at end of file
diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock
index aefc7cf9..700fd017 100644
--- a/continuedev/poetry.lock
+++ b/continuedev/poetry.lock
@@ -588,6 +588,20 @@ files = [
]
[[package]]
+name = "exceptiongroup"
+version = "1.1.3"
+description = "Backport of PEP 654 (exception groups)"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
+ {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
+]
+
+[package.extras]
+test = ["pytest (>=6)"]
+
+[[package]]
name = "fastapi"
version = "0.95.1"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
@@ -776,6 +790,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
[[package]]
+name = "iniconfig"
+version = "2.0.0"
+description = "brain-dead simple config-ini parsing"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
+ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
+]
+
+[[package]]
name = "jsonref"
version = "1.1.0"
description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python."
@@ -975,6 +1000,21 @@ files = [
]
[[package]]
+name = "pluggy"
+version = "1.3.0"
+description = "plugin and hook calling mechanisms for python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"},
+ {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["pytest", "pytest-benchmark"]
+
+[[package]]
name = "posthog"
version = "3.0.1"
description = "Integrate PostHog into any python application."
@@ -1186,6 +1226,46 @@ files = [
]
[[package]]
+name = "pytest"
+version = "7.4.1"
+description = "pytest: simple powerful testing with Python"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-7.4.1-py3-none-any.whl", hash = "sha256:460c9a59b14e27c602eb5ece2e47bec99dc5fc5f6513cf924a7d03a578991b1f"},
+ {file = "pytest-7.4.1.tar.gz", hash = "sha256:2f2301e797521b23e4d2585a0a3d7b5e50fdddaaf7e7d6773ea26ddb17c213ab"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
+iniconfig = "*"
+packaging = "*"
+pluggy = ">=0.12,<2.0"
+tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+
+[[package]]
+name = "pytest-asyncio"
+version = "0.21.1"
+description = "Pytest support for asyncio"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"},
+ {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"},
+]
+
+[package.dependencies]
+pytest = ">=7.0.0"
+
+[package.extras]
+docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
+testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
+
+[[package]]
name = "python-dateutil"
version = "2.8.2"
description = "Extensions to the standard Python datetime module"
@@ -1436,6 +1516,20 @@ files = [
]
[[package]]
+name = "sse-starlette"
+version = "1.6.5"
+description = "\"SSE plugin for Starlette\""
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "sse-starlette-1.6.5.tar.gz", hash = "sha256:819f2c421fb37067380fe3dcaba246c476b02651b7bb7601099a378ad802a0ac"},
+ {file = "sse_starlette-1.6.5-py3-none-any.whl", hash = "sha256:68b6b7eb49be0c72a2af80a055994c13afcaa4761b29226beb208f954c25a642"},
+]
+
+[package.dependencies]
+starlette = "*"
+
+[[package]]
name = "starlette"
version = "0.26.1"
description = "The little ASGI library that shines."
@@ -1553,6 +1647,17 @@ docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
[[package]]
+name = "tomli"
+version = "2.0.1"
+description = "A lil' TOML parser"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
+]
+
+[[package]]
name = "tqdm"
version = "4.65.0"
description = "Fast, Extensible Progress Meter"
@@ -1905,4 +2010,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1"
-content-hash = "fe4715494ed91c691ec1eb914373a612e75751e6685678e438b73193879de98d"
+content-hash = "94cd41b3db68b5e0ef60115992042a3e00f7d80e6b4be19a55d6ed2c0089342e"
diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml
index 8cdf1197..142c1d09 100644
--- a/continuedev/pyproject.toml
+++ b/continuedev/pyproject.toml
@@ -34,7 +34,11 @@ replicate = "^0.11.0"
redbaron = "^0.9.2"
[tool.poetry.scripts]
-typegen = "src.continuedev.models.generate_json_schema:main"
+typegen = "src.continuedev.models.generate_json_schema:main"
+
+[tool.poetry.group.dev.dependencies]
+pytest = "^7.4.1"
+pytest-asyncio = "^0.21.1"
[build-system]
requires = ["poetry-core"]
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index de0b8c53..bae82739 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -507,7 +507,7 @@ class Autopilot(ContinueBaseModel):
if self.session_info is None:
async def create_title():
- title = await self.continue_sdk.models.medium._complete(
+ title = await self.continue_sdk.models.medium.complete(
f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ',
max_tokens=20,
)
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index f3bf8125..9816d5d9 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -5,12 +5,13 @@ from pydantic import BaseModel
from ..libs.llm import LLM
from ..libs.llm.anthropic import AnthropicLLM
from ..libs.llm.ggml import GGML
+from ..libs.llm.llamacpp import LlamaCpp
from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
from ..libs.llm.ollama import Ollama
from ..libs.llm.openai import OpenAI
from ..libs.llm.replicate import ReplicateLLM
from ..libs.llm.together import TogetherLLM
-from ..libs.llm.llamacpp import LlamaCpp
+
class ContinueSDK(BaseModel):
pass
@@ -35,7 +36,7 @@ MODEL_CLASSES = {
AnthropicLLM,
ReplicateLLM,
Ollama,
- LlamaCpp
+ LlamaCpp,
]
}
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 6a321a41..059bf279 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,12 +1,58 @@
-from abc import ABC
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from pydantic import validator
+
from ...core.main import ChatMessage
from ...models.main import ContinueBaseModel
-from ..util.count_tokens import DEFAULT_ARGS, count_tokens
+from ..util.count_tokens import (
+ DEFAULT_ARGS,
+ DEFAULT_MAX_TOKENS,
+ compile_chat_messages,
+ count_tokens,
+ format_chat_messages,
+ prune_raw_prompt_from_top,
+)
+
+
+class CompletionOptions(ContinueBaseModel):
+ """Options for the completion."""
+
+ @validator(
+ "*",
+ pre=True,
+ always=True,
+ )
+ def ignore_none_and_set_default(cls, value, field):
+ return value if value is not None else field.default
+
+ model: str = None
+ "The model name"
+ temperature: Optional[float] = None
+ "The temperature of the completion."
+
+ top_p: Optional[float] = None
+ "The top_p of the completion."
+
+ top_k: Optional[int] = None
+ "The top_k of the completion."
+
+ presence_penalty: Optional[float] = None
+ "The presence penalty of the completion."
+
+ frequency_penalty: Optional[float] = None
+ "The frequency penalty of the completion."
+
+ stop: Optional[List[str]] = None
+ "The stop tokens of the completion."
+
+ max_tokens: int = DEFAULT_MAX_TOKENS
+ "The maximum number of tokens to generate."
+
+ functions: Optional[List[Any]] = None
+ "The functions/tools to make available to the model."
-class LLM(ContinueBaseModel, ABC):
+class LLM(ContinueBaseModel):
title: Optional[str] = None
system_message: Optional[str] = None
@@ -19,8 +65,14 @@ class LLM(ContinueBaseModel, ABC):
model: str
"The model name"
+ timeout: int = 300
+ "The timeout for the request in seconds."
+
prompt_templates: dict = {}
+ template_messages: Optional[Callable[[List[Dict[str, str]]], str]] = None
+ "A function that takes a list of messages and returns a prompt."
+
write_log: Optional[Callable[[str], None]] = None
"A function that takes a string and writes it to the log."
@@ -33,16 +85,12 @@ class LLM(ContinueBaseModel, ABC):
def dict(self, **kwargs):
original_dict = super().dict(**kwargs)
- original_dict.pop("write_log", None)
+ original_dict.pop("write_log")
+ original_dict.pop("template_messages")
+ original_dict.pop("unique_id")
original_dict["class_name"] = self.__class__.__name__
return original_dict
- def collect_args(self, **kwargs) -> Any:
- """Collect the arguments for the LLM."""
- args = {**DEFAULT_ARGS.copy(), "model": self.model, "max_tokens": 1024}
- args.update(kwargs)
- return args
-
async def start(
self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None
):
@@ -54,23 +102,194 @@ class LLM(ContinueBaseModel, ABC):
"""Stop the connection to the LLM."""
pass
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
+ def collect_args(self, options: CompletionOptions) -> Dict[str, Any]:
+ """Collect the arguments for the LLM."""
+ args = {**DEFAULT_ARGS.copy(), "model": self.model}
+ args.update(options.dict(exclude_unset=True, exclude_none=True))
+ return args
+
+ def compile_chat_messages(
+ self,
+ options: CompletionOptions,
+ msgs: List[ChatMessage],
+ functions: Optional[List[Any]] = None,
+ ) -> List[Dict]:
+ return compile_chat_messages(
+ model_name=options.model,
+ msgs=msgs,
+ context_length=self.context_length,
+ max_tokens=options.max_tokens,
+ functions=functions,
+ system_message=self.system_message,
+ )
+
+ def template_prompt_like_messages(self, prompt: str) -> str:
+ if self.template_messages is None:
+ return prompt
+
+ msgs = [{"role": "user", "content": prompt}]
+ if self.system_message is not None:
+ msgs.insert(0, {"role": "system", "content": self.system_message})
+
+ return self.template_messages(msgs)
+
+ async def stream_complete(
+ self,
+ prompt: str,
+ model: str = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: int = None,
+ presence_penalty: float = None,
+ frequency_penalty: float = None,
+ stop: Optional[List[str]] = None,
+ max_tokens: Optional[int] = None,
+ functions: Optional[List[Any]] = None,
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ """Yield completion response, either streamed or not."""
+ options = CompletionOptions(
+ model=model or self.model,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
+ stop=stop,
+ max_tokens=max_tokens,
+ functions=functions,
+ )
+
+ prompt = prune_raw_prompt_from_top(
+ self.model, self.context_length, prompt, options.max_tokens
+ )
+ prompt = self.template_prompt_like_messages(prompt)
+
+ self.write_log(f"Prompt: \n\n{prompt}")
+
+ completion = ""
+ async for chunk in self._stream_complete(prompt=prompt, options=options):
+ yield chunk
+ completion += chunk
+
+ self.write_log(f"Completion: \n\n{completion}")
+
+ async def complete(
+ self,
+ prompt: str,
+ model: str = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: int = None,
+ presence_penalty: float = None,
+ frequency_penalty: float = None,
+ stop: Optional[List[str]] = None,
+ max_tokens: Optional[int] = None,
+ functions: Optional[List[Any]] = None,
) -> Coroutine[Any, Any, str]:
- """Return the completion of the text with the given temperature."""
- raise NotImplementedError
+ """Yield completion response, either streamed or not."""
+ options = CompletionOptions(
+ model=model or self.model,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
+ stop=stop,
+ max_tokens=max_tokens,
+ functions=functions,
+ )
- def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
+ prompt = prune_raw_prompt_from_top(
+ self.model, self.context_length, prompt, options.max_tokens
+ )
+ prompt = self.template_prompt_like_messages(prompt)
+
+ self.write_log(f"Prompt: \n\n{prompt}")
+
+ completion = await self._complete(prompt=prompt, options=options)
+
+ self.write_log(f"Completion: \n\n{completion}")
+ return completion
+
+ async def stream_chat(
+ self,
+ messages: List[ChatMessage],
+ model: str = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: int = None,
+ presence_penalty: float = None,
+ frequency_penalty: float = None,
+ stop: Optional[List[str]] = None,
+ max_tokens: Optional[int] = None,
+ functions: Optional[List[Any]] = None,
) -> Generator[Union[Any, List, Dict], None, None]:
+ """Yield completion response, either streamed or not."""
+ options = CompletionOptions(
+ model=model or self.model,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
+ stop=stop,
+ max_tokens=max_tokens,
+ functions=functions,
+ )
+
+ messages = self.compile_chat_messages(
+ options=options, msgs=messages, functions=functions
+ )
+ if self.template_messages is not None:
+ prompt = self.template_messages(messages)
+ else:
+ prompt = format_chat_messages(messages)
+
+ self.write_log(f"Prompt: \n\n{prompt}")
+
+ completion = ""
+
+ # Use the template_messages function if it exists and do a raw completion
+ if self.template_messages is None:
+ async for chunk in self._stream_chat(messages=messages, options=options):
+ yield chunk
+ if "content" in chunk:
+ completion += chunk["content"]
+ else:
+ async for chunk in self._stream_complete(prompt=prompt, options=options):
+ yield {"role": "assistant", "content": chunk}
+ completion += chunk
+
+ self.write_log(f"Completion: \n\n{completion}")
+
+ def _stream_complete(
+ self, prompt, options: CompletionOptions
+ ) -> Generator[str, None, None]:
"""Stream the completion through generator."""
raise NotImplementedError
+ async def _complete(
+ self, prompt: str, options: CompletionOptions
+ ) -> Coroutine[Any, Any, str]:
+ """Return the completion of the text with the given temperature."""
+ completion = ""
+ async for chunk in self._stream_complete(prompt=prompt, options=options):
+ completion += chunk
+ return completion
+
async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
+ self, messages: List[ChatMessage], options: CompletionOptions
) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the chat through generator."""
- raise NotImplementedError
+ if self.template_messages is None:
+ raise NotImplementedError(
+ "You must either implement template_messages or _stream_chat"
+ )
+
+ async for chunk in self._stream_complete(
+ prompt=self.template_messages(messages), options=options
+ ):
+ yield {"role": "assistant", "content": chunk}
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index b5aff63a..70a0868c 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,10 +1,9 @@
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Any, Callable, Coroutine
from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic
-from ...core.main import ChatMessage
-from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages
+from ..llm import LLM, CompletionOptions
+from .prompts.chat import anthropic_template_messages
class AnthropicLLM(LLM):
@@ -15,21 +14,21 @@ class AnthropicLLM(LLM):
_async_client: AsyncAnthropic = None
+ template_messages: Callable = anthropic_template_messages
+
class Config:
arbitrary_types_allowed = True
- async def start(
- self,
- **kwargs,
- ):
+ async def start(self, **kwargs):
await super().start(**kwargs)
self._async_client = AsyncAnthropic(api_key=self.api_key)
if self.model == "claude-2":
self.context_length = 100_000
- def collect_args(self, **kwargs) -> Any:
- args = super().collect_args(**kwargs)
+ def collect_args(self, options: CompletionOptions):
+ options.stop = None
+ args = super().collect_args(options)
if "max_tokens" in args:
args["max_tokens_to_sample"] = args["max_tokens"]
@@ -40,85 +39,18 @@ class AnthropicLLM(LLM):
del args["presence_penalty"]
return args
- def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
- prompt = ""
-
- # Anthropic prompt must start with a Human turn
- if (
- len(messages) > 0
- and messages[0]["role"] != "user"
- and messages[0]["role"] != "system"
- ):
- prompt += f"{HUMAN_PROMPT} Hello."
- for msg in messages:
- prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} "
-
- prompt += AI_PROMPT
- return prompt
-
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
+ async def _stream_complete(self, prompt: str, options):
+ args = self.collect_args(options)
prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}"
- self.write_log(f"Prompt: \n\n{prompt}")
- completion = ""
async for chunk in await self._async_client.completions.create(
- prompt=prompt, **args
+ prompt=prompt, stream=True, **args
):
yield chunk.completion
- completion += chunk.completion
-
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
-
- messages = compile_chat_messages(
- args["model"],
- messages,
- self.context_length,
- args["max_tokens_to_sample"],
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- completion = ""
- prompt = self.__messages_to_prompt(messages)
- self.write_log(f"Prompt: \n\n{prompt}")
- async for chunk in await self._async_client.completions.create(
- prompt=prompt, **args
- ):
- yield {"role": "assistant", "content": chunk.completion}
- completion += chunk.completion
-
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
-
- messages = compile_chat_messages(
- args["model"],
- with_history,
- self.context_length,
- args["max_tokens_to_sample"],
- prompt,
- functions=None,
- system_message=self.system_message,
- )
-
- prompt = self.__messages_to_prompt(messages)
- self.write_log(f"Prompt: \n\n{prompt}")
- resp = (
+ async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]:
+ args = self.collect_args(options)
+ prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}"
+ return (
await self._async_client.completions.create(prompt=prompt, **args)
).completion
-
- self.write_log(f"Completion: \n\n{resp}")
- return resp
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 70c265d9..e4971867 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,11 +1,12 @@
import json
-from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Coroutine, List, Optional
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, format_chat_messages
+from ..util.logging import logger
+from .prompts.edit import simplified_edit_prompt
class GGML(LLM):
@@ -13,61 +14,44 @@ class GGML(LLM):
verify_ssl: Optional[bool] = None
model: str = "ggml"
- timeout: int = 300
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
class Config:
arbitrary_types_allowed = True
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
- messages = compile_chat_messages(
- self.model,
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
-
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- completion = ""
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as client_session:
async with client_session.post(
- f"{self.server_url}/v1/completions", json={"messages": messages, **args}
+ f"{self.server_url}/v1/completions",
+ json={
+ "prompt": prompt,
+ "stream": True,
+ **args,
+ },
) as resp:
async for line in resp.content.iter_any():
if line:
- try:
- chunk = line.decode("utf-8")
- yield chunk
- completion += chunk
- except:
- raise Exception(str(line))
+ chunk = line.decode("utf-8")
+ if chunk.startswith(": ping - ") or chunk.startswith(
+ "data: [DONE]"
+ ):
+ continue
+ elif chunk.startswith("data: "):
+ chunk = chunk[6:]
- self.write_log(f"Completion: \n\n{completion}")
+ j = json.loads(chunk)
+ if "choices" in j:
+ yield j["choices"][0]["text"]
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- args["stream"] = True
+ async def _stream_chat(self, messages: List[ChatMessage], options):
+ args = self.collect_args(options)
async def generator():
async with aiohttp.ClientSession(
@@ -76,10 +60,9 @@ class GGML(LLM):
) as client_session:
async with client_session.post(
f"{self.server_url}/v1/chat/completions",
- json={"messages": messages, **args},
+ json={"messages": messages, "stream": True, **args},
headers={"Content-Type": "application/json"},
) as resp:
- # This is streaming application/json instaed of text/event-stream
async for line, end in resp.content.iter_chunks():
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
@@ -91,34 +74,22 @@ class GGML(LLM):
):
continue
try:
- yield json.loads(chunk[6:])["choices"][0][
- "delta"
- ] # {"role": "assistant", "content": "..."}
+ yield json.loads(chunk[6:])["choices"][0]["delta"]
except:
pass
# Because quite often the first attempt fails, and it works thereafter
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- completion = ""
try:
async for chunk in generator():
yield chunk
- if "content" in chunk:
- completion += chunk["content"]
- except:
+ except Exception as e:
+ logger.warning(f"Error calling /chat/completions endpoint: {e}")
async for chunk in generator():
yield chunk
- if "content" in chunk:
- completion += chunk["content"]
-
- self.write_log(f"Completion: \n\n{completion}")
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
+ async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]:
+ args = self.collect_args(options)
- self.write_log(f"Prompt: \n\n{prompt}")
async with aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
timeout=aiohttp.ClientTimeout(total=self.timeout),
@@ -133,7 +104,6 @@ class GGML(LLM):
text = await resp.text()
try:
completion = json.loads(text)["choices"][0]["text"]
- self.write_log(f"Completion: \n\n{completion}")
return completion
except Exception as e:
raise Exception(
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 3a586a43..43aac148 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,10 +1,11 @@
-from typing import Any, Coroutine, Dict, Generator, List, Optional
+from typing import List, Optional
import aiohttp
import requests
from ...core.main import ChatMessage
from ..llm import LLM
+from .prompts.edit import simplified_edit_prompt
DEFAULT_MAX_TIME = 120.0
@@ -17,21 +18,24 @@ class HuggingFaceInferenceAPI(LLM):
_client_session: aiohttp.ClientSession = None
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
class Config:
arbitrary_types_allowed = True
async def start(self, **kwargs):
await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
)
async def stop(self):
await self._client_session.close()
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ):
+ async def _complete(self, prompt: str, options):
"""Return the completion of the text with the given temperature."""
API_URL = (
self.base_url or f"https://api-inference.huggingface.co/models/{self.model}"
@@ -60,14 +64,10 @@ class HuggingFaceInferenceAPI(LLM):
return data[0]["generated_text"]
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]:
+ async def _stream_chat(self, messages: List[ChatMessage], options):
response = await self._complete(messages[-1].content, messages[:-1])
yield {"content": response, "role": "assistant"}
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Any | List | Dict, None, None]:
- response = await self._complete(prompt, with_history)
+ async def _stream_complete(self, prompt, options):
+ response = await self._complete(prompt, options)
yield response
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index 5c7e0239..508ebe87 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -1,12 +1,12 @@
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Callable, List, Optional
import aiohttp
from ...core.main import ChatMessage
-from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages
+from ..llm import LLM, CompletionOptions
from .prompts.chat import code_llama_template_messages
+from .prompts.edit import simplified_edit_prompt
class HuggingFaceTGI(LLM):
@@ -16,11 +16,15 @@ class HuggingFaceTGI(LLM):
template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
class Config:
arbitrary_types_allowed = True
- def collect_args(self, **kwargs) -> Any:
- args = super().collect_args(**kwargs)
+ def collect_args(self, options: CompletionOptions) -> Any:
+ args = super().collect_args(options)
args = {
**args,
"max_new_tokens": args.get("max_tokens", 1024),
@@ -28,31 +32,16 @@ class HuggingFaceTGI(LLM):
args.pop("max_tokens", None)
return args
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
-
- messages = compile_chat_messages(
- self.model,
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
- prompt = self.template_messages(messages)
- self.write_log(f"Prompt: \n\n{prompt}")
- completion = ""
async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
) as client_session:
async with client_session.post(
f"{self.server_url}",
- json={"inputs": prompt, **args},
+ json={"inputs": prompt, "stream": True, **args},
) as resp:
async for line in resp.content.iter_any():
if line:
@@ -62,39 +51,3 @@ class HuggingFaceTGI(LLM):
"generated_text"
]
yield text
- completion += text
-
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
-
- async for chunk in self._stream_complete(
- None, self.template_messages(messages), **args
- ):
- yield {
- "role": "assistant",
- "content": chunk,
- }
-
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
-
- completion = ""
- async for chunk in self._stream_complete(prompt, with_history, **args):
- completion += chunk
-
- return completion
diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py
index f246a43c..c2e934c0 100644
--- a/continuedev/src/continuedev/libs/llm/hugging_face.py
+++ b/continuedev/src/continuedev/libs/llm/hugging_face.py
@@ -1,3 +1,5 @@
+# TODO: This class is far out of date
+
from transformers import AutoModelForCausalLM, AutoTokenizer
from .llm import LLM
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index 9d8b548f..3596fd99 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -1,13 +1,12 @@
import asyncio
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Callable, Dict, Optional
import aiohttp
-from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages
from .prompts.chat import llama2_template_messages
+from .prompts.edit import simplified_edit_prompt
class LlamaCpp(LLM):
@@ -15,21 +14,20 @@ class LlamaCpp(LLM):
server_url: str = "http://localhost:8080"
verify_ssl: Optional[bool] = None
- template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages
llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]}
use_command: Optional[str] = None
+ template_messages: Callable = llama2_template_messages
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
class Config:
arbitrary_types_allowed = True
- def dict(self, **kwargs):
- d = super().dict(**kwargs)
- d.pop("template_messages")
- return d
-
- def collect_args(self, **kwargs) -> Any:
- args = super().collect_args(**kwargs)
+ def collect_args(self, options) -> Any:
+ args = super().collect_args(options)
if "max_tokens" in args:
args["n_predict"] = args["max_tokens"]
del args["max_tokens"]
@@ -61,119 +59,30 @@ class LlamaCpp(LLM):
await process.wait()
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
-
- messages = compile_chat_messages(
- self.model,
- with_history,
- self.context_length,
- args["n_predict"] if "n_predict" in args else 1024,
- prompt,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
-
- prompt = self.convert_to_chat(messages)
- self.write_log(f"Prompt: \n\n{prompt}")
- completion = ""
- async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
- ) as client_session:
- async with client_session.post(
- f"{self.server_url}/completion",
- json={
- "prompt": prompt,
- **args,
- },
- headers={"Content-Type": "application/json"},
- ) as resp:
- async for line in resp.content.iter_any():
- if line:
- chunk = line.decode("utf-8")
- yield chunk
- completion += chunk
-
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- messages,
- self.context_length,
- args["n_predict"] if "n_predict" in args else 1024,
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- args["stream"] = True
-
- prompt = self.template_messages(messages)
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
headers = {"Content-Type": "application/json"}
async def server_generator():
async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
) as client_session:
async with client_session.post(
f"{self.server_url}/completion",
- json={"prompt": prompt, **args},
+ json={"prompt": prompt, "stream": True, **args},
headers=headers,
) as resp:
async for line in resp.content:
content = line.decode("utf-8")
if content.strip() == "":
continue
- yield {
- "content": json.loads(content[6:])["content"],
- "role": "assistant",
- }
+ yield json.loads(content[6:])["content"]
async def command_generator():
async for line in self.stream_from_main(prompt):
- yield {"content": line, "role": "assistant"}
+ yield line
generator = command_generator if self.use_command else server_generator
-
- # Because quite often the first attempt fails, and it works thereafter
- self.write_log(f"Prompt: \n\n{prompt}")
- completion = ""
async for chunk in generator():
yield chunk
- if "content" in chunk:
- completion += chunk["content"]
-
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
-
- self.write_log(f"Prompt: \n\n{prompt}")
-
- if self.use_command:
- completion = ""
- async for line in self.stream_from_main(prompt):
- completion += line
- self.write_log(f"Completion: \n\n{completion}")
- return completion
- else:
- async with aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
- ) as client_session:
- async with client_session.post(
- f"{self.server_url}/completion",
- json={"prompt": prompt, **args},
- headers={"Content-Type": "application/json"},
- ) as resp:
- json_resp = await resp.json()
- completion = json_resp["content"]
- self.write_log(f"Completion: \n\n{completion}")
- return completion
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
index 99b7c47f..084c57fd 100644
--- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -1,4 +1,4 @@
-from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Callable, List, Optional
from ...core.main import ChatMessage
from . import LLM
@@ -7,7 +7,6 @@ from .proxy_server import ProxyServer
class MaybeProxyOpenAI(LLM):
- model: str
api_key: Optional[str] = None
llm: Optional[LLM] = None
@@ -16,36 +15,33 @@ class MaybeProxyOpenAI(LLM):
if self.llm is not None:
self.llm.system_message = self.system_message
- async def start(self, **kwargs):
+ async def start(
+ self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None
+ ):
+ await super().start(write_log=lambda *args, **kwargs: None, unique_id=unique_id)
if self.api_key is None or self.api_key.strip() == "":
self.llm = ProxyServer(model=self.model)
else:
self.llm = OpenAI(api_key=self.api_key, model=self.model)
- await self.llm.start(**kwargs)
+ await self.llm.start(write_log=write_log, unique_id=unique_id)
async def stop(self):
await self.llm.stop()
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
+ async def _complete(self, prompt: str, options):
self.update_llm_properties()
- return await self.llm._complete(prompt, with_history=with_history, **kwargs)
+ return await self.llm._complete(prompt, options)
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
+ async def _stream_complete(self, prompt, options):
self.update_llm_properties()
- resp = self.llm._stream_complete(prompt, with_history=with_history, **kwargs)
+ resp = self.llm._stream_complete(prompt, options)
async for item in resp:
yield item
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
+ async def _stream_chat(self, messages: List[ChatMessage], options):
self.update_llm_properties()
- resp = self.llm._stream_chat(messages=messages, **kwargs)
+ resp = self.llm._stream_chat(messages=messages, options=options)
async for item in resp:
yield item
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py
index ef8ed47b..d0da281a 100644
--- a/continuedev/src/continuedev/libs/llm/ollama.py
+++ b/continuedev/src/continuedev/libs/llm/ollama.py
@@ -1,14 +1,11 @@
import json
-import urllib.parse
-from textwrap import dedent
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import Callable
import aiohttp
-from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages
from .prompts.chat import llama2_template_messages
+from .prompts.edit import simplified_edit_prompt
class Ollama(LLM):
@@ -17,18 +14,10 @@ class Ollama(LLM):
_client_session: aiohttp.ClientSession = None
+ template_messages: Callable = llama2_template_messages
+
prompt_templates = {
- "edit": dedent(
- """\
- [INST] Consider the following code:
- ```
- {{code_to_edit}}
- ```
- Edit the code to perfectly satisfy the following user request:
- {{user_input}}
- Output nothing except for the code. No code block, no English explanation, no start/end tags.
- [/INST]"""
- ),
+ "edit": simplified_edit_prompt,
}
class Config:
@@ -36,36 +25,23 @@ class Ollama(LLM):
async def start(self, **kwargs):
await super().start(**kwargs)
- self._client_session = aiohttp.ClientSession()
+ self._client_session = aiohttp.ClientSession(
+ timeout=aiohttp.ClientTimeout(total=self.timeout)
+ )
async def stop(self):
await self._client_session.close()
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=None,
- system_message=self.system_message,
- )
- prompt = llama2_template_messages(messages)
-
+ async def _stream_complete(self, prompt, options):
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
"template": prompt,
"model": self.model,
"system": self.system_message,
- "options": {"temperature": args["temperature"]},
+ "options": {"temperature": options.temperature},
},
) as resp:
- url_decode_buffer = ""
async for line in resp.content.iter_any():
if line:
json_chunk = line.decode("utf-8")
@@ -74,80 +50,4 @@ class Ollama(LLM):
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
- url_decode_buffer += j["response"]
-
- if (
- "&" in url_decode_buffer
- and url_decode_buffer.index("&")
- > len(url_decode_buffer) - 5
- ):
- continue
- yield urllib.parse.unquote(url_decode_buffer)
- url_decode_buffer = ""
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=None,
- system_message=self.system_message,
- )
- prompt = llama2_template_messages(messages)
-
- self.write_log(f"Prompt:\n{prompt}")
- completion = ""
- async with self._client_session.post(
- f"{self.server_url}/api/generate",
- json={
- "template": prompt,
- "model": self.model,
- "system": self.system_message,
- "options": {"temperature": args["temperature"]},
- },
- ) as resp:
- async for line in resp.content.iter_chunks():
- if line[1]:
- json_chunk = line[0].decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- yield {
- "role": "assistant",
- "content": j["response"],
- }
- completion += j["response"]
- self.write_log(f"Completion:\n{completion}")
-
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- completion = ""
- args = self.collect_args(**kwargs)
- async with self._client_session.post(
- f"{self.server_url}/api/generate",
- json={
- "template": prompt,
- "model": self.model,
- "system": self.system_message,
- "options": {"temperature": args["temperature"]},
- },
- ) as resp:
- async for line in resp.content.iter_any():
- if line:
- json_chunk = line.decode("utf-8")
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- j = json.loads(chunk)
- if "response" in j:
- completion += urllib.parse.unquote(j["response"])
-
- return completion
+ yield j["response"]
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index a017af22..2074ae04 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -1,25 +1,10 @@
-from typing import (
- Any,
- Callable,
- Coroutine,
- Dict,
- Generator,
- List,
- Literal,
- Optional,
- Union,
-)
+from typing import Callable, List, Literal, Optional
import certifi
import openai
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- compile_chat_messages,
- format_chat_messages,
- prune_raw_prompt_from_top,
-)
CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"}
MAX_TOKENS_FOR_MODEL = {
@@ -82,118 +67,53 @@ class OpenAI(LLM):
openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
- def collect_args(self, **kwargs):
- args = super().collect_args()
+ def collect_args(self, options):
+ args = super().collect_args(options)
if self.engine is not None:
args["engine"] = self.engine
+
+ if not args["model"].endswith("0613") and "functions" in args:
+ del args["functions"]
+
return args
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
args["stream"] = True
if args["model"] in CHAT_MODELS:
- messages = compile_chat_messages(
- args["model"],
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=None,
- system_message=self.system_message,
- )
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- completion = ""
async for chunk in await openai.ChatCompletion.acreate(
- messages=messages,
+ messages=[{"role": "user", "content": prompt}],
**args,
):
if "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
- completion += chunk.choices[0].delta.content
- else:
- continue # :)
-
- self.write_log(f"Completion: \n\n{completion}")
else:
- self.write_log(f"Prompt:\n\n{prompt}")
- completion = ""
async for chunk in await openai.Completion.acreate(prompt=prompt, **args):
yield chunk.choices[0].text
- completion += chunk.choices[0].text
-
- self.write_log(f"Completion:\n\n{completion}")
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
-
- if not args["model"].endswith("0613") and "functions" in args:
- del args["functions"]
+ async def _stream_chat(self, messages: List[ChatMessage], options):
+ args = self.collect_args(options)
- messages = compile_chat_messages(
- args["model"],
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- completion = ""
async for chunk in await openai.ChatCompletion.acreate(
messages=messages,
+ stream=True,
**args,
):
if len(chunk.choices) == 0:
continue
yield chunk.choices[0].delta
- if "content" in chunk.choices[0].delta:
- completion += chunk.choices[0].delta.content
- self.write_log(f"Completion: \n\n{completion}")
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
+ async def _complete(self, prompt: str, options):
+ args = self.collect_args(options)
if args["model"] in CHAT_MODELS:
- messages = compile_chat_messages(
- args["model"],
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=None,
- system_message=self.system_message,
- )
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
resp = await openai.ChatCompletion.acreate(
- messages=messages,
+ messages=[{"role": "user", "content": prompt}],
**args,
)
- completion = resp.choices[0].message.content
- self.write_log(f"Completion: \n\n{completion}")
+ return resp.choices[0].message.content
else:
- prompt = prune_raw_prompt_from_top(
- args["model"], self.context_length, prompt, args["max_tokens"]
- )
- self.write_log(f"Prompt:\n\n{prompt}")
- completion = (
- (
- await openai.Completion.acreate(
- prompt=prompt,
- **args,
- )
- )
- .choices[0]
- .text
+ return (
+ (await openai.Completion.acreate(prompt=prompt, **args)).choices[0].text
)
- self.write_log(f"Completion:\n\n{completion}")
-
- return completion
diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py
index c7c208c0..1329a2ff 100644
--- a/continuedev/src/continuedev/libs/llm/prompts/chat.py
+++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py
@@ -1,9 +1,27 @@
from textwrap import dedent
+from typing import Dict, List
-from ....core.main import ChatMessage
+from anthropic import AI_PROMPT, HUMAN_PROMPT
-def llama2_template_messages(msgs: ChatMessage) -> str:
+def anthropic_template_messages(messages: List[Dict[str, str]]) -> str:
+ prompt = ""
+
+ # Anthropic prompt must start with a Human turn
+ if (
+ len(messages) > 0
+ and messages[0]["role"] != "user"
+ and messages[0]["role"] != "system"
+ ):
+ prompt += f"{HUMAN_PROMPT} Hello."
+ for msg in messages:
+ prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} "
+
+ prompt += AI_PROMPT
+ return prompt
+
+
+def llama2_template_messages(msgs: List[Dict[str, str]]) -> str:
if len(msgs) == 0:
return ""
@@ -38,20 +56,20 @@ def llama2_template_messages(msgs: ChatMessage) -> str:
if msgs[i]["role"] == "user":
prompt += f"[INST] {msgs[i]['content']} [/INST]"
else:
- prompt += msgs[i]["content"]
+ prompt += msgs[i]["content"] + " "
return prompt
-def code_llama_template_messages(msgs: ChatMessage) -> str:
+def code_llama_template_messages(msgs: List[Dict[str, str]]) -> str:
return f"[INST] {msgs[-1]['content']}\n[/INST]"
-def extra_space_template_messages(msgs: ChatMessage) -> str:
+def extra_space_template_messages(msgs: List[Dict[str, str]]) -> str:
return f" {msgs[-1]['content']}"
-def code_llama_python_template_messages(msgs: ChatMessage) -> str:
+def code_llama_python_template_messages(msgs: List[Dict[str, str]]) -> str:
return dedent(
f"""\
[INST]
diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py
new file mode 100644
index 00000000..a234fa61
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py
@@ -0,0 +1,13 @@
+from textwrap import dedent
+
+simplified_edit_prompt = dedent(
+ """\
+ [INST] Consider the following code:
+ ```
+ {{code_to_edit}}
+ ```
+ Edit the code to perfectly satisfy the following user request:
+ {{user_input}}
+ Output nothing except for the code. No code block, no English explanation, no start/end tags.
+ [/INST]"""
+)
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 3ac6371f..d62fafa7 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,14 +1,13 @@
import json
import ssl
import traceback
-from typing import Any, Coroutine, Dict, Generator, List, Union
+from typing import List
import aiohttp
import certifi
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages, format_chat_messages
from ..util.telemetry import posthog_logger
ca_bundle_path = certifi.where()
@@ -37,7 +36,8 @@ class ProxyServer(LLM):
):
await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(ssl_context=ssl_context)
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
)
self.context_length = MAX_TOKENS_FOR_MODEL[self.model]
@@ -45,60 +45,32 @@ class ProxyServer(LLM):
await self._client_session.close()
def get_headers(self):
- # headers with unique id
return {"unique_id": self.unique_id}
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
-
- messages = compile_chat_messages(
- args["model"],
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=None,
- system_message=self.system_message,
- )
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ async def _complete(self, prompt: str, options):
+ args = self.collect_args(options)
+
async with self._client_session.post(
f"{SERVER_URL}/complete",
- json={"messages": messages, **args},
+ json={"messages": [{"role": "user", "content": prompt}], **args},
headers=self.get_headers(),
) as resp:
+ resp_text = await resp.text()
if resp.status != 200:
- raise Exception(await resp.text())
+ raise Exception(resp_text)
- response_text = await resp.text()
- self.write_log(f"Completion: \n\n{response_text}")
- return response_text
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- args["model"],
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ return resp_text
+ async def _stream_chat(self, messages: List[ChatMessage], options):
+ args = self.collect_args(options)
async with self._client_session.post(
f"{SERVER_URL}/stream_chat",
json={"messages": messages, **args},
headers=self.get_headers(),
) as resp:
- # This is streaming application/json instaed of text/event-stream
- completion = ""
if resp.status != 200:
raise Exception(await resp.text())
+
async for line in resp.content.iter_chunks():
if line[1]:
try:
@@ -109,8 +81,7 @@ class ProxyServer(LLM):
if chunk.strip() != "":
loaded_chunk = json.loads(chunk)
yield loaded_chunk
- if "content" in loaded_chunk:
- completion += loaded_chunk["content"]
+
except Exception as e:
posthog_logger.capture_event(
"proxy_server_parse_error",
@@ -124,37 +95,18 @@ class ProxyServer(LLM):
else:
break
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
async with self._client_session.post(
f"{SERVER_URL}/stream_complete",
- json={"messages": messages, **args},
+ json={"messages": [{"role": "user", "content": prompt}], **args},
headers=self.get_headers(),
) as resp:
- completion = ""
if resp.status != 200:
raise Exception(await resp.text())
+
async for line in resp.content.iter_any():
if line:
- try:
- decoded_line = line.decode("utf-8")
- yield decoded_line
- completion += decoded_line
- except:
- raise Exception(str(line))
- self.write_log(f"Completion: \n\n{completion}")
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
index fb0d3f5c..1ed493c1 100644
--- a/continuedev/src/continuedev/libs/llm/replicate.py
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -5,6 +5,7 @@ import replicate
from ...core.main import ChatMessage
from . import LLM
+from .prompts.edit import simplified_edit_prompt
class ReplicateLLM(LLM):
@@ -15,13 +16,15 @@ class ReplicateLLM(LLM):
_client: replicate.Client = None
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
async def start(self, **kwargs):
await super().start(**kwargs)
self._client = replicate.Client(api_token=self.api_key)
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ):
+ async def _complete(self, prompt: str, options):
def helper():
output = self._client.run(
self.model, input={"message": prompt, "prompt": prompt}
@@ -38,17 +41,18 @@ class ReplicateLLM(LLM):
return completion
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ):
+ async def _stream_complete(self, prompt, options):
for item in self._client.run(
self.model, input={"message": prompt, "prompt": prompt}
):
yield item
- async def _stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
+ async def _stream_chat(self, messages: List[ChatMessage], options):
for item in self._client.run(
self.model,
- input={"message": messages[-1].content, "prompt": messages[-1].content},
+ input={
+ "message": messages[-1]["content"],
+ "prompt": messages[-1]["content"],
+ },
):
yield {"content": item, "role": "assistant"}
diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py
index 59627629..e37366c7 100644
--- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py
+++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py
@@ -1,42 +1,37 @@
import json
-from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, List, Optional
import websockets
from ...core.main import ChatMessage
-from ..util.count_tokens import compile_chat_messages, format_chat_messages
from . import LLM
+from .prompts.edit import simplified_edit_prompt
class TextGenUI(LLM):
- # this is model-specific
model: str = "text-gen-ui"
server_url: str = "http://localhost:5000"
streaming_url: str = "http://localhost:5005"
verify_ssl: Optional[bool] = None
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
class Config:
arbitrary_types_allowed = True
- def _transform_args(self, args):
- args = {
- **args,
- "max_new_tokens": args.get("max_tokens", 1024),
- }
+ def collect_args(self, options) -> Any:
+ args = super().collect_args(options)
+ args = {**args, "max_new_tokens": options.max_tokens}
args.pop("max_tokens", None)
return args
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream"] = True
-
- self.write_log(f"Prompt: \n\n{prompt}")
- completion = ""
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}"
- payload = json.dumps({"prompt": prompt, **self._transform_args(args)})
+ payload = json.dumps({"prompt": prompt, "stream": True, **args})
async with websockets.connect(
f"{ws_url}/api/v1/stream", ping_interval=None
) as websocket:
@@ -48,27 +43,12 @@ class TextGenUI(LLM):
match incoming_data["event"]:
case "text_stream":
- completion += incoming_data["text"]
yield incoming_data["text"]
case "stream_end":
break
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- args["stream"] = True
+ async def _stream_chat(self, messages: List[ChatMessage], options):
+ args = self.collect_args(options)
async def generator():
ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}"
@@ -77,7 +57,8 @@ class TextGenUI(LLM):
{
"user_input": messages[-1]["content"],
"history": {"internal": [history], "visible": [history]},
- **self._transform_args(args),
+ "stream": True,
+ **args,
}
)
async with websockets.connect(
@@ -102,26 +83,5 @@ class TextGenUI(LLM):
case "stream_end":
break
- # Because quite often the first attempt fails, and it works thereafter
- self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- completion = ""
async for chunk in generator():
yield chunk
- if "content" in chunk:
- completion += chunk["content"]
-
- self.write_log(f"Completion: \n\n{completion}")
-
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- generator = self._stream_chat(
- [ChatMessage(role="user", content=prompt, summary=prompt)], **kwargs
- )
-
- completion = ""
- async for chunk in generator:
- if "content" in chunk:
- completion += chunk["content"]
-
- return completion
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
index d8c7334b..03c9cce4 100644
--- a/continuedev/src/continuedev/libs/llm/together.py
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -1,16 +1,14 @@
import json
-from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Callable, Optional
import aiohttp
-from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import compile_chat_messages
from .prompts.chat import llama2_template_messages
+from .prompts.edit import simplified_edit_prompt
class TogetherLLM(LLM):
- # this is model-specific
api_key: str
"Together API key"
@@ -20,61 +18,32 @@ class TogetherLLM(LLM):
_client_session: aiohttp.ClientSession = None
+ template_messages: Callable = llama2_template_messages
+
+ prompt_templates = {
+ "edit": simplified_edit_prompt,
+ }
+
async def start(self, **kwargs):
await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
- connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
)
async def stop(self):
await self._client_session.close()
- async def _stream_complete(
- self, prompt, with_history: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- args["stream_tokens"] = True
-
- messages = compile_chat_messages(
- self.model,
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
-
- async with self._client_session.post(
- f"{self.base_url}/inference",
- json={"prompt": llama2_template_messages(messages), **args},
- headers={"Authorization": f"Bearer {self.api_key}"},
- ) as resp:
- async for line in resp.content.iter_any():
- if line:
- try:
- yield line.decode("utf-8")
- except:
- raise Exception(str(line))
-
- async def _stream_chat(
- self, messages: List[ChatMessage] = None, **kwargs
- ) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.collect_args(**kwargs)
- messages = compile_chat_messages(
- self.model,
- messages,
- self.context_length,
- args["max_tokens"],
- None,
- functions=args.get("functions", None),
- system_message=self.system_message,
- )
- args["stream_tokens"] = True
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
async with self._client_session.post(
f"{self.base_url}/inference",
- json={"prompt": llama2_template_messages(messages), **args},
+ json={
+ "prompt": prompt,
+ "stream_tokens": True,
+ **args,
+ },
headers={"Authorization": f"Bearer {self.api_key}"},
) as resp:
async for line in resp.content.iter_chunks():
@@ -92,36 +61,19 @@ class TogetherLLM(LLM):
chunk = chunk[6:]
json_chunk = json.loads(chunk)
if "choices" in json_chunk:
- yield {
- "role": "assistant",
- "content": json_chunk["choices"][0]["text"],
- }
+ yield json_chunk["choices"][0]["text"]
- async def _complete(
- self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
- ) -> Coroutine[Any, Any, str]:
- args = self.collect_args(**kwargs)
+ async def _complete(self, prompt: str, options):
+ args = self.collect_args(options)
- messages = compile_chat_messages(
- args["model"],
- with_history,
- self.context_length,
- args["max_tokens"],
- prompt,
- functions=None,
- system_message=self.system_message,
- )
async with self._client_session.post(
f"{self.base_url}/inference",
- json={"prompt": llama2_template_messages(messages), **args},
+ json={"prompt": prompt, **args},
headers={"Authorization": f"Bearer {self.api_key}"},
) as resp:
- try:
- text = await resp.text()
- j = json.loads(text)
- if "choices" not in j["output"]:
- raise Exception(text)
- if "output" in j:
- return j["output"]["choices"][0]["text"]
- except:
- raise Exception(await resp.text())
+ text = await resp.text()
+ j = json.loads(text)
+ if "choices" not in j["output"]:
+ raise Exception(text)
+ if "output" in j:
+ return j["output"]["choices"][0]["text"]
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 2663aa1c..aaa32907 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -12,13 +12,10 @@ aliases = {
"ggml": "gpt-3.5-turbo",
"claude-2": "gpt-3.5-turbo",
}
-DEFAULT_MAX_TOKENS = 2048
+DEFAULT_MAX_TOKENS = 1024
DEFAULT_ARGS = {
"max_tokens": DEFAULT_MAX_TOKENS,
"temperature": 0.5,
- "top_p": 1,
- "frequency_penalty": 0,
- "presence_penalty": 0,
}
@@ -144,13 +141,14 @@ def compile_chat_messages(
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
+
msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else []
if prompt is not None:
prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt)
msgs_copy += [prompt_msg]
- if system_message is not None:
+ if system_message is not None and system_message.strip() != "":
# NOTE: System message takes second precedence to user prompt, so it is placed just before
# but move back to start after processing
rendered_system_message = render_templated_string(system_message)
@@ -168,6 +166,11 @@ def compile_chat_messages(
for function in functions:
function_tokens += count_tokens(model_name, json.dumps(function))
+ if max_tokens + function_tokens + TOKEN_BUFFER_FOR_SAFETY >= context_length:
+ raise ValueError(
+ f"max_tokens ({max_tokens}) is too close to context_length ({context_length}), which doesn't leave room for chat history. This would cause incoherent responses. Try increasing the context_length parameter of the model in your config file."
+ )
+
msgs_copy = prune_chat_history(
model_name,
msgs_copy,
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index 45a4a599..7c8ee76f 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -80,9 +80,13 @@ filtered_attrs = {
}
+def escape_string(string: str) -> str:
+ return string.replace('"', '\\"').replace("'", "\\'")
+
+
def display_val(v: Any):
if isinstance(v, str):
- return f'"{v}"'
+ return f'"{escape_string(v)}"'
return str(v)
@@ -103,6 +107,7 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron:
def create_string_node(string: str) -> redbaron.RedBaron:
+ string = escape_string(string)
if "\n" in string:
return redbaron.RedBaron(f'"""{string}"""')[0]
return redbaron.RedBaron(f'"{string}"')[0]
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index fe049268..43a2b800 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -30,7 +30,7 @@ class SetupPipelineStep(Step):
sdk.context.set("api_description", self.api_description)
source_name = (
- await sdk.models.medium._complete(
+ await sdk.models.medium.complete(
f"Write a snake_case name for the data source described by {self.api_description}: "
)
).strip()
@@ -115,7 +115,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.medium._complete(
+ suggestion = await sdk.models.medium.complete(
dedent(
f"""\
```python
@@ -131,7 +131,7 @@ class ValidatePipelineStep(Step):
)
)
- api_documentation_url = await sdk.models.medium._complete(
+ api_documentation_url = await sdk.models.medium.complete(
dedent(
f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
@@ -216,7 +216,7 @@ class RunQueryStep(Step):
)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.medium._complete(
+ suggestion = await sdk.models.medium.complete(
dedent(
f"""\
```python
diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py
index 44065d22..d6769148 100644
--- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py
@@ -92,7 +92,7 @@ class LoadDataStep(Step):
docs = f.read()
output = "Traceback" + output.split("Traceback")[-1]
- suggestion = await sdk.models.default._complete(
+ suggestion = await sdk.models.default.complete(
dedent(
f"""\
When trying to load data into BigQuery, the following error occurred:
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index 4727c994..e2712746 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -45,7 +45,7 @@ class WritePytestsRecipe(Step):
Here is a complete set of pytest unit tests:"""
)
- tests = await sdk.models.medium._complete(prompt)
+ tests = await sdk.models.medium.complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md
index a248f19c..3f2f804c 100644
--- a/continuedev/src/continuedev/plugins/steps/README.md
+++ b/continuedev/src/continuedev/plugins/steps/README.md
@@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium._complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
Here's an example:
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index d580f886..15740057 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -83,7 +83,7 @@ class SimpleChatStep(Step):
messages = self.messages or await sdk.get_chat_context()
- generator = sdk.models.chat._stream_chat(
+ generator = sdk.models.chat.stream_chat(
messages, temperature=sdk.config.temperature
)
@@ -118,7 +118,7 @@ class SimpleChatStep(Step):
await sdk.update_ui()
self.name = add_ellipsis(
remove_quotes_and_escapes(
- await sdk.models.medium._complete(
+ await sdk.models.medium.complete(
f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:',
max_tokens=20,
)
@@ -260,7 +260,7 @@ class ChatWithFunctions(Step):
gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
await sdk.start_model(gpt350613)
- async for msg_chunk in gpt350613._stream_chat(
+ async for msg_chunk in gpt350613.stream_chat(
await sdk.get_chat_context(), functions=functions
):
if sdk.current_step_was_deleted():
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index 9ee2a48d..25633942 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -58,7 +58,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:"""
)
- answer = await sdk.models.medium._complete(prompt)
+ answer = await sdk.models.medium.complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index df29a01e..5325a918 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -107,7 +107,7 @@ class ShellCommandsStep(Step):
return f"Error when running shell commands:\n```\n{self._err_text}\n```"
cmds_str = "\n".join(self.cmds)
- return await models.medium._complete(
+ return await models.medium.complete(
f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:"
)
@@ -121,7 +121,7 @@ class ShellCommandsStep(Step):
and output is not None
and output_contains_error(output)
):
- suggestion = await sdk.models.medium._complete(
+ suggestion = await sdk.models.medium.complete(
dedent(
f"""\
While running the command `{cmd}`, the following error occurred:
@@ -220,7 +220,7 @@ class DefaultModelEditCodeStep(Step):
self._new_contents.splitlines(),
)
)
- description = await models.medium._complete(
+ description = await models.medium.complete(
dedent(
f"""\
Diff summary: "{self.user_input}"
@@ -232,7 +232,7 @@ class DefaultModelEditCodeStep(Step):
{self.summary_prompt}"""
)
)
- name = await models.medium._complete(
+ name = await models.medium.complete(
f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:"
)
self.name = remove_quotes_and_escapes(name)
@@ -663,7 +663,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user
else:
messages = rendered
- generator = model_to_use._stream_chat(
+ generator = model_to_use.stream_chat(
messages, temperature=sdk.config.temperature, max_tokens=max_tokens
)
@@ -874,7 +874,7 @@ class ManualEditStep(ReversibleStep):
return "Manual edit step"
# TODO - only handling FileEdit here, but need all other types of FileSystemEdits
# Also requires the merge_file_edit function
- # return llm._complete(dedent(f"""This code was replaced:
+ # return llm.complete(dedent(f"""This code was replaced:
# {self.edit_diff.backward.replacement}
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index c73d7eef..148dddb8 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -59,7 +59,7 @@ class HelpStep(Step):
ChatMessage(role="user", content=prompt, summary="Help")
)
messages = await sdk.get_chat_context()
- generator = sdk.models.default._stream_chat(messages)
+ generator = sdk.models.default.stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index 001876d0..721f1306 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -26,7 +26,7 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.default._complete(
+ gpt_parsed = await sdk.models.default.complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:"
)
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index cecd3a2a..ca15aaab 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -105,7 +105,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.medium._complete(prompt)
+ completion = await sdk.models.medium.complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -180,7 +180,7 @@ class StarCoderEditHighlightedCodeStep(Step):
_prompt_and_completion: str = ""
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return await models.medium._complete(
+ return await models.medium.complete(
f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:"
)
@@ -213,7 +213,7 @@ class StarCoderEditHighlightedCodeStep(Step):
segs = full_file_contents.split(rif.contents)
prompt = f"<file_prefix>{segs[0]}<file_suffix>{segs[1]}" + prompt
- completion = str(await sdk.models.starcoder._complete(prompt))
+ completion = str(await sdk.models.starcoder.complete(prompt))
eot_token = "<|endoftext|>"
completion = completion.removesuffix(eot_token)
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index 2ed2d3d7..a2612731 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -29,7 +29,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:"""
)
- resp = (await sdk.models.medium._complete(prompt)).lower()
+ resp = (await sdk.models.medium.complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 9317bfe1..04fb98b7 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -46,7 +46,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.medium._complete(
+ pattern = await sdk.models.medium.complete(
dedent(
f"""\
This is the user request:
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index aea20ef9..f0c33929 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -10,7 +10,6 @@ from fastapi.websockets import WebSocketState
from ..core.autopilot import Autopilot
from ..core.main import FullState
from ..libs.util.create_async_task import create_async_task
-from ..libs.util.errors import SessionNotFound
from ..libs.util.logging import logger
from ..libs.util.paths import (
getSessionFilePath,
@@ -124,7 +123,12 @@ class SessionManager:
# Read and update the sessions list
with open(getSessionsListFilePath(), "r") as f:
- sessions_list = json.load(f)
+ try:
+ sessions_list = json.load(f)
+ except json.JSONDecodeError:
+ raise Exception(
+ f"It looks like there is a JSON formatting error in your sessions.json file ({getSessionsListFilePath()}). Please fix this before creating a new session."
+ )
session_ids = [s["session_id"] for s in sessions_list]
if session_id not in session_ids:
diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py
new file mode 100644
index 00000000..91ddd33f
--- /dev/null
+++ b/continuedev/src/continuedev/tests/llm_test.py
@@ -0,0 +1,192 @@
+import asyncio
+import os
+from functools import wraps
+
+import pytest
+from dotenv import load_dotenv
+
+from continuedev.core.main import ChatMessage
+from continuedev.libs.llm import LLM, CompletionOptions
+from continuedev.libs.llm.anthropic import AnthropicLLM
+from continuedev.libs.llm.ggml import GGML
+from continuedev.libs.llm.openai import OpenAI
+from continuedev.libs.llm.together import TogetherLLM
+from continuedev.libs.util.count_tokens import DEFAULT_ARGS
+from continuedev.tests.util.openai_mock import start_openai
+
+load_dotenv()
+
+TEST_PROMPT = "Output a single word, that being the capital of Japan:"
+SPEND_MONEY = True
+
+
+def start_model(model):
+ def write_log(msg: str):
+ pass
+
+ asyncio.run(model.start(write_log=write_log, unique_id="test_unique_id"))
+
+
+def async_test(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ return asyncio.run(func(*args, **kwargs))
+
+ return wrapper
+
+
+class TestBaseLLM:
+ model = "gpt-3.5-turbo"
+ context_length = 4096
+ system_message = "test_system_message"
+
+ def setup_class(cls):
+ cls.llm = LLM(
+ model=cls.model,
+ context_length=cls.context_length,
+ system_message=cls.system_message,
+ )
+
+ start_model(cls.llm)
+
+ def test_llm_is_instance(self):
+ assert isinstance(self.llm, LLM)
+
+ def test_llm_collect_args(self):
+ options = CompletionOptions(model=self.model)
+ assert self.llm.collect_args(options) == {
+ **DEFAULT_ARGS,
+ "model": self.model,
+ }
+
+ @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")
+ @async_test
+ async def test_completion(self):
+ if self.llm.__class__.__name__ == "LLM":
+ pytest.skip("Skipping abstract LLM")
+
+ resp = await self.llm.complete(TEST_PROMPT, temperature=0.0)
+ assert isinstance(resp, str)
+ assert resp.strip().lower() == "tokyo"
+
+ @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")
+ @async_test
+ async def test_stream_chat(self):
+ if self.llm.__class__.__name__ == "LLM":
+ pytest.skip("Skipping abstract LLM")
+
+ completion = ""
+ role = None
+ async for chunk in self.llm.stream_chat(
+ messages=[
+ ChatMessage(role="user", content=TEST_PROMPT, summary=TEST_PROMPT)
+ ],
+ temperature=0.0,
+ ):
+ assert isinstance(chunk, dict)
+ if "content" in chunk:
+ completion += chunk["content"]
+ if "role" in chunk:
+ role = chunk["role"]
+
+ assert role == "assistant"
+ assert completion.strip().lower() == "tokyo"
+
+ @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")
+ @async_test
+ async def test_stream_complete(self):
+ if self.llm.__class__.__name__ == "LLM":
+ pytest.skip("Skipping abstract LLM")
+
+ completion = ""
+ async for chunk in self.llm.stream_complete(TEST_PROMPT, temperature=0.0):
+ assert isinstance(chunk, str)
+ completion += chunk
+
+ assert completion.strip().lower() == "tokyo"
+
+
+class TestOpenAI(TestBaseLLM):
+ def setup_class(cls):
+ super().setup_class(cls)
+ cls.llm = OpenAI(
+ model=cls.model,
+ context_length=cls.context_length,
+ system_message=cls.system_message,
+ api_key=os.environ["OPENAI_API_KEY"],
+ # api_base=f"http://localhost:{port}",
+ )
+ start_model(cls.llm)
+ # cls.server = start_openai(port=port)
+
+ # def teardown_class(cls):
+ # cls.server.terminate()
+
+ @pytest.mark.asyncio
+ @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money")
+ async def test_completion(self):
+ resp = await self.llm.complete(
+ "Output a single word, that being the capital of Japan:"
+ )
+ assert isinstance(resp, str)
+ assert resp.strip().lower() == "tokyo"
+
+
+class TestGGML(TestBaseLLM):
+ def setup_class(cls):
+ super().setup_class(cls)
+ port = 8000
+ cls.llm = GGML(
+ model=cls.model,
+ context_length=cls.context_length,
+ system_message=cls.system_message,
+ api_base=f"http://localhost:{port}",
+ )
+ start_model(cls.llm)
+ cls.server = start_openai(port=port)
+
+ def teardown_class(cls):
+ cls.server.terminate()
+
+ @pytest.mark.asyncio
+ async def test_stream_chat(self):
+ pytest.skip(reason="GGML is not working")
+
+ @pytest.mark.asyncio
+ async def test_stream_complete(self):
+ pytest.skip(reason="GGML is not working")
+
+ @pytest.mark.asyncio
+ async def test_completion(self):
+ pytest.skip(reason="GGML is not working")
+
+
+@pytest.mark.skipif(True, reason="Together is not working")
+class TestTogetherLLM(TestBaseLLM):
+ def setup_class(cls):
+ super().setup_class(cls)
+ cls.llm = TogetherLLM(
+ api_key=os.environ["TOGETHER_API_KEY"],
+ )
+ start_model(cls.llm)
+
+
+class TestAnthropicLLM(TestBaseLLM):
+ def setup_class(cls):
+ super().setup_class(cls)
+ cls.llm = AnthropicLLM(api_key=os.environ["ANTHROPIC_API_KEY"])
+ start_model(cls.llm)
+
+ def test_llm_collect_args(self):
+ options = CompletionOptions(model=self.model)
+ assert self.llm.collect_args(options) == {
+ "max_tokens_to_sample": DEFAULT_ARGS["max_tokens"],
+ "temperature": DEFAULT_ARGS["temperature"],
+ "model": self.model,
+ }
+
+
+if __name__ == "__main__":
+ import pytest
+
+ pytest.main()
diff --git a/continuedev/src/continuedev/tests/util/openai_mock.py b/continuedev/src/continuedev/tests/util/openai_mock.py
new file mode 100644
index 00000000..763c5647
--- /dev/null
+++ b/continuedev/src/continuedev/tests/util/openai_mock.py
@@ -0,0 +1,139 @@
+import asyncio
+import os
+import random
+import subprocess
+from typing import Dict, List, Optional
+
+from fastapi import FastAPI
+from fastapi.responses import StreamingResponse
+from pydantic import BaseModel
+
+openai = FastAPI()
+
+
+class CompletionBody(BaseModel):
+ prompt: str
+ max_tokens: Optional[int] = 60
+ stream: Optional[bool] = False
+
+ class Config:
+ extra = "allow"
+
+
+@openai.post("/completions")
+@openai.post("/v1/completions")
+async def mock_completion(item: CompletionBody):
+ prompt = item.prompt
+
+ text = "This is a fake completion."
+
+ if item.stream:
+
+ async def stream_text():
+ for i in range(len(text)):
+ word = random.choice(prompt.split())
+ yield {
+ "choices": [
+ {
+ "delta": {"role": "assistant", "content": word},
+ "finish_reason": None,
+ "index": 0,
+ }
+ ],
+ "created": 1677825464,
+ "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD",
+ "model": "gpt-3.5-turbo-0301",
+ "object": "chat.completion.chunk",
+ }
+ await asyncio.sleep(0.1)
+
+ return StreamingResponse(stream_text(), media_type="text/plain")
+
+ return {
+ "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
+ "object": "text_completion",
+ "created": 1589478378,
+ "model": "gpt-3.5-turbo",
+ "choices": [
+ {
+ "text": text,
+ "index": 0,
+ "logprobs": None,
+ "finish_reason": "length",
+ }
+ ],
+ "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12},
+ }
+
+
+class ChatBody(BaseModel):
+ messages: List[Dict[str, str]]
+ max_tokens: Optional[int] = None
+ stream: Optional[bool] = False
+
+ class Config:
+ extra = "allow"
+
+
+@openai.post("/v1/chat/completions")
+async def mock_chat_completion(item: ChatBody):
+ text = "This is a fake completion."
+
+ if item.stream:
+
+ async def stream_text():
+ for i in range(len(text)):
+ word = text[i]
+ yield {
+ "choices": [
+ {
+ "delta": {"role": "assistant", "content": word},
+ "finish_reason": None,
+ "index": 0,
+ }
+ ],
+ "created": 1677825464,
+ "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD",
+ "model": "gpt-3.5-turbo-0301",
+ "object": "chat.completion.chunk",
+ }
+ await asyncio.sleep(0.1)
+
+ return StreamingResponse(stream_text(), media_type="text/plain")
+
+ return {
+ "id": "chatcmpl-123",
+ "object": "chat.completion",
+ "created": 1677652288,
+ "model": "gpt-3.5-turbo-0613",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": text,
+ },
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
+ }
+
+
+def start_openai(port: int = 8000):
+ server = subprocess.Popen(
+ [
+ "uvicorn",
+ "openai_mock:openai",
+ "--host",
+ "127.0.0.1",
+ "--port",
+ str(port),
+ ],
+ cwd=os.path.dirname(__file__),
+ )
+ return server
+
+
+if __name__ == "__main__":
+ start_openai()
diff --git a/docs/docs/concepts/sdk.md b/docs/docs/concepts/sdk.md
index ac2bc8bf..21190aa8 100644
--- a/docs/docs/concepts/sdk.md
+++ b/docs/docs/concepts/sdk.md
@@ -23,7 +23,7 @@ The **Continue SDK** gives you all the tools you need to automate software devel
### `sdk.models`
-`sdk.models` is an instance of the `Models` class, containing many of the most commonly used LLMs or other foundation models. You can access a model (starcoder for example) like `starcoder = sdk.models.starcoder`. Right now, all of the models are `LLM`s, meaning that they offer the `complete` method, used like `bubble_sort_code = await starcoder._complete("# Write a bubble sort function below, in Python:\n")`.
+`sdk.models` is an instance of the `Models` class, containing many of the most commonly used LLMs or other foundation models. You can access a model (starcoder for example) like `starcoder = sdk.models.starcoder`. Right now, all of the models are `LLM`s, meaning that they offer the `complete` method, used like `bubble_sort_code = await starcoder.complete("# Write a bubble sort function below, in Python:\n")`.
### `sdk.history`
diff --git a/docs/docs/customization.md b/docs/docs/customization.md
index 09f7ed46..2d1d8ba4 100644
--- a/docs/docs/customization.md
+++ b/docs/docs/customization.md
@@ -292,7 +292,7 @@ class CommitMessageStep(Step):
# Ask the LLM to write a commit message,
# and set it as the description of this step
- self.description = await sdk.models.default._complete(
+ self.description = await sdk.models.default.complete(
f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:")
config=ContinueConfig(
diff --git a/docs/docs/walkthroughs/create-a-recipe.md b/docs/docs/walkthroughs/create-a-recipe.md
index cc80be0e..2cb28f77 100644
--- a/docs/docs/walkthroughs/create-a-recipe.md
+++ b/docs/docs/walkthroughs/create-a-recipe.md
@@ -31,7 +31,7 @@ If you'd like to override the default description of your steps, which is just t
- Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium._complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
## 2. Compose steps together into a complete recipe
diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx
index 1cbf3f0e..83f005c7 100644
--- a/extension/react-app/src/components/ModelSelect.tsx
+++ b/extension/react-app/src/components/ModelSelect.tsx
@@ -13,21 +13,13 @@ import { useSelector } from "react-redux";
const MODEL_INFO: { title: string; class: string; args: any }[] = [
{
title: "gpt-4",
- class: "MaybeProxyOpenAI",
+ class: "OpenAI",
args: {
model: "gpt-4",
api_key: "",
},
},
{
- title: "gpt-3.5-turbo",
- class: "MaybeProxyOpenAI",
- args: {
- model: "gpt-3.5-turbo",
- api_key: "",
- },
- },
- {
title: "claude-2",
class: "AnthropicLLM",
args: {