summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-09-01 18:31:33 -0700
committerNate Sesti <sestinj@gmail.com>2023-09-01 18:31:33 -0700
commit2f792f46026a6bb3c3580f2521b01ecb8c68117c (patch)
tree36cecf6d218bc166c0a8c1c78261b4feac7f01cd /continuedev
parent5c8b28b7fddf5b214de61102c768ef44d4087870 (diff)
downloadsncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.tar.gz
sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.tar.bz2
sncontinue-2f792f46026a6bb3c3580f2521b01ecb8c68117c.zip
feat: :sparkles: improved model dropdown
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/core/context.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py21
-rw-r--r--continuedev/src/continuedev/server/gui.py80
-rw-r--r--continuedev/src/continuedev/server/gui_protocol.py40
5 files changed, 50 insertions, 95 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index bfb89561..bb2c43dc 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -225,7 +225,7 @@ class ContextManager:
await self.load_index(sdk.ide.workspace_directory)
logger.debug("Loaded Meilisearch index")
except asyncio.TimeoutError:
- logger.warning("Meilisearch is not running. As of now, Continue does not attempt to download Meilisearch on Windows because the download process is more involved. If you'd like install Meilisearch (which allows you to reference context by typing '@' (e.g. files, GitHub issues, etc.)), follow the instructions here: https://www.meilisearch.com/docs/learn/getting_started/installation. Alternatively, you can track our progress on support for Meilisearch on Windows here: https://github.com/continuedev/continue/issues/408.")
+ logger.warning("Meilisearch is not running.")
create_async_task(start_meilisearch(context_providers))
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 90ef7934..1e77a691 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -9,7 +9,7 @@ class LLM(ContinueBaseModel, ABC):
requires_api_key: Optional[str] = None
requires_unique_id: bool = False
requires_write_log: bool = False
-
+ title: Optional[str] = None
system_message: Optional[str] = None
prompt_templates: dict = {}
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index a61103b9..c5d19ed2 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -190,23 +190,18 @@ class OpenAI(LLM):
system_message=self.system_message,
)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- resp = (
- (
- await openai.ChatCompletion.acreate(
- messages=messages,
- **args,
- )
- )
- .choices[0]
- .message.content
+ resp = await openai.ChatCompletion.acreate(
+ messages=messages,
+ **args,
)
- self.write_log(f"Completion: \n\n{resp}")
+ completion = resp.choices[0].message.content
+ self.write_log(f"Completion: \n\n{completion}")
else:
prompt = prune_raw_prompt_from_top(
args["model"], self.context_length, prompt, args["max_tokens"]
)
self.write_log(f"Prompt:\n\n{prompt}")
- resp = (
+ completion = (
(
await openai.Completion.acreate(
prompt=prompt,
@@ -216,6 +211,6 @@ class OpenAI(LLM):
.choices[0]
.text
)
- self.write_log(f"Completion:\n\n{resp}")
+ self.write_log(f"Completion:\n\n{completion}")
- return resp
+ return completion
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 2c4f2e4d..49541b76 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -23,7 +23,6 @@ from ..libs.util.queue import AsyncSubscriptionQueue
from ..libs.util.telemetry import posthog_logger
from ..plugins.steps.core.core import DisplayErrorStep
from ..plugins.steps.setup_model import SetupModelStep
-from .gui_protocol import AbstractGUIProtocolServer
from .session_manager import Session, session_manager
router = APIRouter(prefix="/gui", tags=["gui"])
@@ -54,7 +53,7 @@ T = TypeVar("T", bound=BaseModel)
# You should probably abstract away the websocket stuff into a separate class
-class GUIProtocolServer(AbstractGUIProtocolServer):
+class GUIProtocolServer:
websocket: WebSocket
session: Session
sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue()
@@ -118,8 +117,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.set_system_message(data["message"])
elif message_type == "set_temperature":
self.set_temperature(float(data["temperature"]))
- elif message_type == "set_model_for_role":
- self.set_model_for_role(data["role"], data["model_class"], data["model"])
+ elif message_type == "add_model_for_role":
+ self.add_model_for_role(data["role"], data["model_class"], data["model"])
+ elif message_type == "set_model_for_role_from_index":
+ self.set_model_for_role_from_index(data["role"], data["index"])
elif message_type == "save_context_group":
self.save_context_group(
data["title"], [ContextItem(**item) for item in data["context_items"]]
@@ -230,51 +231,50 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
self.on_error,
)
- def set_model_for_role(self, role: str, model_class: str, model: Any):
+ def set_model_for_role_from_index(self, role: str, index: int):
+ async def async_stuff():
+ models = self.session.autopilot.continue_sdk.config.models
+
+ # Set models in SDK
+ temp = models.default
+ models.default = models.unused[index]
+ models.unused[index] = temp
+ await self.session.autopilot.continue_sdk.start_model(models.default)
+
+ # Set models in config.py
+ JOINER = ", "
+ models_args = {
+ "unused": f"[{JOINER.join([display_llm_class(llm) for llm in models.unused])}]",
+ ("default" if role == "*" else role): display_llm_class(models.default),
+ }
+
+ await self.session.autopilot.set_config_attr(
+ ["models"],
+ create_obj_node("Models", models_args),
+ )
+
+ for other_role in ALL_MODEL_ROLES:
+ if other_role != "default":
+ models.__setattr__(other_role, models.default)
+
+ await self.session.autopilot.continue_sdk.update_ui()
+
+ create_async_task(async_stuff(), self.on_error)
+
+ def add_model_for_role(self, role: str, model_class: str, model: Any):
models = self.session.autopilot.continue_sdk.config.models
unused_models = models.unused
if role == "*":
async def async_stuff():
- # Clear all of the models and store them in unused_models
- # NOTE: There will be duplicates
for role in ALL_MODEL_ROLES:
- prev_model = models.__getattribute__(role)
models.__setattr__(role, None)
- if prev_model is not None:
- exists = False
- for other in unused_models:
- if (
- prev_model.__class__.__name__
- == other.__class__.__name__
- and (
- other.name is not None
- and (
- not other.name.startswith("gpt")
- or prev_model.name == other.name
- )
- )
- ):
- exists = True
- break
- if not exists:
- unused_models.append(prev_model)
-
- # Replace default with either new one or existing from unused_models
- for unused_model in unused_models:
- if model_class == unused_model.__class__.__name__ and (
- "model" not in model
- or model["model"] == unused_model.model
- and model["model"].startswith("gpt")
- ):
- models.default = unused_model
# Set and start the default model if didn't already exist from unused
- if models.default is None:
- models.default = MODEL_CLASSES[model_class](**model)
- await self.session.autopilot.continue_sdk.run_step(
- SetupModelStep(model_class=model_class)
- )
+ models.default = MODEL_CLASSES[model_class](**model)
+ await self.session.autopilot.continue_sdk.run_step(
+ SetupModelStep(model_class=model_class)
+ )
await self.session.autopilot.continue_sdk.start_model(models.default)
diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py
deleted file mode 100644
index d079475c..00000000
--- a/continuedev/src/continuedev/server/gui_protocol.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any
-
-
-class AbstractGUIProtocolServer(ABC):
- @abstractmethod
- async def handle_json(self, data: Any):
- """Handle a json message"""
-
- @abstractmethod
- def on_main_input(self, input: str):
- """Called when the user inputs something"""
-
- @abstractmethod
- def on_reverse_to_index(self, index: int):
- """Called when the user requests reverse to a previous index"""
-
- @abstractmethod
- def on_refinement_input(self, input: str, index: int):
- """Called when the user inputs a refinement"""
-
- @abstractmethod
- def on_step_user_input(self, input: str, index: int):
- """Called when the user inputs a step"""
-
- @abstractmethod
- def on_retry_at_index(self, index: int):
- """Called when the user requests a retry at a previous index"""
-
- @abstractmethod
- def on_clear_history(self):
- """Called when the user requests to clear the history"""
-
- @abstractmethod
- def on_delete_at_index(self, index: int):
- """Called when the user requests to delete a step at a given index"""
-
- @abstractmethod
- def select_context_item(self, id: str, query: str):
- """Called when user selects an item from the dropdown"""