summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-07-30 15:53:53 -0700
committerNate Sesti <sestinj@gmail.com>2023-07-30 15:53:53 -0700
commit11d7f0a9d178b7ae8f913a2ad5e70d623ce4b11e (patch)
tree292c6fb38cebf46cd647d6f5133c2c6246616572
parentc57182b8533a2c86d465bbf21e3a357bda13bb41 (diff)
downloadsncontinue-11d7f0a9d178b7ae8f913a2ad5e70d623ce4b11e.tar.gz
sncontinue-11d7f0a9d178b7ae8f913a2ad5e70d623ce4b11e.tar.bz2
sncontinue-11d7f0a9d178b7ae8f913a2ad5e70d623ce4b11e.zip
refactor: :construction: refactor so server runs until requesting model
-rw-r--r--CONTRIBUTING.md1
-rw-r--r--continuedev/src/continuedev/core/autopilot.py1
-rw-r--r--continuedev/src/continuedev/core/main.py1
-rw-r--r--continuedev/src/continuedev/core/models.py33
-rw-r--r--continuedev/src/continuedev/libs/constants/default_config.py.txt8
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py5
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py27
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py38
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py12
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py27
12 files changed, 108 insertions, 57 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a958777f..bf39f22c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -107,7 +107,6 @@ When state is updated on the server, we currently send the entirety of the objec
- `history`, a record of previously run Steps. Displayed in order in the sidebar.
- `active`, whether the autopilot is currently running a step. Displayed as a loader while step is running.
- `user_input_queue`, the queue of user inputs that have not yet been processed due to waiting for previous Steps to complete. Displayed below the `active` loader until popped from the queue.
-- `default_model`, the default model used for completions. Displayed as a toggleable button on the bottom of the GUI.
- `selected_context_items`, the ranges of code and other items (like GitHub Issues, files, etc...) that have been selected to include as context. Displayed just above the main text input.
- `slash_commands`, the list of available slash commands. Displayed in the main text input dropdown.
- `adding_highlighted_code`, whether highlighting of new code for context is locked. Displayed as a button adjacent to `highlighted_ranges`.
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 42a58423..beb40c75 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -94,7 +94,6 @@ class Autopilot(ContinueBaseModel):
history=self.history,
active=self._active,
user_input_queue=self._main_user_input_queue,
- default_model=self.continue_sdk.config.default_model,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
"code"].adding_highlighted_code,
diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py
index df9b98ef..2553850f 100644
--- a/continuedev/src/continuedev/core/main.py
+++ b/continuedev/src/continuedev/core/main.py
@@ -258,7 +258,6 @@ class FullState(ContinueBaseModel):
history: History
active: bool
user_input_queue: List[str]
- default_model: str
slash_commands: List[SlashCommandDescription]
adding_highlighted_code: bool
selected_context_items: List[ContextItem]
diff --git a/continuedev/src/continuedev/core/models.py b/continuedev/src/continuedev/core/models.py
index ec89d503..e4cb8ed6 100644
--- a/continuedev/src/continuedev/core/models.py
+++ b/continuedev/src/continuedev/core/models.py
@@ -1,5 +1,5 @@
from typing import Optional, Any
-from pydantic import BaseModel
+from pydantic import BaseModel, validator
from ..libs.llm import LLM
@@ -12,7 +12,7 @@ class Models(BaseModel):
# TODO namespace these away to not confuse readers,
# or split Models into ModelsConfig, which gets turned into Models
- sdk: Any = None
+ sdk: "ContinueSDK" = None
system_message: Any = None
"""
@@ -34,43 +34,42 @@ class Models(BaseModel):
'''depending on the model, return the single prompt string'''
"""
- async def _start(self, llm: LLM):
+ async def _start_llm(self, llm: LLM):
kwargs = {}
- if llm.required_api_key:
- kwargs["api_key"] = await self.sdk.get_api_key(llm.required_api_key)
- if llm.required_unique_id:
+ if llm.requires_api_key:
+ kwargs["api_key"] = await self.sdk.get_api_key(llm.requires_api_key)
+ if llm.requires_unique_id:
kwargs["unique_id"] = self.sdk.ide.unique_id
- if llm.required_write_log:
+ if llm.requires_write_log:
kwargs["write_log"] = self.sdk.write_log
await llm.start(**kwargs)
async def start(self, sdk: "ContinueSDK"):
+ """Start each of the LLMs, or fall back to default"""
self.sdk = sdk
self.system_message = self.sdk.config.system_message
- await self._start(self.default)
+ await self._start_llm(self.default)
if self.small:
- await self._start(self.small)
+ await self._start_llm(self.small)
else:
self.small = self.default
if self.medium:
- await self._start(self.medium)
+ await self._start_llm(self.medium)
else:
self.medium = self.default
if self.large:
- await self._start(self.large)
+ await self._start_llm(self.large)
else:
self.large = self.default
async def stop(self, sdk: "ContinueSDK"):
+ """Stop each LLM (if it's not the default, which is shared)"""
await self.default.stop()
- if self.small:
+ if self.small is not self.default:
await self.small.stop()
-
- if self.medium:
+ if self.medium is not self.default:
await self.medium.stop()
-
- if self.large:
+ if self.large is not self.default:
await self.large.stop()
-
diff --git a/continuedev/src/continuedev/libs/constants/default_config.py.txt b/continuedev/src/continuedev/libs/constants/default_config.py.txt
index f80a9ff0..5708747f 100644
--- a/continuedev/src/continuedev/libs/constants/default_config.py.txt
+++ b/continuedev/src/continuedev/libs/constants/default_config.py.txt
@@ -12,7 +12,7 @@ from continuedev.src.continuedev.core.sdk import ContinueSDK
from continuedev.src.continuedev.core.config import CustomCommand, SlashCommand, ContinueConfig
from continuedev.src.continuedev.plugins.context_providers.github import GitHubIssuesContextProvider
from continuedev.src.continuedev.plugins.context_providers.google import GoogleContextProvider
-
+from continuedev.src.continuedev.libs.llm.maybe_proxy_openai import MaybeProxyOpenAI
class CommitMessageStep(Step):
"""
@@ -41,9 +41,9 @@ config = ContinueConfig(
# See here to learn what anonymous data we collect: https://continue.dev/docs/telemetry
allow_anonymous_telemetry=True,
- # GPT-4 is recommended for best results
- # See options here: https://continue.dev/docs/customization#change-the-default-llm
- default_model="gpt-4",
+ models=Models(
+ default=MaybeProxyOpenAI("gpt4")
+ )
# Set a system message with information that the LLM should always keep in mind
# E.g. "Please give concise answers. Always respond in Spanish."
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 0f6b1505..21afc338 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -9,7 +9,10 @@ from pydantic import BaseModel
class LLM(BaseModel, ABC):
- required_api_key: Optional[str] = None
+ requires_api_key: Optional[str] = None
+ requires_unique_id: bool = False
+ requires_write_log: bool = False
+
system_message: Union[str, None] = None
async def start(self, *, api_key: Optional[str] = None):
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 846a2450..067a903b 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -9,27 +9,28 @@ from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_
class AnthropicLLM(LLM):
- required_api_key: str = "ANTHROPIC_API_KEY"
- default_model: str
- async_client: AsyncAnthropic
+ model: str
- def __init__(self, default_model: str, system_message: str = None):
- self.default_model = default_model
+ requires_api_key: str = "ANTHROPIC_API_KEY"
+ _async_client: AsyncAnthropic = None
+
+ def __init__(self, model: str, system_message: str = None):
+ self.model = model
self.system_message = system_message
- async def start(self, *, api_key):
- self.async_client = AsyncAnthropic(api_key=api_key)
+ async def start(self, *, api_key: str):
+ self._async_client = AsyncAnthropic(api_key=api_key)
async def stop(self):
pass
@cached_property
def name(self):
- return self.default_model
+ return self.model
@property
def default_args(self):
- return {**DEFAULT_ARGS, "model": self.default_model}
+ return {**DEFAULT_ARGS, "model": self.model}
def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
args = args.copy()
@@ -43,7 +44,7 @@ class AnthropicLLM(LLM):
return args
def count_tokens(self, text: str):
- return count_tokens(self.default_model, text)
+ return count_tokens(self.model, text)
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -63,7 +64,7 @@ class AnthropicLLM(LLM):
args["stream"] = True
args = self._transform_args(args)
- async for chunk in await self.async_client.completions.create(
+ async for chunk in await self._async_client.completions.create(
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
**args
):
@@ -77,7 +78,7 @@ class AnthropicLLM(LLM):
messages = compile_chat_messages(
args["model"], messages, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message)
- async for chunk in await self.async_client.completions.create(
+ async for chunk in await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
):
@@ -92,7 +93,7 @@ class AnthropicLLM(LLM):
messages = compile_chat_messages(
args["model"], with_history, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message)
- resp = (await self.async_client.completions.create(
+ resp = (await self._async_client.completions.create(
prompt=self.__messages_to_prompt(messages),
**args
)).completion
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 401709c9..4bcf7e54 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -16,6 +16,12 @@ class GGML(LLM):
def __init__(self, system_message: str = None):
self.system_message = system_message
+ async def start(self, **kwargs):
+ pass
+
+ async def stop(self):
+ pass
+
@property
def name(self):
return "ggml"
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 06d37596..4ad32e0e 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -8,14 +8,16 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
- required_api_key: str = "HUGGING_FACE_TOKEN"
model: str
+ requires_api_key: str = "HUGGING_FACE_TOKEN"
+ api_key: str = None
+
def __init__(self, model: str, system_message: str = None):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
- async def start(self, *, api_key):
+ async def start(self, *, api_key: str):
self.api_key = api_key
def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
new file mode 100644
index 00000000..d2898b5c
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -0,0 +1,38 @@
+from typing import Any, Coroutine, Dict, Generator, List, Union, Optional
+
+from ...core.main import ChatMessage
+from . import LLM
+from .proxy_server import ProxyServer
+from .openai import OpenAI
+
+
+class MaybeProxyOpenAI(LLM):
+ model: str
+
+ requires_api_key: Optional[str] = "OPENAI_API_KEY"
+ requires_write_log: bool = True
+ system_message: Union[str, None] = None
+
+ llm: Optional[LLM] = None
+
+ async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ if api_key is None or api_key.strip() == "":
+ self.llm = ProxyServer(
+ unique_id="", model=self.model, write_log=kwargs["write_log"])
+ else:
+ self.llm = OpenAI(model=self.model, write_log=kwargs["write_log"])
+
+ async def stop(self):
+ await self.llm.stop()
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ return await self.llm.complete(prompt, with_history=with_history, **kwargs)
+
+ def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ return self.llm.stream_complete(prompt, with_history=with_history, **kwargs)
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ return self.llm.stream_chat(messages=messages, **kwargs)
+
+ def count_tokens(self, text: str):
+ return self.llm.count_tokens(text)
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 5ac4d211..0c2c360b 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -17,12 +17,14 @@ class AzureInfo(BaseModel):
class OpenAI(LLM):
model: str
+
+ requires_api_key = "OPENAI_API_KEY"
+ requires_write_log = True
+
system_message: Optional[str] = None
azure_info: Optional[AzureInfo] = None
write_log: Optional[Callable[[str], None]] = None
-
- required_api_key = "OPENAI_API_KEY"
- required_write_log = True
+ api_key: str = None
async def start(self, *, api_key):
self.api_key = api_key
@@ -31,8 +33,8 @@ class OpenAI(LLM):
# Using an Azure OpenAI deployment
if self.azure_info is not None:
openai.api_type = "azure"
- openai.api_base = azure_info.endpoint
- openai.api_version = azure_info.api_version
+ openai.api_base = self.azure_info.endpoint
+ openai.api_version = self.azure_info.api_version
async def stop(self):
pass
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 2c0e1dc4..e8f1cb46 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,4 +1,3 @@
-
import json
import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional
@@ -18,20 +17,24 @@ SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"
class ProxyServer(LLM):
- unique_id: str
model: str
system_message: Optional[str]
- write_log: Callable[[str], None]
- required_unique_id = True
- required_write_log = True
+ unique_id: str = None
+ write_log: Callable[[str], None] = None
+ _client_session: aiohttp.ClientSession
+
+ requires_unique_id = True
+ requires_write_log = True
- async def start(self):
- # TODO put ClientSession here
- pass
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context))
+ self.write_log = kwargs["write_log"]
+ self.unique_id = kwargs["unique_id"]
async def stop(self):
- pass
+ await self._client_session.close()
@property
def name(self):
@@ -54,7 +57,7 @@ class ProxyServer(LLM):
messages = compile_chat_messages(
args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with self._client_session as session:
async with session.post(f"{SERVER_URL}/complete", json={
"messages": messages,
**args
@@ -72,7 +75,7 @@ class ProxyServer(LLM):
args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with self._client_session as session:
async with session.post(f"{SERVER_URL}/stream_chat", json={
"messages": messages,
**args
@@ -107,7 +110,7 @@ class ProxyServer(LLM):
self.model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:
+ async with self._client_session as session:
async with session.post(f"{SERVER_URL}/stream_complete", json={
"messages": messages,
**args