diff options
-rw-r--r-- | continuedev/src/continuedev/core/context.py | 6 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/hf_inference_api.py | 61 | ||||
-rw-r--r-- | continuedev/src/continuedev/server/gui.py | 6 | ||||
-rw-r--r-- | extension/react-app/src/pages/gui.tsx | 66 | ||||
-rw-r--r-- | extension/src/debugPanel.ts | 4 |
5 files changed, 115 insertions, 28 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index f83cbd34..b1f68b50 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId from ..server.meilisearch_server import check_meilisearch_running from ..libs.util.logging import logger +from ..libs.util.telemetry import posthog_logger SEARCH_INDEX_NAME = "continue_context_items" @@ -199,6 +200,11 @@ class ContextManager: raise ValueError( f"Context provider with title {id.provider_title} not found") + posthog_logger.capture_event("select_context_item", { + "provider_title": id.provider_title, + "item_id": id.item_id, + "query": query + }) await self.context_providers[id.provider_title].add_context_item(id, query) async def delete_context_with_ids(self, ids: List[str]): diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 49f593d8..8945250c 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -1,34 +1,58 @@ -from typing import List, Optional +from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +import aiohttp +import requests + +from ...core.main import ChatMessage +from ..util.count_tokens import DEFAULT_ARGS, count_tokens from ...core.main import ChatMessage from ..llm import LLM -import requests -DEFAULT_MAX_TOKENS = 2048 DEFAULT_MAX_TIME = 120. class HuggingFaceInferenceAPI(LLM): model: str + hf_token: str + + max_context_length: int = 2048 + verify_ssl: bool = True + + _client_session: aiohttp.ClientSession = None - requires_api_key: str = "HUGGING_FACE_TOKEN" - api_key: str = None + class Config: + arbitrary_types_allowed = True - def __init__(self, model: str, system_message: str = None): - self.model = model - self.system_message = system_message # TODO: Nothing being done with this + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)) - async def start(self, *, api_key: Optional[str] = None, **kwargs): - self.api_key = api_key + async def stop(self): + await self._client_session.close() - def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): + @property + def name(self): + return self.model + + @property + def context_length(self): + return self.max_context_length + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs): """Return the completion of the text with the given temperature.""" API_URL = f"https://api-inference.huggingface.co/models/{self.model}" headers = { - "Authorization": f"Bearer {self.api_key}"} + "Authorization": f"Bearer {self.hf_token}"} response = requests.post(API_URL, headers=headers, json={ "inputs": prompt, "parameters": { - "max_new_tokens": DEFAULT_MAX_TOKENS, + "max_new_tokens": min(250, self.max_context_length - self.count_tokens(prompt)), "max_time": DEFAULT_MAX_TIME, "return_full_text": False, } @@ -41,3 +65,14 @@ class HuggingFaceInferenceAPI(LLM): "Hugging Face returned an error response: \n\n", data) return data[0]["generated_text"] + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: + response = await self.complete(messages[-1].content, messages[:-1]) + yield { + "content": response, + "role": "assistant" + } + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]: + response = await self.complete(prompt, with_history) + yield response diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index 4470999a..49d46be3 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -139,6 +139,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer): def on_toggle_adding_highlighted_code(self): create_async_task( self.session.autopilot.toggle_adding_highlighted_code(), self.on_error) + posthog_logger.capture_event("toggle_adding_highlighted_code", {}) def on_set_editing_at_ids(self, ids: List[str]): create_async_task( @@ -150,6 +151,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer): ["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs) create_async_task( self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error) + posthog_logger.capture_event("show_logs_at_index", {}) def select_context_item(self, id: str, query: str): """Called when user selects an item from the dropdown""" @@ -164,6 +166,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer): create_async_task( load_and_tell_to_reconnect(), self.on_error) + posthog_logger.capture_event("load_session", { + "session_id": session_id + }) + @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)): diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx index ff113636..4c89bbaa 100644 --- a/extension/react-app/src/pages/gui.tsx +++ b/extension/react-app/src/pages/gui.tsx @@ -351,21 +351,57 @@ function GUI(props: GUIProps) { > Continue Server Starting </p> - <p - style={{ - margin: "auto", - textAlign: "center", - marginTop: "4px", - fontSize: "12px", - cursor: "pointer", - opacity: 0.7, - }} - onClick={() => { - postVscMessage("toggleDevTools", {}); - }} - > - <u>Click to view logs</u> - </p> + <div className="flex mx-8 my-2"> + <p + style={{ + margin: "auto", + textAlign: "center", + marginTop: "4px", + fontSize: "12px", + cursor: "pointer", + opacity: 0.7, + }} + > + <u> + <a + style={{ color: "inherit" }} + href="https://continue.dev/docs/troubleshooting" + > + Troubleshooting help + </a> + </u> + </p> + <p + style={{ + margin: "auto", + textAlign: "center", + marginTop: "4px", + fontSize: "12px", + cursor: "pointer", + opacity: 0.7, + }} + onClick={() => { + postVscMessage("toggleDevTools", {}); + }} + > + <u>View logs</u> + </p> + <p + style={{ + margin: "auto", + textAlign: "center", + marginTop: "4px", + fontSize: "12px", + cursor: "pointer", + opacity: 0.7, + }} + onClick={() => { + postVscMessage("reloadWindow", {}); + }} + > + <u>Reload the window</u> + </p> + </div> <div className="w-3/4 m-auto text-center text-xs"> Tip: Drag the Continue logo from the far left of the window to the right, then toggle Continue using option/alt+command+m. diff --git a/extension/src/debugPanel.ts b/extension/src/debugPanel.ts index d133080b..e6dade37 100644 --- a/extension/src/debugPanel.ts +++ b/extension/src/debugPanel.ts @@ -252,6 +252,10 @@ export function setupDebugPanel( vscode.commands.executeCommand("continue.viewLogs"); break; } + case "reloadWindow": { + vscode.commands.executeCommand("workbench.action.reloadWindow"); + break; + } case "focusEditor": { setFocusedOnContinueInput(false); vscode.commands.executeCommand( |