diff options
author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-09-03 21:58:46 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-03 21:58:46 -0700 |
commit | e645a89192b28cc16a1303bfa5551834c64ecb77 (patch) | |
tree | 6da1d0b5f59cef5c9fd9a615119742550fe1ad2c | |
parent | e49c6f55ae0c00bc660bbe885ea44f3a2fb1dc35 (diff) | |
download | sncontinue-e645a89192b28cc16a1303bfa5551834c64ecb77.tar.gz sncontinue-e645a89192b28cc16a1303bfa5551834c64ecb77.tar.bz2 sncontinue-e645a89192b28cc16a1303bfa5551834c64ecb77.zip |
refactor: :construction: Initial, not tested, refactor of LLM (#448)
* refactor: :construction: Initial, not tested, refactor of LLM
* refactor: :construction: replace usages of _complete with complete
* fix: :bug: fixes after refactor
* refactor: :recycle: template raw completions in chat format
* test: :white_check_mark: simplified edit prompt and UNIT TESTS!
* ci: :green_heart: unit tests in ci
* fix: :bug: fixes for unit tests in ci
* fix: :bug: start uvicorn in tests without poetry
* fix: :closed_lock_with_key: add secrets to main.yaml
* feat: :adhesive_bandage: timeout for all LLM classes
* ci: :green_heart: prepare main.yaml for main branch
44 files changed, 999 insertions, 832 deletions
diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 8c348024..49dde09c 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -39,6 +39,17 @@ jobs: run: | chmod 777 dist/run + - name: Test Python Server + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} + run: | + cd continuedev + pip install -r dev_requirements.txt + cd src + python -m pytest + - name: Upload Artifacts uses: actions/upload-artifact@v3 with: diff --git a/.vscode/launch.json b/.vscode/launch.json index 12cfaef8..08061d13 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,6 +17,15 @@ ], "configurations": [ { + "name": "Pytest", + "type": "python", + "request": "launch", + "program": "continuedev/src/continuedev/tests/llm_test.py", + "console": "integratedTerminal", + "justMyCode": true, + "subProcess": true + }, + { "name": "Server", "type": "python", "request": "launch", diff --git a/continuedev/dev_requirements.txt b/continuedev/dev_requirements.txt new file mode 100644 index 00000000..2fa7631b --- /dev/null +++ b/continuedev/dev_requirements.txt @@ -0,0 +1,2 @@ +pytest==7.4.1 +pytest-asyncio==0.21.1
\ No newline at end of file diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index aefc7cf9..700fd017 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -588,6 +588,20 @@ files = [ ] [[package]] +name = "exceptiongroup" +version = "1.1.3" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] name = "fastapi" version = "0.95.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" @@ -776,6 +790,17 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] name = "jsonref" version = "1.1.0" description = "jsonref is a library for automatic dereferencing of JSON Reference objects for Python." @@ -975,6 +1000,21 @@ files = [ ] [[package]] +name = "pluggy" +version = "1.3.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, + {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] name = "posthog" version = "3.0.1" description = "Integrate PostHog into any python application." @@ -1186,6 +1226,46 @@ files = [ ] [[package]] +name = "pytest" +version = "7.4.1" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.1-py3-none-any.whl", hash = "sha256:460c9a59b14e27c602eb5ece2e47bec99dc5fc5f6513cf924a7d03a578991b1f"}, + {file = "pytest-7.4.1.tar.gz", hash = "sha256:2f2301e797521b23e4d2585a0a3d7b5e50fdddaaf7e7d6773ea26ddb17c213ab"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-asyncio" +version = "0.21.1" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + +[[package]] name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" @@ -1436,6 +1516,20 @@ files = [ ] [[package]] +name = "sse-starlette" +version = "1.6.5" +description = "\"SSE plugin for Starlette\"" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sse-starlette-1.6.5.tar.gz", hash = "sha256:819f2c421fb37067380fe3dcaba246c476b02651b7bb7601099a378ad802a0ac"}, + {file = "sse_starlette-1.6.5-py3-none-any.whl", hash = "sha256:68b6b7eb49be0c72a2af80a055994c13afcaa4761b29226beb208f954c25a642"}, +] + +[package.dependencies] +starlette = "*" + +[[package]] name = "starlette" version = "0.26.1" description = "The little ASGI library that shines." @@ -1553,6 +1647,17 @@ docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] [[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" @@ -1905,4 +2010,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "fe4715494ed91c691ec1eb914373a612e75751e6685678e438b73193879de98d" +content-hash = "94cd41b3db68b5e0ef60115992042a3e00f7d80e6b4be19a55d6ed2c0089342e" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 8cdf1197..142c1d09 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -34,7 +34,11 @@ replicate = "^0.11.0" redbaron = "^0.9.2" [tool.poetry.scripts] -typegen = "src.continuedev.models.generate_json_schema:main" +typegen = "src.continuedev.models.generate_json_schema:main" + +[tool.poetry.group.dev.dependencies] +pytest = "^7.4.1" +pytest-asyncio = "^0.21.1" [build-system] requires = ["poetry-core"] diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index de0b8c53..bae82739 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -507,7 +507,7 @@ class Autopilot(ContinueBaseModel): if self.session_info is None: async def create_title(): - title = await self.continue_sdk.models.medium._complete( + title = await self.continue_sdk.models.medium.complete( f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ', max_tokens=20, ) diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py index f3bf8125..9816d5d9 100644 --- a/continuedev/src/continuedev/core/models.py +++ b/continuedev/src/continuedev/core/models.py @@ -5,12 +5,13 @@ from pydantic import BaseModel from ..libs.llm import LLM from ..libs.llm.anthropic import AnthropicLLM from ..libs.llm.ggml import GGML +from ..libs.llm.llamacpp import LlamaCpp from ..libs.llm.maybe_proxy_openai import MaybeProxyOpenAI from ..libs.llm.ollama import Ollama from ..libs.llm.openai import OpenAI from ..libs.llm.replicate import ReplicateLLM from ..libs.llm.together import TogetherLLM -from ..libs.llm.llamacpp import LlamaCpp + class ContinueSDK(BaseModel): pass @@ -35,7 +36,7 @@ MODEL_CLASSES = { AnthropicLLM, ReplicateLLM, Ollama, - LlamaCpp + LlamaCpp, ] } diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 6a321a41..059bf279 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,12 +1,58 @@ -from abc import ABC from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from pydantic import validator + from ...core.main import ChatMessage from ...models.main import ContinueBaseModel -from ..util.count_tokens import DEFAULT_ARGS, count_tokens +from ..util.count_tokens import ( + DEFAULT_ARGS, + DEFAULT_MAX_TOKENS, + compile_chat_messages, + count_tokens, + format_chat_messages, + prune_raw_prompt_from_top, +) + + +class CompletionOptions(ContinueBaseModel): + """Options for the completion.""" + + @validator( + "*", + pre=True, + always=True, + ) + def ignore_none_and_set_default(cls, value, field): + return value if value is not None else field.default + + model: str = None + "The model name" + temperature: Optional[float] = None + "The temperature of the completion." + + top_p: Optional[float] = None + "The top_p of the completion." + + top_k: Optional[int] = None + "The top_k of the completion." + + presence_penalty: Optional[float] = None + "The presence penalty of the completion." + + frequency_penalty: Optional[float] = None + "The frequency penalty of the completion." + + stop: Optional[List[str]] = None + "The stop tokens of the completion." + + max_tokens: int = DEFAULT_MAX_TOKENS + "The maximum number of tokens to generate." + + functions: Optional[List[Any]] = None + "The functions/tools to make available to the model." -class LLM(ContinueBaseModel, ABC): +class LLM(ContinueBaseModel): title: Optional[str] = None system_message: Optional[str] = None @@ -19,8 +65,14 @@ class LLM(ContinueBaseModel, ABC): model: str "The model name" + timeout: int = 300 + "The timeout for the request in seconds." + prompt_templates: dict = {} + template_messages: Optional[Callable[[List[Dict[str, str]]], str]] = None + "A function that takes a list of messages and returns a prompt." + write_log: Optional[Callable[[str], None]] = None "A function that takes a string and writes it to the log." @@ -33,16 +85,12 @@ class LLM(ContinueBaseModel, ABC): def dict(self, **kwargs): original_dict = super().dict(**kwargs) - original_dict.pop("write_log", None) + original_dict.pop("write_log") + original_dict.pop("template_messages") + original_dict.pop("unique_id") original_dict["class_name"] = self.__class__.__name__ return original_dict - def collect_args(self, **kwargs) -> Any: - """Collect the arguments for the LLM.""" - args = {**DEFAULT_ARGS.copy(), "model": self.model, "max_tokens": 1024} - args.update(kwargs) - return args - async def start( self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None ): @@ -54,23 +102,194 @@ class LLM(ContinueBaseModel, ABC): """Stop the connection to the LLM.""" pass - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs + def collect_args(self, options: CompletionOptions) -> Dict[str, Any]: + """Collect the arguments for the LLM.""" + args = {**DEFAULT_ARGS.copy(), "model": self.model} + args.update(options.dict(exclude_unset=True, exclude_none=True)) + return args + + def compile_chat_messages( + self, + options: CompletionOptions, + msgs: List[ChatMessage], + functions: Optional[List[Any]] = None, + ) -> List[Dict]: + return compile_chat_messages( + model_name=options.model, + msgs=msgs, + context_length=self.context_length, + max_tokens=options.max_tokens, + functions=functions, + system_message=self.system_message, + ) + + def template_prompt_like_messages(self, prompt: str) -> str: + if self.template_messages is None: + return prompt + + msgs = [{"role": "user", "content": prompt}] + if self.system_message is not None: + msgs.insert(0, {"role": "system", "content": self.system_message}) + + return self.template_messages(msgs) + + async def stream_complete( + self, + prompt: str, + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, + ) -> Generator[Union[Any, List, Dict], None, None]: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + stop=stop, + max_tokens=max_tokens, + functions=functions, + ) + + prompt = prune_raw_prompt_from_top( + self.model, self.context_length, prompt, options.max_tokens + ) + prompt = self.template_prompt_like_messages(prompt) + + self.write_log(f"Prompt: \n\n{prompt}") + + completion = "" + async for chunk in self._stream_complete(prompt=prompt, options=options): + yield chunk + completion += chunk + + self.write_log(f"Completion: \n\n{completion}") + + async def complete( + self, + prompt: str, + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, ) -> Coroutine[Any, Any, str]: - """Return the completion of the text with the given temperature.""" - raise NotImplementedError + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + stop=stop, + max_tokens=max_tokens, + functions=functions, + ) - def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs + prompt = prune_raw_prompt_from_top( + self.model, self.context_length, prompt, options.max_tokens + ) + prompt = self.template_prompt_like_messages(prompt) + + self.write_log(f"Prompt: \n\n{prompt}") + + completion = await self._complete(prompt=prompt, options=options) + + self.write_log(f"Completion: \n\n{completion}") + return completion + + async def stream_chat( + self, + messages: List[ChatMessage], + model: str = None, + temperature: float = None, + top_p: float = None, + top_k: int = None, + presence_penalty: float = None, + frequency_penalty: float = None, + stop: Optional[List[str]] = None, + max_tokens: Optional[int] = None, + functions: Optional[List[Any]] = None, ) -> Generator[Union[Any, List, Dict], None, None]: + """Yield completion response, either streamed or not.""" + options = CompletionOptions( + model=model or self.model, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + stop=stop, + max_tokens=max_tokens, + functions=functions, + ) + + messages = self.compile_chat_messages( + options=options, msgs=messages, functions=functions + ) + if self.template_messages is not None: + prompt = self.template_messages(messages) + else: + prompt = format_chat_messages(messages) + + self.write_log(f"Prompt: \n\n{prompt}") + + completion = "" + + # Use the template_messages function if it exists and do a raw completion + if self.template_messages is None: + async for chunk in self._stream_chat(messages=messages, options=options): + yield chunk + if "content" in chunk: + completion += chunk["content"] + else: + async for chunk in self._stream_complete(prompt=prompt, options=options): + yield {"role": "assistant", "content": chunk} + completion += chunk + + self.write_log(f"Completion: \n\n{completion}") + + def _stream_complete( + self, prompt, options: CompletionOptions + ) -> Generator[str, None, None]: """Stream the completion through generator.""" raise NotImplementedError + async def _complete( + self, prompt: str, options: CompletionOptions + ) -> Coroutine[Any, Any, str]: + """Return the completion of the text with the given temperature.""" + completion = "" + async for chunk in self._stream_complete(prompt=prompt, options=options): + completion += chunk + return completion + async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs + self, messages: List[ChatMessage], options: CompletionOptions ) -> Generator[Union[Any, List, Dict], None, None]: """Stream the chat through generator.""" - raise NotImplementedError + if self.template_messages is None: + raise NotImplementedError( + "You must either implement template_messages or _stream_chat" + ) + + async for chunk in self._stream_complete( + prompt=self.template_messages(messages), options=options + ): + yield {"role": "assistant", "content": chunk} def count_tokens(self, text: str): """Return the number of tokens in the given text.""" diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index b5aff63a..70a0868c 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,10 +1,9 @@ -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Any, Callable, Coroutine from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic -from ...core.main import ChatMessage -from ..llm import LLM -from ..util.count_tokens import compile_chat_messages +from ..llm import LLM, CompletionOptions +from .prompts.chat import anthropic_template_messages class AnthropicLLM(LLM): @@ -15,21 +14,21 @@ class AnthropicLLM(LLM): _async_client: AsyncAnthropic = None + template_messages: Callable = anthropic_template_messages + class Config: arbitrary_types_allowed = True - async def start( - self, - **kwargs, - ): + async def start(self, **kwargs): await super().start(**kwargs) self._async_client = AsyncAnthropic(api_key=self.api_key) if self.model == "claude-2": self.context_length = 100_000 - def collect_args(self, **kwargs) -> Any: - args = super().collect_args(**kwargs) + def collect_args(self, options: CompletionOptions): + options.stop = None + args = super().collect_args(options) if "max_tokens" in args: args["max_tokens_to_sample"] = args["max_tokens"] @@ -40,85 +39,18 @@ class AnthropicLLM(LLM): del args["presence_penalty"] return args - def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: - prompt = "" - - # Anthropic prompt must start with a Human turn - if ( - len(messages) > 0 - and messages[0]["role"] != "user" - and messages[0]["role"] != "system" - ): - prompt += f"{HUMAN_PROMPT} Hello." - for msg in messages: - prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " - - prompt += AI_PROMPT - return prompt - - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True + async def _stream_complete(self, prompt: str, options): + args = self.collect_args(options) prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" - self.write_log(f"Prompt: \n\n{prompt}") - completion = "" async for chunk in await self._async_client.completions.create( - prompt=prompt, **args + prompt=prompt, stream=True, **args ): yield chunk.completion - completion += chunk.completion - - self.write_log(f"Completion: \n\n{completion}") - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True - - messages = compile_chat_messages( - args["model"], - messages, - self.context_length, - args["max_tokens_to_sample"], - functions=args.get("functions", None), - system_message=self.system_message, - ) - completion = "" - prompt = self.__messages_to_prompt(messages) - self.write_log(f"Prompt: \n\n{prompt}") - async for chunk in await self._async_client.completions.create( - prompt=prompt, **args - ): - yield {"role": "assistant", "content": chunk.completion} - completion += chunk.completion - - self.write_log(f"Completion: \n\n{completion}") - - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) - - messages = compile_chat_messages( - args["model"], - with_history, - self.context_length, - args["max_tokens_to_sample"], - prompt, - functions=None, - system_message=self.system_message, - ) - - prompt = self.__messages_to_prompt(messages) - self.write_log(f"Prompt: \n\n{prompt}") - resp = ( + async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + args = self.collect_args(options) + prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" + return ( await self._async_client.completions.create(prompt=prompt, **args) ).completion - - self.write_log(f"Completion: \n\n{resp}") - return resp diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index 70c265d9..e4971867 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -1,11 +1,12 @@ import json -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Coroutine, List, Optional import aiohttp from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, format_chat_messages +from ..util.logging import logger +from .prompts.edit import simplified_edit_prompt class GGML(LLM): @@ -13,61 +14,44 @@ class GGML(LLM): verify_ssl: Optional[bool] = None model: str = "ggml" - timeout: int = 300 + prompt_templates = { + "edit": simplified_edit_prompt, + } class Config: arbitrary_types_allowed = True - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) - messages = compile_chat_messages( - self.model, - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=args.get("functions", None), - system_message=self.system_message, - ) - - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - completion = "" async with aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), timeout=aiohttp.ClientTimeout(total=self.timeout), ) as client_session: async with client_session.post( - f"{self.server_url}/v1/completions", json={"messages": messages, **args} + f"{self.server_url}/v1/completions", + json={ + "prompt": prompt, + "stream": True, + **args, + }, ) as resp: async for line in resp.content.iter_any(): if line: - try: - chunk = line.decode("utf-8") - yield chunk - completion += chunk - except: - raise Exception(str(line)) + chunk = line.decode("utf-8") + if chunk.startswith(": ping - ") or chunk.startswith( + "data: [DONE]" + ): + continue + elif chunk.startswith("data: "): + chunk = chunk[6:] - self.write_log(f"Completion: \n\n{completion}") + j = json.loads(chunk) + if "choices" in j: + yield j["choices"][0]["text"] - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - messages, - self.context_length, - args["max_tokens"], - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - args["stream"] = True + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) async def generator(): async with aiohttp.ClientSession( @@ -76,10 +60,9 @@ class GGML(LLM): ) as client_session: async with client_session.post( f"{self.server_url}/v1/chat/completions", - json={"messages": messages, **args}, + json={"messages": messages, "stream": True, **args}, headers={"Content-Type": "application/json"}, ) as resp: - # This is streaming application/json instaed of text/event-stream async for line, end in resp.content.iter_chunks(): json_chunk = line.decode("utf-8") chunks = json_chunk.split("\n") @@ -91,34 +74,22 @@ class GGML(LLM): ): continue try: - yield json.loads(chunk[6:])["choices"][0][ - "delta" - ] # {"role": "assistant", "content": "..."} + yield json.loads(chunk[6:])["choices"][0]["delta"] except: pass # Because quite often the first attempt fails, and it works thereafter - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - completion = "" try: async for chunk in generator(): yield chunk - if "content" in chunk: - completion += chunk["content"] - except: + except Exception as e: + logger.warning(f"Error calling /chat/completions endpoint: {e}") async for chunk in generator(): yield chunk - if "content" in chunk: - completion += chunk["content"] - - self.write_log(f"Completion: \n\n{completion}") - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) + async def _complete(self, prompt: str, options) -> Coroutine[Any, Any, str]: + args = self.collect_args(options) - self.write_log(f"Prompt: \n\n{prompt}") async with aiohttp.ClientSession( connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), timeout=aiohttp.ClientTimeout(total=self.timeout), @@ -133,7 +104,6 @@ class GGML(LLM): text = await resp.text() try: completion = json.loads(text)["choices"][0]["text"] - self.write_log(f"Completion: \n\n{completion}") return completion except Exception as e: raise Exception( diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 3a586a43..43aac148 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,10 +1,11 @@ -from typing import Any, Coroutine, Dict, Generator, List, Optional +from typing import List, Optional import aiohttp import requests from ...core.main import ChatMessage from ..llm import LLM +from .prompts.edit import simplified_edit_prompt DEFAULT_MAX_TIME = 120.0 @@ -17,21 +18,24 @@ class HuggingFaceInferenceAPI(LLM): _client_session: aiohttp.ClientSession = None + prompt_templates = { + "edit": simplified_edit_prompt, + } + class Config: arbitrary_types_allowed = True async def start(self, **kwargs): await super().start(**kwargs) self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + timeout=aiohttp.ClientTimeout(total=self.timeout), ) async def stop(self): await self._client_session.close() - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ): + async def _complete(self, prompt: str, options): """Return the completion of the text with the given temperature.""" API_URL = ( self.base_url or f"https://api-inference.huggingface.co/models/{self.model}" @@ -60,14 +64,10 @@ class HuggingFaceInferenceAPI(LLM): return data[0]["generated_text"] - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: + async def _stream_chat(self, messages: List[ChatMessage], options): response = await self._complete(messages[-1].content, messages[:-1]) yield {"content": response, "role": "assistant"} - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Any | List | Dict, None, None]: - response = await self._complete(prompt, with_history) + async def _stream_complete(self, prompt, options): + response = await self._complete(prompt, options) yield response diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py index 5c7e0239..508ebe87 100644 --- a/continuedev/src/continuedev/libs/llm/hf_tgi.py +++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py @@ -1,12 +1,12 @@ import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Callable, List, Optional import aiohttp from ...core.main import ChatMessage -from ..llm import LLM -from ..util.count_tokens import compile_chat_messages +from ..llm import LLM, CompletionOptions from .prompts.chat import code_llama_template_messages +from .prompts.edit import simplified_edit_prompt class HuggingFaceTGI(LLM): @@ -16,11 +16,15 @@ class HuggingFaceTGI(LLM): template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages + prompt_templates = { + "edit": simplified_edit_prompt, + } + class Config: arbitrary_types_allowed = True - def collect_args(self, **kwargs) -> Any: - args = super().collect_args(**kwargs) + def collect_args(self, options: CompletionOptions) -> Any: + args = super().collect_args(options) args = { **args, "max_new_tokens": args.get("max_tokens", 1024), @@ -28,31 +32,16 @@ class HuggingFaceTGI(LLM): args.pop("max_tokens", None) return args - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True - - messages = compile_chat_messages( - self.model, - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=args.get("functions", None), - system_message=self.system_message, - ) + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) - prompt = self.template_messages(messages) - self.write_log(f"Prompt: \n\n{prompt}") - completion = "" async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + timeout=aiohttp.ClientTimeout(total=self.timeout), ) as client_session: async with client_session.post( f"{self.server_url}", - json={"inputs": prompt, **args}, + json={"inputs": prompt, "stream": True, **args}, ) as resp: async for line in resp.content.iter_any(): if line: @@ -62,39 +51,3 @@ class HuggingFaceTGI(LLM): "generated_text" ] yield text - completion += text - - self.write_log(f"Completion: \n\n{completion}") - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - messages, - self.context_length, - args["max_tokens"], - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - - async for chunk in self._stream_complete( - None, self.template_messages(messages), **args - ): - yield { - "role": "assistant", - "content": chunk, - } - - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) - - completion = "" - async for chunk in self._stream_complete(prompt, with_history, **args): - completion += chunk - - return completion diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py index f246a43c..c2e934c0 100644 --- a/continuedev/src/continuedev/libs/llm/hugging_face.py +++ b/continuedev/src/continuedev/libs/llm/hugging_face.py @@ -1,3 +1,5 @@ +# TODO: This class is far out of date + from transformers import AutoModelForCausalLM, AutoTokenizer from .llm import LLM diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py index 9d8b548f..3596fd99 100644 --- a/continuedev/src/continuedev/libs/llm/llamacpp.py +++ b/continuedev/src/continuedev/libs/llm/llamacpp.py @@ -1,13 +1,12 @@ import asyncio import json -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Dict, Optional import aiohttp -from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt class LlamaCpp(LLM): @@ -15,21 +14,20 @@ class LlamaCpp(LLM): server_url: str = "http://localhost:8080" verify_ssl: Optional[bool] = None - template_messages: Callable[[List[ChatMessage]], str] = llama2_template_messages llama_cpp_args: Dict[str, Any] = {"stop": ["[INST]"]} use_command: Optional[str] = None + template_messages: Callable = llama2_template_messages + prompt_templates = { + "edit": simplified_edit_prompt, + } + class Config: arbitrary_types_allowed = True - def dict(self, **kwargs): - d = super().dict(**kwargs) - d.pop("template_messages") - return d - - def collect_args(self, **kwargs) -> Any: - args = super().collect_args(**kwargs) + def collect_args(self, options) -> Any: + args = super().collect_args(options) if "max_tokens" in args: args["n_predict"] = args["max_tokens"] del args["max_tokens"] @@ -61,119 +59,30 @@ class LlamaCpp(LLM): await process.wait() - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True - - messages = compile_chat_messages( - self.model, - with_history, - self.context_length, - args["n_predict"] if "n_predict" in args else 1024, - prompt, - functions=args.get("functions", None), - system_message=self.system_message, - ) - - prompt = self.convert_to_chat(messages) - self.write_log(f"Prompt: \n\n{prompt}") - completion = "" - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) - ) as client_session: - async with client_session.post( - f"{self.server_url}/completion", - json={ - "prompt": prompt, - **args, - }, - headers={"Content-Type": "application/json"}, - ) as resp: - async for line in resp.content.iter_any(): - if line: - chunk = line.decode("utf-8") - yield chunk - completion += chunk - - self.write_log(f"Completion: \n\n{completion}") - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - messages, - self.context_length, - args["n_predict"] if "n_predict" in args else 1024, - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - args["stream"] = True - - prompt = self.template_messages(messages) + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) headers = {"Content-Type": "application/json"} async def server_generator(): async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + timeout=aiohttp.ClientTimeout(total=self.timeout), ) as client_session: async with client_session.post( f"{self.server_url}/completion", - json={"prompt": prompt, **args}, + json={"prompt": prompt, "stream": True, **args}, headers=headers, ) as resp: async for line in resp.content: content = line.decode("utf-8") if content.strip() == "": continue - yield { - "content": json.loads(content[6:])["content"], - "role": "assistant", - } + yield json.loads(content[6:])["content"] async def command_generator(): async for line in self.stream_from_main(prompt): - yield {"content": line, "role": "assistant"} + yield line generator = command_generator if self.use_command else server_generator - - # Because quite often the first attempt fails, and it works thereafter - self.write_log(f"Prompt: \n\n{prompt}") - completion = "" async for chunk in generator(): yield chunk - if "content" in chunk: - completion += chunk["content"] - - self.write_log(f"Completion: \n\n{completion}") - - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) - - self.write_log(f"Prompt: \n\n{prompt}") - - if self.use_command: - completion = "" - async for line in self.stream_from_main(prompt): - completion += line - self.write_log(f"Completion: \n\n{completion}") - return completion - else: - async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) - ) as client_session: - async with client_session.post( - f"{self.server_url}/completion", - json={"prompt": prompt, **args}, - headers={"Content-Type": "application/json"}, - ) as resp: - json_resp = await resp.json() - completion = json_resp["content"] - self.write_log(f"Completion: \n\n{completion}") - return completion diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py index 99b7c47f..084c57fd 100644 --- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py +++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Callable, List, Optional from ...core.main import ChatMessage from . import LLM @@ -7,7 +7,6 @@ from .proxy_server import ProxyServer class MaybeProxyOpenAI(LLM): - model: str api_key: Optional[str] = None llm: Optional[LLM] = None @@ -16,36 +15,33 @@ class MaybeProxyOpenAI(LLM): if self.llm is not None: self.llm.system_message = self.system_message - async def start(self, **kwargs): + async def start( + self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None + ): + await super().start(write_log=lambda *args, **kwargs: None, unique_id=unique_id) if self.api_key is None or self.api_key.strip() == "": self.llm = ProxyServer(model=self.model) else: self.llm = OpenAI(api_key=self.api_key, model=self.model) - await self.llm.start(**kwargs) + await self.llm.start(write_log=write_log, unique_id=unique_id) async def stop(self): await self.llm.stop() - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: + async def _complete(self, prompt: str, options): self.update_llm_properties() - return await self.llm._complete(prompt, with_history=with_history, **kwargs) + return await self.llm._complete(prompt, options) - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: + async def _stream_complete(self, prompt, options): self.update_llm_properties() - resp = self.llm._stream_complete(prompt, with_history=with_history, **kwargs) + resp = self.llm._stream_complete(prompt, options) async for item in resp: yield item - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: + async def _stream_chat(self, messages: List[ChatMessage], options): self.update_llm_properties() - resp = self.llm._stream_chat(messages=messages, **kwargs) + resp = self.llm._stream_chat(messages=messages, options=options) async for item in resp: yield item diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index ef8ed47b..d0da281a 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -1,14 +1,11 @@ import json -import urllib.parse -from textwrap import dedent -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import Callable import aiohttp -from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt class Ollama(LLM): @@ -17,18 +14,10 @@ class Ollama(LLM): _client_session: aiohttp.ClientSession = None + template_messages: Callable = llama2_template_messages + prompt_templates = { - "edit": dedent( - """\ - [INST] Consider the following code: - ``` - {{code_to_edit}} - ``` - Edit the code to perfectly satisfy the following user request: - {{user_input}} - Output nothing except for the code. No code block, no English explanation, no start/end tags. - [/INST]""" - ), + "edit": simplified_edit_prompt, } class Config: @@ -36,36 +25,23 @@ class Ollama(LLM): async def start(self, **kwargs): await super().start(**kwargs) - self._client_session = aiohttp.ClientSession() + self._client_session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.timeout) + ) async def stop(self): await self._client_session.close() - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=None, - system_message=self.system_message, - ) - prompt = llama2_template_messages(messages) - + async def _stream_complete(self, prompt, options): async with self._client_session.post( f"{self.server_url}/api/generate", json={ "template": prompt, "model": self.model, "system": self.system_message, - "options": {"temperature": args["temperature"]}, + "options": {"temperature": options.temperature}, }, ) as resp: - url_decode_buffer = "" async for line in resp.content.iter_any(): if line: json_chunk = line.decode("utf-8") @@ -74,80 +50,4 @@ class Ollama(LLM): if chunk.strip() != "": j = json.loads(chunk) if "response" in j: - url_decode_buffer += j["response"] - - if ( - "&" in url_decode_buffer - and url_decode_buffer.index("&") - > len(url_decode_buffer) - 5 - ): - continue - yield urllib.parse.unquote(url_decode_buffer) - url_decode_buffer = "" - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - messages, - self.context_length, - args["max_tokens"], - None, - functions=None, - system_message=self.system_message, - ) - prompt = llama2_template_messages(messages) - - self.write_log(f"Prompt:\n{prompt}") - completion = "" - async with self._client_session.post( - f"{self.server_url}/api/generate", - json={ - "template": prompt, - "model": self.model, - "system": self.system_message, - "options": {"temperature": args["temperature"]}, - }, - ) as resp: - async for line in resp.content.iter_chunks(): - if line[1]: - json_chunk = line[0].decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "response" in j: - yield { - "role": "assistant", - "content": j["response"], - } - completion += j["response"] - self.write_log(f"Completion:\n{completion}") - - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - completion = "" - args = self.collect_args(**kwargs) - async with self._client_session.post( - f"{self.server_url}/api/generate", - json={ - "template": prompt, - "model": self.model, - "system": self.system_message, - "options": {"temperature": args["temperature"]}, - }, - ) as resp: - async for line in resp.content.iter_any(): - if line: - json_chunk = line.decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "response" in j: - completion += urllib.parse.unquote(j["response"]) - - return completion + yield j["response"] diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a017af22..2074ae04 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,25 +1,10 @@ -from typing import ( - Any, - Callable, - Coroutine, - Dict, - Generator, - List, - Literal, - Optional, - Union, -) +from typing import Callable, List, Literal, Optional import certifi import openai from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import ( - compile_chat_messages, - format_chat_messages, - prune_raw_prompt_from_top, -) CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"} MAX_TOKENS_FOR_MODEL = { @@ -82,118 +67,53 @@ class OpenAI(LLM): openai.ca_bundle_path = self.ca_bundle_path or certifi.where() - def collect_args(self, **kwargs): - args = super().collect_args() + def collect_args(self, options): + args = super().collect_args(options) if self.engine is not None: args["engine"] = self.engine + + if not args["model"].endswith("0613") and "functions" in args: + del args["functions"] + return args - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) args["stream"] = True if args["model"] in CHAT_MODELS: - messages = compile_chat_messages( - args["model"], - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=None, - system_message=self.system_message, - ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - completion = "" async for chunk in await openai.ChatCompletion.acreate( - messages=messages, + messages=[{"role": "user", "content": prompt}], **args, ): if "content" in chunk.choices[0].delta: yield chunk.choices[0].delta.content - completion += chunk.choices[0].delta.content - else: - continue # :) - - self.write_log(f"Completion: \n\n{completion}") else: - self.write_log(f"Prompt:\n\n{prompt}") - completion = "" async for chunk in await openai.Completion.acreate(prompt=prompt, **args): yield chunk.choices[0].text - completion += chunk.choices[0].text - - self.write_log(f"Completion:\n\n{completion}") - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True - - if not args["model"].endswith("0613") and "functions" in args: - del args["functions"] + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) - messages = compile_chat_messages( - args["model"], - messages, - self.context_length, - args["max_tokens"], - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - completion = "" async for chunk in await openai.ChatCompletion.acreate( messages=messages, + stream=True, **args, ): if len(chunk.choices) == 0: continue yield chunk.choices[0].delta - if "content" in chunk.choices[0].delta: - completion += chunk.choices[0].delta.content - self.write_log(f"Completion: \n\n{completion}") - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) + async def _complete(self, prompt: str, options): + args = self.collect_args(options) if args["model"] in CHAT_MODELS: - messages = compile_chat_messages( - args["model"], - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=None, - system_message=self.system_message, - ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") resp = await openai.ChatCompletion.acreate( - messages=messages, + messages=[{"role": "user", "content": prompt}], **args, ) - completion = resp.choices[0].message.content - self.write_log(f"Completion: \n\n{completion}") + return resp.choices[0].message.content else: - prompt = prune_raw_prompt_from_top( - args["model"], self.context_length, prompt, args["max_tokens"] - ) - self.write_log(f"Prompt:\n\n{prompt}") - completion = ( - ( - await openai.Completion.acreate( - prompt=prompt, - **args, - ) - ) - .choices[0] - .text + return ( + (await openai.Completion.acreate(prompt=prompt, **args)).choices[0].text ) - self.write_log(f"Completion:\n\n{completion}") - - return completion diff --git a/continuedev/src/continuedev/libs/llm/prompts/chat.py b/continuedev/src/continuedev/libs/llm/prompts/chat.py index c7c208c0..1329a2ff 100644 --- a/continuedev/src/continuedev/libs/llm/prompts/chat.py +++ b/continuedev/src/continuedev/libs/llm/prompts/chat.py @@ -1,9 +1,27 @@ from textwrap import dedent +from typing import Dict, List -from ....core.main import ChatMessage +from anthropic import AI_PROMPT, HUMAN_PROMPT -def llama2_template_messages(msgs: ChatMessage) -> str: +def anthropic_template_messages(messages: List[Dict[str, str]]) -> str: + prompt = "" + + # Anthropic prompt must start with a Human turn + if ( + len(messages) > 0 + and messages[0]["role"] != "user" + and messages[0]["role"] != "system" + ): + prompt += f"{HUMAN_PROMPT} Hello." + for msg in messages: + prompt += f"{HUMAN_PROMPT if (msg['role'] == 'user' or msg['role'] == 'system') else AI_PROMPT} {msg['content']} " + + prompt += AI_PROMPT + return prompt + + +def llama2_template_messages(msgs: List[Dict[str, str]]) -> str: if len(msgs) == 0: return "" @@ -38,20 +56,20 @@ def llama2_template_messages(msgs: ChatMessage) -> str: if msgs[i]["role"] == "user": prompt += f"[INST] {msgs[i]['content']} [/INST]" else: - prompt += msgs[i]["content"] + prompt += msgs[i]["content"] + " " return prompt -def code_llama_template_messages(msgs: ChatMessage) -> str: +def code_llama_template_messages(msgs: List[Dict[str, str]]) -> str: return f"[INST] {msgs[-1]['content']}\n[/INST]" -def extra_space_template_messages(msgs: ChatMessage) -> str: +def extra_space_template_messages(msgs: List[Dict[str, str]]) -> str: return f" {msgs[-1]['content']}" -def code_llama_python_template_messages(msgs: ChatMessage) -> str: +def code_llama_python_template_messages(msgs: List[Dict[str, str]]) -> str: return dedent( f"""\ [INST] diff --git a/continuedev/src/continuedev/libs/llm/prompts/edit.py b/continuedev/src/continuedev/libs/llm/prompts/edit.py new file mode 100644 index 00000000..a234fa61 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/prompts/edit.py @@ -0,0 +1,13 @@ +from textwrap import dedent + +simplified_edit_prompt = dedent( + """\ + [INST] Consider the following code: + ``` + {{code_to_edit}} + ``` + Edit the code to perfectly satisfy the following user request: + {{user_input}} + Output nothing except for the code. No code block, no English explanation, no start/end tags. + [/INST]""" +) diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 3ac6371f..d62fafa7 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,14 +1,13 @@ import json import ssl import traceback -from typing import Any, Coroutine, Dict, Generator, List, Union +from typing import List import aiohttp import certifi from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, format_chat_messages from ..util.telemetry import posthog_logger ca_bundle_path = certifi.where() @@ -37,7 +36,8 @@ class ProxyServer(LLM): ): await super().start(**kwargs) self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl_context=ssl_context) + connector=aiohttp.TCPConnector(ssl_context=ssl_context), + timeout=aiohttp.ClientTimeout(total=self.timeout), ) self.context_length = MAX_TOKENS_FOR_MODEL[self.model] @@ -45,60 +45,32 @@ class ProxyServer(LLM): await self._client_session.close() def get_headers(self): - # headers with unique id return {"unique_id": self.unique_id} - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) - - messages = compile_chat_messages( - args["model"], - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=None, - system_message=self.system_message, - ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + async def _complete(self, prompt: str, options): + args = self.collect_args(options) + async with self._client_session.post( f"{SERVER_URL}/complete", - json={"messages": messages, **args}, + json={"messages": [{"role": "user", "content": prompt}], **args}, headers=self.get_headers(), ) as resp: + resp_text = await resp.text() if resp.status != 200: - raise Exception(await resp.text()) + raise Exception(resp_text) - response_text = await resp.text() - self.write_log(f"Completion: \n\n{response_text}") - return response_text - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - args["model"], - messages, - self.context_length, - args["max_tokens"], - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + return resp_text + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) async with self._client_session.post( f"{SERVER_URL}/stream_chat", json={"messages": messages, **args}, headers=self.get_headers(), ) as resp: - # This is streaming application/json instaed of text/event-stream - completion = "" if resp.status != 200: raise Exception(await resp.text()) + async for line in resp.content.iter_chunks(): if line[1]: try: @@ -109,8 +81,7 @@ class ProxyServer(LLM): if chunk.strip() != "": loaded_chunk = json.loads(chunk) yield loaded_chunk - if "content" in loaded_chunk: - completion += loaded_chunk["content"] + except Exception as e: posthog_logger.capture_event( "proxy_server_parse_error", @@ -124,37 +95,18 @@ class ProxyServer(LLM): else: break - self.write_log(f"Completion: \n\n{completion}") - - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=args.get("functions", None), - system_message=self.system_message, - ) - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) async with self._client_session.post( f"{SERVER_URL}/stream_complete", - json={"messages": messages, **args}, + json={"messages": [{"role": "user", "content": prompt}], **args}, headers=self.get_headers(), ) as resp: - completion = "" if resp.status != 200: raise Exception(await resp.text()) + async for line in resp.content.iter_any(): if line: - try: - decoded_line = line.decode("utf-8") - yield decoded_line - completion += decoded_line - except: - raise Exception(str(line)) - self.write_log(f"Completion: \n\n{completion}") + decoded_line = line.decode("utf-8") + yield decoded_line diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index fb0d3f5c..1ed493c1 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -5,6 +5,7 @@ import replicate from ...core.main import ChatMessage from . import LLM +from .prompts.edit import simplified_edit_prompt class ReplicateLLM(LLM): @@ -15,13 +16,15 @@ class ReplicateLLM(LLM): _client: replicate.Client = None + prompt_templates = { + "edit": simplified_edit_prompt, + } + async def start(self, **kwargs): await super().start(**kwargs) self._client = replicate.Client(api_token=self.api_key) - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ): + async def _complete(self, prompt: str, options): def helper(): output = self._client.run( self.model, input={"message": prompt, "prompt": prompt} @@ -38,17 +41,18 @@ class ReplicateLLM(LLM): return completion - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ): + async def _stream_complete(self, prompt, options): for item in self._client.run( self.model, input={"message": prompt, "prompt": prompt} ): yield item - async def _stream_chat(self, messages: List[ChatMessage] = None, **kwargs): + async def _stream_chat(self, messages: List[ChatMessage], options): for item in self._client.run( self.model, - input={"message": messages[-1].content, "prompt": messages[-1].content}, + input={ + "message": messages[-1]["content"], + "prompt": messages[-1]["content"], + }, ): yield {"content": item, "role": "assistant"} diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py index 59627629..e37366c7 100644 --- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py +++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py @@ -1,42 +1,37 @@ import json -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, List, Optional import websockets from ...core.main import ChatMessage -from ..util.count_tokens import compile_chat_messages, format_chat_messages from . import LLM +from .prompts.edit import simplified_edit_prompt class TextGenUI(LLM): - # this is model-specific model: str = "text-gen-ui" server_url: str = "http://localhost:5000" streaming_url: str = "http://localhost:5005" verify_ssl: Optional[bool] = None + prompt_templates = { + "edit": simplified_edit_prompt, + } + class Config: arbitrary_types_allowed = True - def _transform_args(self, args): - args = { - **args, - "max_new_tokens": args.get("max_tokens", 1024), - } + def collect_args(self, options) -> Any: + args = super().collect_args(options) + args = {**args, "max_new_tokens": options.max_tokens} args.pop("max_tokens", None) return args - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream"] = True - - self.write_log(f"Prompt: \n\n{prompt}") - completion = "" + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" - payload = json.dumps({"prompt": prompt, **self._transform_args(args)}) + payload = json.dumps({"prompt": prompt, "stream": True, **args}) async with websockets.connect( f"{ws_url}/api/v1/stream", ping_interval=None ) as websocket: @@ -48,27 +43,12 @@ class TextGenUI(LLM): match incoming_data["event"]: case "text_stream": - completion += incoming_data["text"] yield incoming_data["text"] case "stream_end": break - self.write_log(f"Completion: \n\n{completion}") - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - messages, - self.context_length, - args["max_tokens"], - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - args["stream"] = True + async def _stream_chat(self, messages: List[ChatMessage], options): + args = self.collect_args(options) async def generator(): ws_url = f"{self.streaming_url.replace('http://', 'ws://').replace('https://', 'wss://')}" @@ -77,7 +57,8 @@ class TextGenUI(LLM): { "user_input": messages[-1]["content"], "history": {"internal": [history], "visible": [history]}, - **self._transform_args(args), + "stream": True, + **args, } ) async with websockets.connect( @@ -102,26 +83,5 @@ class TextGenUI(LLM): case "stream_end": break - # Because quite often the first attempt fails, and it works thereafter - self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - completion = "" async for chunk in generator(): yield chunk - if "content" in chunk: - completion += chunk["content"] - - self.write_log(f"Completion: \n\n{completion}") - - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - generator = self._stream_chat( - [ChatMessage(role="user", content=prompt, summary=prompt)], **kwargs - ) - - completion = "" - async for chunk in generator: - if "content" in chunk: - completion += chunk["content"] - - return completion diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index d8c7334b..03c9cce4 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -1,16 +1,14 @@ import json -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Callable, Optional import aiohttp -from ...core.main import ChatMessage from ..llm import LLM -from ..util.count_tokens import compile_chat_messages from .prompts.chat import llama2_template_messages +from .prompts.edit import simplified_edit_prompt class TogetherLLM(LLM): - # this is model-specific api_key: str "Together API key" @@ -20,61 +18,32 @@ class TogetherLLM(LLM): _client_session: aiohttp.ClientSession = None + template_messages: Callable = llama2_template_messages + + prompt_templates = { + "edit": simplified_edit_prompt, + } + async def start(self, **kwargs): await super().start(**kwargs) self._client_session = aiohttp.ClientSession( - connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl) + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl), + timeout=aiohttp.ClientTimeout(total=self.timeout), ) async def stop(self): await self._client_session.close() - async def _stream_complete( - self, prompt, with_history: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - args["stream_tokens"] = True - - messages = compile_chat_messages( - self.model, - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=args.get("functions", None), - system_message=self.system_message, - ) - - async with self._client_session.post( - f"{self.base_url}/inference", - json={"prompt": llama2_template_messages(messages), **args}, - headers={"Authorization": f"Bearer {self.api_key}"}, - ) as resp: - async for line in resp.content.iter_any(): - if line: - try: - yield line.decode("utf-8") - except: - raise Exception(str(line)) - - async def _stream_chat( - self, messages: List[ChatMessage] = None, **kwargs - ) -> Generator[Union[Any, List, Dict], None, None]: - args = self.collect_args(**kwargs) - messages = compile_chat_messages( - self.model, - messages, - self.context_length, - args["max_tokens"], - None, - functions=args.get("functions", None), - system_message=self.system_message, - ) - args["stream_tokens"] = True + async def _stream_complete(self, prompt, options): + args = self.collect_args(options) async with self._client_session.post( f"{self.base_url}/inference", - json={"prompt": llama2_template_messages(messages), **args}, + json={ + "prompt": prompt, + "stream_tokens": True, + **args, + }, headers={"Authorization": f"Bearer {self.api_key}"}, ) as resp: async for line in resp.content.iter_chunks(): @@ -92,36 +61,19 @@ class TogetherLLM(LLM): chunk = chunk[6:] json_chunk = json.loads(chunk) if "choices" in json_chunk: - yield { - "role": "assistant", - "content": json_chunk["choices"][0]["text"], - } + yield json_chunk["choices"][0]["text"] - async def _complete( - self, prompt: str, with_history: List[ChatMessage] = None, **kwargs - ) -> Coroutine[Any, Any, str]: - args = self.collect_args(**kwargs) + async def _complete(self, prompt: str, options): + args = self.collect_args(options) - messages = compile_chat_messages( - args["model"], - with_history, - self.context_length, - args["max_tokens"], - prompt, - functions=None, - system_message=self.system_message, - ) async with self._client_session.post( f"{self.base_url}/inference", - json={"prompt": llama2_template_messages(messages), **args}, + json={"prompt": prompt, **args}, headers={"Authorization": f"Bearer {self.api_key}"}, ) as resp: - try: - text = await resp.text() - j = json.loads(text) - if "choices" not in j["output"]: - raise Exception(text) - if "output" in j: - return j["output"]["choices"][0]["text"] - except: - raise Exception(await resp.text()) + text = await resp.text() + j = json.loads(text) + if "choices" not in j["output"]: + raise Exception(text) + if "output" in j: + return j["output"]["choices"][0]["text"] diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index 2663aa1c..aaa32907 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -12,13 +12,10 @@ aliases = { "ggml": "gpt-3.5-turbo", "claude-2": "gpt-3.5-turbo", } -DEFAULT_MAX_TOKENS = 2048 +DEFAULT_MAX_TOKENS = 1024 DEFAULT_ARGS = { "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, } @@ -144,13 +141,14 @@ def compile_chat_messages( """ The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message """ + msgs_copy = [msg.copy(deep=True) for msg in msgs] if msgs is not None else [] if prompt is not None: prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt) msgs_copy += [prompt_msg] - if system_message is not None: + if system_message is not None and system_message.strip() != "": # NOTE: System message takes second precedence to user prompt, so it is placed just before # but move back to start after processing rendered_system_message = render_templated_string(system_message) @@ -168,6 +166,11 @@ def compile_chat_messages( for function in functions: function_tokens += count_tokens(model_name, json.dumps(function)) + if max_tokens + function_tokens + TOKEN_BUFFER_FOR_SAFETY >= context_length: + raise ValueError( + f"max_tokens ({max_tokens}) is too close to context_length ({context_length}), which doesn't leave room for chat history. This would cause incoherent responses. Try increasing the context_length parameter of the model in your config file." + ) + msgs_copy = prune_chat_history( model_name, msgs_copy, diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py index 45a4a599..7c8ee76f 100644 --- a/continuedev/src/continuedev/libs/util/edit_config.py +++ b/continuedev/src/continuedev/libs/util/edit_config.py @@ -80,9 +80,13 @@ filtered_attrs = { } +def escape_string(string: str) -> str: + return string.replace('"', '\\"').replace("'", "\\'") + + def display_val(v: Any): if isinstance(v, str): - return f'"{v}"' + return f'"{escape_string(v)}"' return str(v) @@ -103,6 +107,7 @@ def create_obj_node(class_name: str, args: Dict[str, str]) -> redbaron.RedBaron: def create_string_node(string: str) -> redbaron.RedBaron: + string = escape_string(string) if "\n" in string: return redbaron.RedBaron(f'"""{string}"""')[0] return redbaron.RedBaron(f'"{string}"')[0] diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py index fe049268..43a2b800 100644 --- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py @@ -30,7 +30,7 @@ class SetupPipelineStep(Step): sdk.context.set("api_description", self.api_description) source_name = ( - await sdk.models.medium._complete( + await sdk.models.medium.complete( f"Write a snake_case name for the data source described by {self.api_description}: " ) ).strip() @@ -115,7 +115,7 @@ class ValidatePipelineStep(Step): if "Traceback" in output or "SyntaxError" in output: output = "Traceback" + output.split("Traceback")[-1] file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename)) - suggestion = await sdk.models.medium._complete( + suggestion = await sdk.models.medium.complete( dedent( f"""\ ```python @@ -131,7 +131,7 @@ class ValidatePipelineStep(Step): ) ) - api_documentation_url = await sdk.models.medium._complete( + api_documentation_url = await sdk.models.medium.complete( dedent( f"""\ The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this: @@ -216,7 +216,7 @@ class RunQueryStep(Step): ) if "Traceback" in output or "SyntaxError" in output: - suggestion = await sdk.models.medium._complete( + suggestion = await sdk.models.medium.complete( dedent( f"""\ ```python diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py index 44065d22..d6769148 100644 --- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py +++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py @@ -92,7 +92,7 @@ class LoadDataStep(Step): docs = f.read() output = "Traceback" + output.split("Traceback")[-1] - suggestion = await sdk.models.default._complete( + suggestion = await sdk.models.default.complete( dedent( f"""\ When trying to load data into BigQuery, the following error occurred: diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py index 4727c994..e2712746 100644 --- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py +++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py @@ -45,7 +45,7 @@ class WritePytestsRecipe(Step): Here is a complete set of pytest unit tests:""" ) - tests = await sdk.models.medium._complete(prompt) + tests = await sdk.models.medium.complete(prompt) await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests)) diff --git a/continuedev/src/continuedev/plugins/steps/README.md b/continuedev/src/continuedev/plugins/steps/README.md index a248f19c..3f2f804c 100644 --- a/continuedev/src/continuedev/plugins/steps/README.md +++ b/continuedev/src/continuedev/plugins/steps/README.md @@ -33,7 +33,7 @@ If you'd like to override the default description of your step, which is just th - Return a static string - Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method. -- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium._complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. +- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`. Here's an example: diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py index d580f886..15740057 100644 --- a/continuedev/src/continuedev/plugins/steps/chat.py +++ b/continuedev/src/continuedev/plugins/steps/chat.py @@ -83,7 +83,7 @@ class SimpleChatStep(Step): messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.chat._stream_chat( + generator = sdk.models.chat.stream_chat( messages, temperature=sdk.config.temperature ) @@ -118,7 +118,7 @@ class SimpleChatStep(Step): await sdk.update_ui() self.name = add_ellipsis( remove_quotes_and_escapes( - await sdk.models.medium._complete( + await sdk.models.medium.complete( f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:', max_tokens=20, ) @@ -260,7 +260,7 @@ class ChatWithFunctions(Step): gpt350613 = OpenAI(model="gpt-3.5-turbo-0613") await sdk.start_model(gpt350613) - async for msg_chunk in gpt350613._stream_chat( + async for msg_chunk in gpt350613.stream_chat( await sdk.get_chat_context(), functions=functions ): if sdk.current_step_was_deleted(): diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py index 9ee2a48d..25633942 100644 --- a/continuedev/src/continuedev/plugins/steps/chroma.py +++ b/continuedev/src/continuedev/plugins/steps/chroma.py @@ -58,7 +58,7 @@ class AnswerQuestionChroma(Step): Here is the answer:""" ) - answer = await sdk.models.medium._complete(prompt) + answer = await sdk.models.medium.complete(prompt) # Make paths relative to the workspace directory answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "") diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py index df29a01e..5325a918 100644 --- a/continuedev/src/continuedev/plugins/steps/core/core.py +++ b/continuedev/src/continuedev/plugins/steps/core/core.py @@ -107,7 +107,7 @@ class ShellCommandsStep(Step): return f"Error when running shell commands:\n```\n{self._err_text}\n```" cmds_str = "\n".join(self.cmds) - return await models.medium._complete( + return await models.medium.complete( f"{cmds_str}\n\nSummarize what was done in these shell commands, using markdown bullet points:" ) @@ -121,7 +121,7 @@ class ShellCommandsStep(Step): and output is not None and output_contains_error(output) ): - suggestion = await sdk.models.medium._complete( + suggestion = await sdk.models.medium.complete( dedent( f"""\ While running the command `{cmd}`, the following error occurred: @@ -220,7 +220,7 @@ class DefaultModelEditCodeStep(Step): self._new_contents.splitlines(), ) ) - description = await models.medium._complete( + description = await models.medium.complete( dedent( f"""\ Diff summary: "{self.user_input}" @@ -232,7 +232,7 @@ class DefaultModelEditCodeStep(Step): {self.summary_prompt}""" ) ) - name = await models.medium._complete( + name = await models.medium.complete( f"Write a very short title to describe this requested change (no quotes): '{self.user_input}'. This is the title:" ) self.name = remove_quotes_and_escapes(name) @@ -663,7 +663,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user else: messages = rendered - generator = model_to_use._stream_chat( + generator = model_to_use.stream_chat( messages, temperature=sdk.config.temperature, max_tokens=max_tokens ) @@ -874,7 +874,7 @@ class ManualEditStep(ReversibleStep): return "Manual edit step" # TODO - only handling FileEdit here, but need all other types of FileSystemEdits # Also requires the merge_file_edit function - # return llm._complete(dedent(f"""This code was replaced: + # return llm.complete(dedent(f"""This code was replaced: # {self.edit_diff.backward.replacement} diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py index c73d7eef..148dddb8 100644 --- a/continuedev/src/continuedev/plugins/steps/help.py +++ b/continuedev/src/continuedev/plugins/steps/help.py @@ -59,7 +59,7 @@ class HelpStep(Step): ChatMessage(role="user", content=prompt, summary="Help") ) messages = await sdk.get_chat_context() - generator = sdk.models.default._stream_chat(messages) + generator = sdk.models.default.stream_chat(messages) async for chunk in generator: if "content" in chunk: self.description += chunk["content"] diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py index 001876d0..721f1306 100644 --- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py +++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py @@ -26,7 +26,7 @@ class NLMultiselectStep(Step): if first_try is not None: return first_try - gpt_parsed = await sdk.models.default._complete( + gpt_parsed = await sdk.models.default.complete( f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:" ) return extract_option(gpt_parsed) or self.options[0] diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py index cecd3a2a..ca15aaab 100644 --- a/continuedev/src/continuedev/plugins/steps/main.py +++ b/continuedev/src/continuedev/plugins/steps/main.py @@ -105,7 +105,7 @@ class FasterEditHighlightedCodeStep(Step): for rif in range_in_files: rif_dict[rif.filepath] = rif.contents - completion = await sdk.models.medium._complete(prompt) + completion = await sdk.models.medium.complete(prompt) # Temporarily doing this to generate description. self._prompt = prompt @@ -180,7 +180,7 @@ class StarCoderEditHighlightedCodeStep(Step): _prompt_and_completion: str = "" async def describe(self, models: Models) -> Coroutine[str, None, None]: - return await models.medium._complete( + return await models.medium.complete( f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:" ) @@ -213,7 +213,7 @@ class StarCoderEditHighlightedCodeStep(Step): segs = full_file_contents.split(rif.contents) prompt = f"<file_prefix>{segs[0]}<file_suffix>{segs[1]}" + prompt - completion = str(await sdk.models.starcoder._complete(prompt)) + completion = str(await sdk.models.starcoder.complete(prompt)) eot_token = "<|endoftext|>" completion = completion.removesuffix(eot_token) diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py index 2ed2d3d7..a2612731 100644 --- a/continuedev/src/continuedev/plugins/steps/react.py +++ b/continuedev/src/continuedev/plugins/steps/react.py @@ -29,7 +29,7 @@ class NLDecisionStep(Step): Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:""" ) - resp = (await sdk.models.medium._complete(prompt)).lower() + resp = (await sdk.models.medium.complete(prompt)).lower() step_to_run = None for step in self.steps: diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py index 9317bfe1..04fb98b7 100644 --- a/continuedev/src/continuedev/plugins/steps/search_directory.py +++ b/continuedev/src/continuedev/plugins/steps/search_directory.py @@ -46,7 +46,7 @@ class WriteRegexPatternStep(Step): async def run(self, sdk: ContinueSDK): # Ask the user for a regex pattern - pattern = await sdk.models.medium._complete( + pattern = await sdk.models.medium.complete( dedent( f"""\ This is the user request: diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py index aea20ef9..f0c33929 100644 --- a/continuedev/src/continuedev/server/session_manager.py +++ b/continuedev/src/continuedev/server/session_manager.py @@ -10,7 +10,6 @@ from fastapi.websockets import WebSocketState from ..core.autopilot import Autopilot from ..core.main import FullState from ..libs.util.create_async_task import create_async_task -from ..libs.util.errors import SessionNotFound from ..libs.util.logging import logger from ..libs.util.paths import ( getSessionFilePath, @@ -124,7 +123,12 @@ class SessionManager: # Read and update the sessions list with open(getSessionsListFilePath(), "r") as f: - sessions_list = json.load(f) + try: + sessions_list = json.load(f) + except json.JSONDecodeError: + raise Exception( + f"It looks like there is a JSON formatting error in your sessions.json file ({getSessionsListFilePath()}). Please fix this before creating a new session." + ) session_ids = [s["session_id"] for s in sessions_list] if session_id not in session_ids: diff --git a/continuedev/src/continuedev/tests/llm_test.py b/continuedev/src/continuedev/tests/llm_test.py new file mode 100644 index 00000000..91ddd33f --- /dev/null +++ b/continuedev/src/continuedev/tests/llm_test.py @@ -0,0 +1,192 @@ +import asyncio +import os +from functools import wraps + +import pytest +from dotenv import load_dotenv + +from continuedev.core.main import ChatMessage +from continuedev.libs.llm import LLM, CompletionOptions +from continuedev.libs.llm.anthropic import AnthropicLLM +from continuedev.libs.llm.ggml import GGML +from continuedev.libs.llm.openai import OpenAI +from continuedev.libs.llm.together import TogetherLLM +from continuedev.libs.util.count_tokens import DEFAULT_ARGS +from continuedev.tests.util.openai_mock import start_openai + +load_dotenv() + +TEST_PROMPT = "Output a single word, that being the capital of Japan:" +SPEND_MONEY = True + + +def start_model(model): + def write_log(msg: str): + pass + + asyncio.run(model.start(write_log=write_log, unique_id="test_unique_id")) + + +def async_test(func): + @wraps(func) + def wrapper(*args, **kwargs): + return asyncio.run(func(*args, **kwargs)) + + return wrapper + + +class TestBaseLLM: + model = "gpt-3.5-turbo" + context_length = 4096 + system_message = "test_system_message" + + def setup_class(cls): + cls.llm = LLM( + model=cls.model, + context_length=cls.context_length, + system_message=cls.system_message, + ) + + start_model(cls.llm) + + def test_llm_is_instance(self): + assert isinstance(self.llm, LLM) + + def test_llm_collect_args(self): + options = CompletionOptions(model=self.model) + assert self.llm.collect_args(options) == { + **DEFAULT_ARGS, + "model": self.model, + } + + @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money") + @async_test + async def test_completion(self): + if self.llm.__class__.__name__ == "LLM": + pytest.skip("Skipping abstract LLM") + + resp = await self.llm.complete(TEST_PROMPT, temperature=0.0) + assert isinstance(resp, str) + assert resp.strip().lower() == "tokyo" + + @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money") + @async_test + async def test_stream_chat(self): + if self.llm.__class__.__name__ == "LLM": + pytest.skip("Skipping abstract LLM") + + completion = "" + role = None + async for chunk in self.llm.stream_chat( + messages=[ + ChatMessage(role="user", content=TEST_PROMPT, summary=TEST_PROMPT) + ], + temperature=0.0, + ): + assert isinstance(chunk, dict) + if "content" in chunk: + completion += chunk["content"] + if "role" in chunk: + role = chunk["role"] + + assert role == "assistant" + assert completion.strip().lower() == "tokyo" + + @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money") + @async_test + async def test_stream_complete(self): + if self.llm.__class__.__name__ == "LLM": + pytest.skip("Skipping abstract LLM") + + completion = "" + async for chunk in self.llm.stream_complete(TEST_PROMPT, temperature=0.0): + assert isinstance(chunk, str) + completion += chunk + + assert completion.strip().lower() == "tokyo" + + +class TestOpenAI(TestBaseLLM): + def setup_class(cls): + super().setup_class(cls) + cls.llm = OpenAI( + model=cls.model, + context_length=cls.context_length, + system_message=cls.system_message, + api_key=os.environ["OPENAI_API_KEY"], + # api_base=f"http://localhost:{port}", + ) + start_model(cls.llm) + # cls.server = start_openai(port=port) + + # def teardown_class(cls): + # cls.server.terminate() + + @pytest.mark.asyncio + @pytest.mark.skipif(SPEND_MONEY is False, reason="Costs money") + async def test_completion(self): + resp = await self.llm.complete( + "Output a single word, that being the capital of Japan:" + ) + assert isinstance(resp, str) + assert resp.strip().lower() == "tokyo" + + +class TestGGML(TestBaseLLM): + def setup_class(cls): + super().setup_class(cls) + port = 8000 + cls.llm = GGML( + model=cls.model, + context_length=cls.context_length, + system_message=cls.system_message, + api_base=f"http://localhost:{port}", + ) + start_model(cls.llm) + cls.server = start_openai(port=port) + + def teardown_class(cls): + cls.server.terminate() + + @pytest.mark.asyncio + async def test_stream_chat(self): + pytest.skip(reason="GGML is not working") + + @pytest.mark.asyncio + async def test_stream_complete(self): + pytest.skip(reason="GGML is not working") + + @pytest.mark.asyncio + async def test_completion(self): + pytest.skip(reason="GGML is not working") + + +@pytest.mark.skipif(True, reason="Together is not working") +class TestTogetherLLM(TestBaseLLM): + def setup_class(cls): + super().setup_class(cls) + cls.llm = TogetherLLM( + api_key=os.environ["TOGETHER_API_KEY"], + ) + start_model(cls.llm) + + +class TestAnthropicLLM(TestBaseLLM): + def setup_class(cls): + super().setup_class(cls) + cls.llm = AnthropicLLM(api_key=os.environ["ANTHROPIC_API_KEY"]) + start_model(cls.llm) + + def test_llm_collect_args(self): + options = CompletionOptions(model=self.model) + assert self.llm.collect_args(options) == { + "max_tokens_to_sample": DEFAULT_ARGS["max_tokens"], + "temperature": DEFAULT_ARGS["temperature"], + "model": self.model, + } + + +if __name__ == "__main__": + import pytest + + pytest.main() diff --git a/continuedev/src/continuedev/tests/util/openai_mock.py b/continuedev/src/continuedev/tests/util/openai_mock.py new file mode 100644 index 00000000..763c5647 --- /dev/null +++ b/continuedev/src/continuedev/tests/util/openai_mock.py @@ -0,0 +1,139 @@ +import asyncio +import os +import random +import subprocess +from typing import Dict, List, Optional + +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +openai = FastAPI() + + +class CompletionBody(BaseModel): + prompt: str + max_tokens: Optional[int] = 60 + stream: Optional[bool] = False + + class Config: + extra = "allow" + + +@openai.post("/completions") +@openai.post("/v1/completions") +async def mock_completion(item: CompletionBody): + prompt = item.prompt + + text = "This is a fake completion." + + if item.stream: + + async def stream_text(): + for i in range(len(text)): + word = random.choice(prompt.split()) + yield { + "choices": [ + { + "delta": {"role": "assistant", "content": word}, + "finish_reason": None, + "index": 0, + } + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + } + await asyncio.sleep(0.1) + + return StreamingResponse(stream_text(), media_type="text/plain") + + return { + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "gpt-3.5-turbo", + "choices": [ + { + "text": text, + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + } + + +class ChatBody(BaseModel): + messages: List[Dict[str, str]] + max_tokens: Optional[int] = None + stream: Optional[bool] = False + + class Config: + extra = "allow" + + +@openai.post("/v1/chat/completions") +async def mock_chat_completion(item: ChatBody): + text = "This is a fake completion." + + if item.stream: + + async def stream_text(): + for i in range(len(text)): + word = text[i] + yield { + "choices": [ + { + "delta": {"role": "assistant", "content": word}, + "finish_reason": None, + "index": 0, + } + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + } + await asyncio.sleep(0.1) + + return StreamingResponse(stream_text(), media_type="text/plain") + + return { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": text, + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, + } + + +def start_openai(port: int = 8000): + server = subprocess.Popen( + [ + "uvicorn", + "openai_mock:openai", + "--host", + "127.0.0.1", + "--port", + str(port), + ], + cwd=os.path.dirname(__file__), + ) + return server + + +if __name__ == "__main__": + start_openai() diff --git a/docs/docs/concepts/sdk.md b/docs/docs/concepts/sdk.md index ac2bc8bf..21190aa8 100644 --- a/docs/docs/concepts/sdk.md +++ b/docs/docs/concepts/sdk.md @@ -23,7 +23,7 @@ The **Continue SDK** gives you all the tools you need to automate software devel ### `sdk.models`
-`sdk.models` is an instance of the `Models` class, containing many of the most commonly used LLMs or other foundation models. You can access a model (starcoder for example) like `starcoder = sdk.models.starcoder`. Right now, all of the models are `LLM`s, meaning that they offer the `complete` method, used like `bubble_sort_code = await starcoder._complete("# Write a bubble sort function below, in Python:\n")`.
+`sdk.models` is an instance of the `Models` class, containing many of the most commonly used LLMs or other foundation models. You can access a model (starcoder for example) like `starcoder = sdk.models.starcoder`. Right now, all of the models are `LLM`s, meaning that they offer the `complete` method, used like `bubble_sort_code = await starcoder.complete("# Write a bubble sort function below, in Python:\n")`.
### `sdk.history`
diff --git a/docs/docs/customization.md b/docs/docs/customization.md index 09f7ed46..2d1d8ba4 100644 --- a/docs/docs/customization.md +++ b/docs/docs/customization.md @@ -292,7 +292,7 @@ class CommitMessageStep(Step): # Ask the LLM to write a commit message, # and set it as the description of this step - self.description = await sdk.models.default._complete( + self.description = await sdk.models.default.complete( f"{diff}\n\nWrite a short, specific (less than 50 chars) commit message about the above changes:") config=ContinueConfig( diff --git a/docs/docs/walkthroughs/create-a-recipe.md b/docs/docs/walkthroughs/create-a-recipe.md index cc80be0e..2cb28f77 100644 --- a/docs/docs/walkthroughs/create-a-recipe.md +++ b/docs/docs/walkthroughs/create-a-recipe.md @@ -31,7 +31,7 @@ If you'd like to override the default description of your steps, which is just t - Return a static string
- Store state in a class attribute (prepend with a double underscore, which signifies (through Pydantic) that this is not a parameter for the Step, just internal state) during the run method, and then grab this in the describe method.
-- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium._complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
+- Use state in conjunction with the `models` parameter of the describe method to autogenerate a description with a language model. For example, if you'd used an attribute called `__code_written` to store a string representing some code that was written, you could implement describe as `return models.medium.complete(f"{self.\_\_code_written}\n\nSummarize the changes made in the above code.")`.
## 2. Compose steps together into a complete recipe
diff --git a/extension/react-app/src/components/ModelSelect.tsx b/extension/react-app/src/components/ModelSelect.tsx index 1cbf3f0e..83f005c7 100644 --- a/extension/react-app/src/components/ModelSelect.tsx +++ b/extension/react-app/src/components/ModelSelect.tsx @@ -13,21 +13,13 @@ import { useSelector } from "react-redux"; const MODEL_INFO: { title: string; class: string; args: any }[] = [ { title: "gpt-4", - class: "MaybeProxyOpenAI", + class: "OpenAI", args: { model: "gpt-4", api_key: "", }, }, { - title: "gpt-3.5-turbo", - class: "MaybeProxyOpenAI", - args: { - model: "gpt-3.5-turbo", - api_key: "", - }, - }, - { title: "claude-2", class: "AnthropicLLM", args: { |