diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 21 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 80 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui_protocol.py | 40 | 
5 files changed, 50 insertions, 95 deletions
| diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index bfb89561..bb2c43dc 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -225,7 +225,7 @@ class ContextManager:                  await self.load_index(sdk.ide.workspace_directory)                  logger.debug("Loaded Meilisearch index")              except asyncio.TimeoutError: -                logger.warning("Meilisearch is not running. As of now, Continue does not attempt to download Meilisearch on Windows because the download process is more involved. If you'd like install Meilisearch (which allows you to reference context by typing '@' (e.g. files, GitHub issues, etc.)), follow the instructions here: https://www.meilisearch.com/docs/learn/getting_started/installation. Alternatively, you can track our progress on support for Meilisearch on Windows here: https://github.com/continuedev/continue/issues/408.") +                logger.warning("Meilisearch is not running.")          create_async_task(start_meilisearch(context_providers)) diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 90ef7934..1e77a691 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -9,7 +9,7 @@ class LLM(ContinueBaseModel, ABC):      requires_api_key: Optional[str] = None      requires_unique_id: bool = False      requires_write_log: bool = False - +    title: Optional[str] = None      system_message: Optional[str] = None      prompt_templates: dict = {} diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a61103b9..c5d19ed2 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -190,23 +190,18 @@ class OpenAI(LLM):                  system_message=self.system_message,              )              self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -            resp = ( -                ( -                    await openai.ChatCompletion.acreate( -                        messages=messages, -                        **args, -                    ) -                ) -                .choices[0] -                .message.content +            resp = await openai.ChatCompletion.acreate( +                messages=messages, +                **args,              ) -            self.write_log(f"Completion: \n\n{resp}") +            completion = resp.choices[0].message.content +            self.write_log(f"Completion: \n\n{completion}")          else:              prompt = prune_raw_prompt_from_top(                  args["model"], self.context_length, prompt, args["max_tokens"]              )              self.write_log(f"Prompt:\n\n{prompt}") -            resp = ( +            completion = (                  (                      await openai.Completion.acreate(                          prompt=prompt, @@ -216,6 +211,6 @@ class OpenAI(LLM):                  .choices[0]                  .text              ) -            self.write_log(f"Completion:\n\n{resp}") +            self.write_log(f"Completion:\n\n{completion}") -        return resp +        return completion diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 2c4f2e4d..49541b76 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -23,7 +23,6 @@ from ..libs.util.queue import AsyncSubscriptionQueue  from ..libs.util.telemetry import posthog_logger  from ..plugins.steps.core.core import DisplayErrorStep  from ..plugins.steps.setup_model import SetupModelStep -from .gui_protocol import AbstractGUIProtocolServer  from .session_manager import Session, session_manager  router = APIRouter(prefix="/gui", tags=["gui"]) @@ -54,7 +53,7 @@ T = TypeVar("T", bound=BaseModel)  # You should probably abstract away the websocket stuff into a separate class -class GUIProtocolServer(AbstractGUIProtocolServer): +class GUIProtocolServer:      websocket: WebSocket      session: Session      sub_queue: AsyncSubscriptionQueue = AsyncSubscriptionQueue() @@ -118,8 +117,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              self.set_system_message(data["message"])          elif message_type == "set_temperature":              self.set_temperature(float(data["temperature"])) -        elif message_type == "set_model_for_role": -            self.set_model_for_role(data["role"], data["model_class"], data["model"]) +        elif message_type == "add_model_for_role": +            self.add_model_for_role(data["role"], data["model_class"], data["model"]) +        elif message_type == "set_model_for_role_from_index": +            self.set_model_for_role_from_index(data["role"], data["index"])          elif message_type == "save_context_group":              self.save_context_group(                  data["title"], [ContextItem(**item) for item in data["context_items"]] @@ -230,51 +231,50 @@ class GUIProtocolServer(AbstractGUIProtocolServer):              self.on_error,          ) -    def set_model_for_role(self, role: str, model_class: str, model: Any): +    def set_model_for_role_from_index(self, role: str, index: int): +        async def async_stuff(): +            models = self.session.autopilot.continue_sdk.config.models + +            # Set models in SDK +            temp = models.default +            models.default = models.unused[index] +            models.unused[index] = temp +            await self.session.autopilot.continue_sdk.start_model(models.default) + +            # Set models in config.py +            JOINER = ", " +            models_args = { +                "unused": f"[{JOINER.join([display_llm_class(llm) for llm in models.unused])}]", +                ("default" if role == "*" else role): display_llm_class(models.default), +            } + +            await self.session.autopilot.set_config_attr( +                ["models"], +                create_obj_node("Models", models_args), +            ) + +            for other_role in ALL_MODEL_ROLES: +                if other_role != "default": +                    models.__setattr__(other_role, models.default) + +            await self.session.autopilot.continue_sdk.update_ui() + +        create_async_task(async_stuff(), self.on_error) + +    def add_model_for_role(self, role: str, model_class: str, model: Any):          models = self.session.autopilot.continue_sdk.config.models          unused_models = models.unused          if role == "*":              async def async_stuff(): -                # Clear all of the models and store them in unused_models -                # NOTE: There will be duplicates                  for role in ALL_MODEL_ROLES: -                    prev_model = models.__getattribute__(role)                      models.__setattr__(role, None) -                    if prev_model is not None: -                        exists = False -                        for other in unused_models: -                            if ( -                                prev_model.__class__.__name__ -                                == other.__class__.__name__ -                                and ( -                                    other.name is not None -                                    and ( -                                        not other.name.startswith("gpt") -                                        or prev_model.name == other.name -                                    ) -                                ) -                            ): -                                exists = True -                                break -                        if not exists: -                            unused_models.append(prev_model) - -                # Replace default with either new one or existing from unused_models -                for unused_model in unused_models: -                    if model_class == unused_model.__class__.__name__ and ( -                        "model" not in model -                        or model["model"] == unused_model.model -                        and model["model"].startswith("gpt") -                    ): -                        models.default = unused_model                  # Set and start the default model if didn't already exist from unused -                if models.default is None: -                    models.default = MODEL_CLASSES[model_class](**model) -                    await self.session.autopilot.continue_sdk.run_step( -                        SetupModelStep(model_class=model_class) -                    ) +                models.default = MODEL_CLASSES[model_class](**model) +                await self.session.autopilot.continue_sdk.run_step( +                    SetupModelStep(model_class=model_class) +                )                  await self.session.autopilot.continue_sdk.start_model(models.default) diff --git a/continuedev/src/continuedev/server/gui_protocol.py b/continuedev/src/continuedev/server/gui_protocol.py deleted file mode 100644 index d079475c..00000000 --- a/continuedev/src/continuedev/server/gui_protocol.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class AbstractGUIProtocolServer(ABC): -    @abstractmethod -    async def handle_json(self, data: Any): -        """Handle a json message""" - -    @abstractmethod -    def on_main_input(self, input: str): -        """Called when the user inputs something""" - -    @abstractmethod -    def on_reverse_to_index(self, index: int): -        """Called when the user requests reverse to a previous index""" - -    @abstractmethod -    def on_refinement_input(self, input: str, index: int): -        """Called when the user inputs a refinement""" - -    @abstractmethod -    def on_step_user_input(self, input: str, index: int): -        """Called when the user inputs a step""" - -    @abstractmethod -    def on_retry_at_index(self, index: int): -        """Called when the user requests a retry at a previous index""" - -    @abstractmethod -    def on_clear_history(self): -        """Called when the user requests to clear the history""" - -    @abstractmethod -    def on_delete_at_index(self, index: int): -        """Called when the user requests to delete a step at a given index""" - -    @abstractmethod -    def select_context_item(self, id: str, query: str): -        """Called when user selects an item from the dropdown""" | 
