summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/sdk.py9
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py13
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py84
-rw-r--r--extension/package.json2
-rw-r--r--extension/src/continueIdeClient.ts12
5 files changed, 104 insertions, 16 deletions
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 1da190ff..b806ef73 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -15,6 +15,7 @@ from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
from .main import Context, ContinueCustomException, History, Step, ChatMessage, ChatMessageRole
from ..steps.core.core import *
+from ..libs.llm.proxy_server import ProxyServer
class Autopilot:
@@ -37,7 +38,9 @@ class Models:
def gpt35(self):
async def load_gpt35():
api_key = await self.sdk.get_user_secret(
- 'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file')
+ 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free')
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo")
return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo")
return asyncio.get_event_loop().run_until_complete(load_gpt35())
@@ -45,7 +48,9 @@ class Models:
def gpt4(self):
async def load_gpt4():
api_key = await self.sdk.get_user_secret(
- 'OPENAI_API_KEY', 'Please add your OpenAI API key to the .env file')
+ 'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free')
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, "gpt-4")
return OpenAI(api_key=api_key, default_model="gpt-4")
return asyncio.get_event_loop().run_until_complete(load_gpt4())
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 4889c01e..2986b2c4 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,24 +1,25 @@
-from typing import List, Union
+from abc import ABC
+from typing import Any, Dict, Generator, List, Union
from ...core.main import ChatMessage
from ...models.main import AbstractModel
from pydantic import BaseModel
-class LLM(BaseModel):
+class LLM(ABC):
system_message: Union[str, None] = None
def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
"""Return the completion of the text with the given temperature."""
raise
+ def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ """Yield a stream of chat messages."""
+ raise NotImplementedError
+
def __call__(self, prompt: str, **kwargs):
return self.complete(prompt, **kwargs)
- def fine_tune(self, pairs: list):
- """Fine tune the model on the given prompt/completion pairs."""
- raise NotImplementedError
-
def with_system_message(self, system_message: Union[str, None]):
"""Return a new model with the given system message."""
raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
new file mode 100644
index 00000000..f75788d2
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -0,0 +1,84 @@
+from functools import cached_property
+import json
+from typing import Any, Dict, Generator, List, Literal, Union
+import requests
+import tiktoken
+
+from ...core.main import ChatMessage
+from ..llm import LLM
+
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4097,
+ "gpt-4": 4097,
+}
+DEFAULT_MAX_TOKENS = 2048
+CHAT_MODELS = {
+ "gpt-3.5-turbo", "gpt-4"
+}
+
+SERVER_URL = "http://127.0.0.1:8002"
+
+
+class ProxyServer(LLM):
+ unique_id: str
+ default_model: Literal["gpt-3.5-turbo", "gpt-4"]
+
+ def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None):
+ self.unique_id = unique_id
+ self.default_model = default_model
+ self.system_message = system_message
+
+ @cached_property
+ def __encoding_for_model(self):
+ aliases = {
+ "gpt-3.5-turbo": "gpt3"
+ }
+ return tiktoken.encoding_for_model(self.default_model)
+
+ def count_tokens(self, text: str):
+ return len(self.__encoding_for_model.encode(text, disallowed_special=()))
+
+ def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ resp = requests.post(f"{SERVER_URL}/stream_complete", json={
+ "chat_history": self.compile_chat_messages(with_history, prompt),
+ "model": self.default_model,
+ "unique_id": self.unique_id,
+ }, stream=True)
+ for line in resp.iter_lines():
+ if line:
+ yield line.decode("utf-8")
+
+ def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int):
+ tokens = tokens_for_completion
+ for i in range(len(chat_history) - 1, -1, -1):
+ message = chat_history[i]
+ tokens += self.count_tokens(message.content)
+ if tokens > max_tokens:
+ return chat_history[i + 1:]
+ return chat_history
+
+ def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]:
+ msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens(
+ prompt) + 1000 + self.count_tokens(self.system_message or ""))
+ history = []
+ if self.system_message:
+ history.append({
+ "role": "system",
+ "content": self.system_message
+ })
+ history += [msg.dict() for msg in msgs]
+ history.append({
+ "role": "user",
+ "content": prompt
+ })
+
+ return history
+
+ def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> str:
+
+ resp = requests.post(f"{SERVER_URL}/complete", json={
+ "chat_history": self.compile_chat_messages(with_history, prompt),
+ "model": self.default_model,
+ "unique_id": self.unique_id,
+ })
+ return json.loads(resp.text)
diff --git a/extension/package.json b/extension/package.json
index 13cc92a2..9ff7165c 100644
--- a/extension/package.json
+++ b/extension/package.json
@@ -96,7 +96,7 @@
{
"type": "webview",
"id": "continue.continueGUIView",
- "name": "Continue GUI",
+ "name": ")",
"visibility": "visible"
}
]
diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts
index c879c682..ef9a91c8 100644
--- a/extension/src/continueIdeClient.ts
+++ b/extension/src/continueIdeClient.ts
@@ -192,15 +192,13 @@ class IdeProtocolClient {
async getUserSecret(key: string) {
// Check if secret already exists in VS Code settings (global)
let secret = vscode.workspace.getConfiguration("continue").get(key);
- if (secret && secret !== "") return secret;
+ if (typeof secret !== "undefined") return secret;
// If not, ask user for secret
- while (typeof secret === "undefined" || secret === "") {
- secret = await vscode.window.showInputBox({
- prompt: `Enter secret for ${key}. You can edit this later in the Continue VS Code settings.`,
- password: true,
- });
- }
+ secret = await vscode.window.showInputBox({
+ prompt: `Enter secret for ${key}, OR press enter to try for free. You can edit this later in the Continue VS Code settings.`,
+ password: true,
+ });
// Add secret to VS Code settings
vscode.workspace