summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/autopilot.py10
-rw-r--r--continuedev/src/continuedev/core/sdk.py59
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py6
-rw-r--r--continuedev/src/continuedev/server/ide.py4
-rw-r--r--continuedev/src/continuedev/server/session_manager.py6
-rw-r--r--extension/react-app/src/components/StepContainer.tsx2
-rw-r--r--extension/react-app/src/components/TextDialog.tsx6
-rw-r--r--extension/react-app/src/pages/gui.tsx6
-rw-r--r--extension/react-app/src/util/index.ts30
9 files changed, 98 insertions, 31 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index 82439f49..0696c360 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -50,6 +50,8 @@ class Autopilot(ContinueBaseModel):
full_state: Union[FullState, None] = None
_on_update_callbacks: List[Callable[[FullState], None]] = []
+ continue_sdk: ContinueSDK = None
+
_active: bool = False
_should_halt: bool = False
_main_user_input_queue: List[str] = []
@@ -57,9 +59,11 @@ class Autopilot(ContinueBaseModel):
_user_input_queue = AsyncSubscriptionQueue()
_retry_queue = AsyncSubscriptionQueue()
- @cached_property
- def continue_sdk(self) -> ContinueSDK:
- return ContinueSDK(self)
+ @classmethod
+ async def create(cls, policy: Policy, ide: AbstractIdeProtocolServer, full_state: FullState) -> "Autopilot":
+ autopilot = cls(ide=ide, policy=policy)
+ autopilot.continue_sdk = await ContinueSDK.create(autopilot)
+ return autopilot
class Config:
arbitrary_types_allowed = True
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index aa2d8892..d73561d2 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -1,6 +1,6 @@
import asyncio
from functools import cached_property
-from typing import Coroutine, Union
+from typing import Coroutine, Dict, Union
import os
from ..steps.core.core import DefaultModelEditCodeStep
@@ -13,7 +13,7 @@ from ..libs.llm.hf_inference_api import HuggingFaceInferenceAPI
from ..libs.llm.openai import OpenAI
from .observation import Observation
from ..server.ide_protocol import AbstractIdeProtocolServer
-from .main import Context, ContinueCustomException, HighlightedRangeContext, History, Step, ChatMessage, ChatMessageRole
+from .main import Context, ContinueCustomException, History, Step, ChatMessage
from ..steps.core.core import *
from ..libs.llm.proxy_server import ProxyServer
@@ -22,26 +22,46 @@ class Autopilot:
pass
+ModelProvider = Literal["openai", "hf_inference_api", "ggml", "anthropic"]
+MODEL_PROVIDER_TO_ENV_VAR = {
+ "openai": "OPENAI_API_KEY",
+ "hf_inference_api": "HUGGING_FACE_TOKEN",
+ "anthropic": "ANTHROPIC_API_KEY"
+}
+
+
class Models:
- def __init__(self, sdk: "ContinueSDK"):
+ provider_keys: Dict[ModelProvider, str] = {}
+ model_providers: List[ModelProvider]
+
+ def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]):
self.sdk = sdk
+ self.model_providers = model_providers
+
+ @classmethod
+ async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models":
+ models = Models(sdk, with_providers)
+ for provider in with_providers:
+ if provider in MODEL_PROVIDER_TO_ENV_VAR:
+ env_var = MODEL_PROVIDER_TO_ENV_VAR[provider]
+ models.provider_keys[provider] = await sdk.get_user_secret(
+ env_var, f'Please add your {env_var} to the .env file')
+
+ return models
def __load_openai_model(self, model: str) -> OpenAI:
- async def load_openai_model():
- api_key = await self.sdk.get_user_secret(
- 'OPENAI_API_KEY', 'Enter your OpenAI API key or press enter to try for free')
- if api_key == "":
- return ProxyServer(self.sdk.ide.unique_id, model)
- return OpenAI(api_key=api_key, default_model=model)
- return asyncio.get_event_loop().run_until_complete(load_openai_model())
+ api_key = self.provider_keys["openai"]
+ if api_key == "":
+ return ProxyServer(self.sdk.ide.unique_id, model)
+ return OpenAI(api_key=api_key, default_model=model)
+
+ def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI:
+ api_key = self.provider_keys["hf_inference_api"]
+ return HuggingFaceInferenceAPI(api_key=api_key, model=model)
@cached_property
def starcoder(self):
- async def load_starcoder():
- api_key = await self.sdk.get_user_secret(
- 'HUGGING_FACE_TOKEN', 'Please add your Hugging Face token to the .env file')
- return HuggingFaceInferenceAPI(api_key=api_key)
- return asyncio.get_event_loop().run_until_complete(load_starcoder())
+ return self.__load_hf_inference_api_model("bigcode/starcoder")
@cached_property
def gpt35(self):
@@ -74,7 +94,7 @@ class Models:
@property
def default(self):
default_model = self.sdk.config.default_model
- return self.__model_from_name(default_model) if default_model is not None else self.gpt35
+ return self.__model_from_name(default_model) if default_model is not None else self.gpt4
class ContinueSDK(AbstractContinueSDK):
@@ -87,10 +107,15 @@ class ContinueSDK(AbstractContinueSDK):
def __init__(self, autopilot: Autopilot):
self.ide = autopilot.ide
self.__autopilot = autopilot
- self.models = Models(self)
self.context = autopilot.context
self.config = self._load_config()
+ @classmethod
+ async def create(cls, autopilot: Autopilot) -> "ContinueSDK":
+ sdk = ContinueSDK(autopilot)
+ sdk.models = await Models.create(sdk)
+ return sdk
+
config: ContinueConfig
def _load_config(self) -> ContinueConfig:
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 1586c620..803ba122 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -9,7 +9,11 @@ DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
api_key: str
- model: str = "bigcode/starcoder"
+ model: str
+
+ def __init__(self, api_key: str, model: str):
+ self.api_key = api_key
+ self.model = model
def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
"""Return the completion of the text with the given temperature."""
diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py
index 7875c94d..77b13483 100644
--- a/continuedev/src/continuedev/server/ide.py
+++ b/continuedev/src/continuedev/server/ide.py
@@ -227,8 +227,8 @@ class IdeProtocolServer(AbstractIdeProtocolServer):
})
async def getSessionId(self):
- session_id = self.session_manager.new_session(
- self, self.session_id).session_id
+ session_id = (await self.session_manager.new_session(
+ self, self.session_id)).session_id
await self._send_json("getSessionId", {
"sessionId": session_id
})
diff --git a/continuedev/src/continuedev/server/session_manager.py b/continuedev/src/continuedev/server/session_manager.py
index fb8ac386..6d109ca6 100644
--- a/continuedev/src/continuedev/server/session_manager.py
+++ b/continuedev/src/continuedev/server/session_manager.py
@@ -53,18 +53,18 @@ class SessionManager:
session_files = os.listdir(sessions_folder)
if f"{session_id}.json" in session_files and session_id in self.registered_ides:
if self.registered_ides[session_id].session_id is not None:
- return self.new_session(self.registered_ides[session_id], session_id=session_id)
+ return await self.new_session(self.registered_ides[session_id], session_id=session_id)
raise KeyError("Session ID not recognized", session_id)
return self.sessions[session_id]
- def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
+ async def new_session(self, ide: AbstractIdeProtocolServer, session_id: Union[str, None] = None) -> Session:
full_state = None
if session_id is not None and os.path.exists(getSessionFilePath(session_id)):
with open(getSessionFilePath(session_id), "r") as f:
full_state = FullState(**json.load(f))
- autopilot = DemoAutopilot(
+ autopilot = await DemoAutopilot.create(
policy=DemoPolicy(), ide=ide, full_state=full_state)
session_id = session_id or str(uuid4())
ide.session_id = session_id
diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx
index 14e9b854..7f23e333 100644
--- a/extension/react-app/src/components/StepContainer.tsx
+++ b/extension/react-app/src/components/StepContainer.tsx
@@ -181,7 +181,7 @@ function StepContainer(props: StepContainerProps) {
}
className="overflow-hidden cursor-pointer"
onClick={(e) => {
- if (e.metaKey) {
+ if (isMetaEquivalentKeyPressed(e)) {
props.onToggleAll();
} else {
props.onToggle();
diff --git a/extension/react-app/src/components/TextDialog.tsx b/extension/react-app/src/components/TextDialog.tsx
index ea5727f0..c724697d 100644
--- a/extension/react-app/src/components/TextDialog.tsx
+++ b/extension/react-app/src/components/TextDialog.tsx
@@ -81,7 +81,11 @@ const TextDialog = (props: {
rows={10}
ref={textAreaRef}
onKeyDown={(e) => {
- if (e.key === "Enter" && e.metaKey && textAreaRef.current) {
+ if (
+ e.key === "Enter" &&
+ isMetaEquivalentKeyPressed(e) &&
+ textAreaRef.current
+ ) {
props.onEnter(textAreaRef.current.value);
setText("");
} else if (e.key === "Escape") {
diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx
index 57cebac3..cb0404ab 100644
--- a/extension/react-app/src/pages/gui.tsx
+++ b/extension/react-app/src/pages/gui.tsx
@@ -137,12 +137,12 @@ function GUI(props: GUIProps) {
useEffect(() => {
const listener = (e: any) => {
// Cmd + i to toggle fast model
- if (e.key === "i" && e.metaKey && e.shiftKey) {
+ if (e.key === "i" && isMetaEquivalentKeyPressed(e) && e.shiftKey) {
setUsingFastModel((prev) => !prev);
// Cmd + backspace to stop currently running step
} else if (
e.key === "Backspace" &&
- e.metaKey &&
+ isMetaEquivalentKeyPressed(e) &&
typeof history?.current_index !== "undefined" &&
history.timeline[history.current_index]?.active
) {
@@ -220,7 +220,7 @@ function GUI(props: GUIProps) {
if (mainTextInputRef.current) {
let input = (mainTextInputRef.current as any).inputValue;
// cmd+enter to /edit
- if (event?.metaKey) {
+ if (isMetaEquivalentKeyPressed(event)) {
input = `/edit ${input}`;
}
(mainTextInputRef.current as any).setInputValue("");
diff --git a/extension/react-app/src/util/index.ts b/extension/react-app/src/util/index.ts
new file mode 100644
index 00000000..ad711321
--- /dev/null
+++ b/extension/react-app/src/util/index.ts
@@ -0,0 +1,30 @@
+type Platform = "mac" | "linux" | "windows" | "unknown";
+
+function getPlatform(): Platform {
+ const platform = window.navigator.platform.toUpperCase();
+ if (platform.indexOf("MAC") >= 0) {
+ return "mac";
+ } else if (platform.indexOf("LINUX") >= 0) {
+ return "linux";
+ } else if (platform.indexOf("WIN") >= 0) {
+ return "windows";
+ } else {
+ return "unknown";
+ }
+}
+
+function isMetaEquivalentKeyPressed(event: {
+ metaKey: boolean;
+ ctrlKey: boolean;
+}): boolean {
+ const platform = getPlatform();
+ switch (platform) {
+ case "mac":
+ return event.metaKey;
+ case "linux":
+ case "windows":
+ return event.ctrlKey;
+ default:
+ return event.metaKey;
+ }
+}