diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-06-15 12:07:51 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-06-15 12:07:51 -0700 |
commit | cb518073b0e3711a946282864c150edff35db53a (patch) | |
tree | 652fa8f8b7b7a894490fb6e9a657c820a08afbe9 | |
parent | 6de892f12959a43c74372f1eba40ec2f53b8c537 (diff) | |
download | sncontinue-cb518073b0e3711a946282864c150edff35db53a.tar.gz sncontinue-cb518073b0e3711a946282864c150edff35db53a.tar.bz2 sncontinue-cb518073b0e3711a946282864c150edff35db53a.zip |
proxy server running locally
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 9 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 13 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 84 | ||||
-rw-r--r-- | extension/package.json | 2 | ||||
-rw-r--r-- | extension/src/continueIdeClient.ts | 12 |
5 files changed, 104 insertions, 16 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 1da190ff..b806ef73 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -15,6 +15,7 @@ from .observation import Observation from ..server.ide_protocol import AbstractIdeProtocolServer from .main import Context, ContinueCustomException, History, Step, ChatMessage, ChatMessageRole from ..steps.core.core import * +from ..libs.llm.proxy_server import ProxyServer class Autopilot: @@ -37,7 +38,9 @@ class Models: def gpt35(self): async def load_gpt35(): api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file') + 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo") return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo") return asyncio.get_event_loop().run_until_complete(load_gpt35()) @@ -45,7 +48,9 @@ class Models: def gpt4(self): async def load_gpt4(): api_key = await self.sdk.get_user_secret( - 'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file') + 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') + if api_key == "": + return ProxyServer(self.sdk.ide.unique_id, "gpt-4") return OpenAI(api_key=api_key, default_model="gpt-4") return asyncio.get_event_loop().run_until_complete(load_gpt4()) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 4889c01e..2986b2c4 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -1,24 +1,25 @@ -from typing import List, Union +from abc import ABC +from typing import Any, Dict, Generator, List, Union from ...core.main import ChatMessage from ...models.main import AbstractModel from pydantic import BaseModel -class LLM(BaseModel): +class LLM(ABC): system_message: Union[str, None] = None def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): """Return the completion of the text with the given temperature.""" raise + def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + """Yield a stream of chat messages.""" + raise NotImplementedError + def __call__(self, prompt: str, **kwargs): return self.complete(prompt, **kwargs) - def fine_tune(self, pairs: list): - """Fine tune the model on the given prompt/completion pairs.""" - raise NotImplementedError - def with_system_message(self, system_message: Union[str, None]): """Return a new model with the given system message.""" raise NotImplementedError diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py new file mode 100644 index 00000000..f75788d2 --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -0,0 +1,84 @@ +from functools import cached_property +import json +from typing import Any, Dict, Generator, List, Literal, Union +import requests +import tiktoken + +from ...core.main import ChatMessage +from ..llm import LLM + +MAX_TOKENS_FOR_MODEL = { + "gpt-3.5-turbo": 4097, + "gpt-4": 4097, +} +DEFAULT_MAX_TOKENS = 2048 +CHAT_MODELS = { + "gpt-3.5-turbo", "gpt-4" +} + +SERVER_URL = "http://127.0.0.1:8002" + + +class ProxyServer(LLM): + unique_id: str + default_model: Literal["gpt-3.5-turbo", "gpt-4"] + + def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None): + self.unique_id = unique_id + self.default_model = default_model + self.system_message = system_message + + @cached_property + def __encoding_for_model(self): + aliases = { + "gpt-3.5-turbo": "gpt3" + } + return tiktoken.encoding_for_model(self.default_model) + + def count_tokens(self, text: str): + return len(self.__encoding_for_model.encode(text, disallowed_special=())) + + def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + resp = requests.post(f"{SERVER_URL}/stream_complete", json={ + "chat_history": self.compile_chat_messages(with_history, prompt), + "model": self.default_model, + "unique_id": self.unique_id, + }, stream=True) + for line in resp.iter_lines(): + if line: + yield line.decode("utf-8") + + def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): + tokens = tokens_for_completion + for i in range(len(chat_history) - 1, -1, -1): + message = chat_history[i] + tokens += self.count_tokens(message.content) + if tokens > max_tokens: + return chat_history[i + 1:] + return chat_history + + def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: + msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens( + prompt) + 1000 + self.count_tokens(self.system_message or "")) + history = [] + if self.system_message: + history.append({ + "role": "system", + "content": self.system_message + }) + history += [msg.dict() for msg in msgs] + history.append({ + "role": "user", + "content": prompt + }) + + return history + + def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str: + + resp = requests.post(f"{SERVER_URL}/complete", json={ + "chat_history": self.compile_chat_messages(with_history, prompt), + "model": self.default_model, + "unique_id": self.unique_id, + }) + return json.loads(resp.text) diff --git a/extension/package.json b/extension/package.json index 13cc92a2..9ff7165c 100644 --- a/extension/package.json +++ b/extension/package.json @@ -96,7 +96,7 @@ { "type": "webview", "id": "continue.continueGUIView", - "name": "Continue GUI", + "name": ")", "visibility": "visible" } ] diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts index c879c682..ef9a91c8 100644 --- a/extension/src/continueIdeClient.ts +++ b/extension/src/continueIdeClient.ts @@ -192,15 +192,13 @@ class IdeProtocolClient { async getUserSecret(key: string) { // Check if secret already exists in VS Code settings (global) let secret = vscode.workspace.getConfiguration("continue").get(key); - if (secret && secret !== "") return secret; + if (typeof secret !== "undefined") return secret; // If not, ask user for secret - while (typeof secret === "undefined" || secret === "") { - secret = await vscode.window.showInputBox({ - prompt: `Enter secret for ${key}. You can edit this later in the Continue VS Code settings.`, - password: true, - }); - } + secret = await vscode.window.showInputBox({ + prompt: `Enter secret for ${key}, OR press enter to try for free. You can edit this later in the Continue VS Code settings.`, + password: true, + }); // Add secret to VS Code settings vscode.workspace |