summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/core/context.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py61
-rw-r--r--continuedev/src/continuedev/server/gui.py6
-rw-r--r--extension/react-app/src/pages/gui.tsx66
-rw-r--r--extension/src/debugPanel.ts4
5 files changed, 115 insertions, 28 deletions
diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py
index f83cbd34..b1f68b50 100644
--- a/continuedev/src/continuedev/core/context.py
+++ b/continuedev/src/continuedev/core/context.py
@@ -10,6 +10,7 @@ from pydantic import BaseModel
from .main import ChatMessage, ContextItem, ContextItemDescription, ContextItemId
from ..server.meilisearch_server import check_meilisearch_running
from ..libs.util.logging import logger
+from ..libs.util.telemetry import posthog_logger
SEARCH_INDEX_NAME = "continue_context_items"
@@ -199,6 +200,11 @@ class ContextManager:
raise ValueError(
f"Context provider with title {id.provider_title} not found")
+ posthog_logger.capture_event("select_context_item", {
+ "provider_title": id.provider_title,
+ "item_id": id.item_id,
+ "query": query
+ })
await self.context_providers[id.provider_title].add_context_item(id, query)
async def delete_context_with_ids(self, ids: List[str]):
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 49f593d8..8945250c 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -1,34 +1,58 @@
-from typing import List, Optional
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+import aiohttp
+import requests
+
+from ...core.main import ChatMessage
+from ..util.count_tokens import DEFAULT_ARGS, count_tokens
from ...core.main import ChatMessage
from ..llm import LLM
-import requests
-DEFAULT_MAX_TOKENS = 2048
DEFAULT_MAX_TIME = 120.
class HuggingFaceInferenceAPI(LLM):
model: str
+ hf_token: str
+
+ max_context_length: int = 2048
+ verify_ssl: bool = True
+
+ _client_session: aiohttp.ClientSession = None
- requires_api_key: str = "HUGGING_FACE_TOKEN"
- api_key: str = None
+ class Config:
+ arbitrary_types_allowed = True
- def __init__(self, model: str, system_message: str = None):
- self.model = model
- self.system_message = system_message # TODO: Nothing being done with this
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl))
- async def start(self, *, api_key: Optional[str] = None, **kwargs):
- self.api_key = api_key
+ async def stop(self):
+ await self._client_session.close()
- def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
- "Authorization": f"Bearer {self.api_key}"}
+ "Authorization": f"Bearer {self.hf_token}"}
response = requests.post(API_URL, headers=headers, json={
"inputs": prompt, "parameters": {
- "max_new_tokens": DEFAULT_MAX_TOKENS,
+ "max_new_tokens": min(250, self.max_context_length - self.count_tokens(prompt)),
"max_time": DEFAULT_MAX_TIME,
"return_full_text": False,
}
@@ -41,3 +65,14 @@ class HuggingFaceInferenceAPI(LLM):
"Hugging Face returned an error response: \n\n", data)
return data[0]["generated_text"]
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]:
+ response = await self.complete(messages[-1].content, messages[:-1])
+ yield {
+ "content": response,
+ "role": "assistant"
+ }
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Any | List | Dict, None, None]:
+ response = await self.complete(prompt, with_history)
+ yield response
diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py
index 4470999a..49d46be3 100644
--- a/continuedev/src/continuedev/server/gui.py
+++ b/continuedev/src/continuedev/server/gui.py
@@ -139,6 +139,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
def on_toggle_adding_highlighted_code(self):
create_async_task(
self.session.autopilot.toggle_adding_highlighted_code(), self.on_error)
+ posthog_logger.capture_event("toggle_adding_highlighted_code", {})
def on_set_editing_at_ids(self, ids: List[str]):
create_async_task(
@@ -150,6 +151,7 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
["This is a log of the exact prompt/completion pairs sent/received from the LLM during this step"] + self.session.autopilot.continue_sdk.history.timeline[index].logs)
create_async_task(
self.session.autopilot.ide.showVirtualFile(name, logs), self.on_error)
+ posthog_logger.capture_event("show_logs_at_index", {})
def select_context_item(self, id: str, query: str):
"""Called when user selects an item from the dropdown"""
@@ -164,6 +166,10 @@ class GUIProtocolServer(AbstractGUIProtocolServer):
create_async_task(
load_and_tell_to_reconnect(), self.on_error)
+ posthog_logger.capture_event("load_session", {
+ "session_id": session_id
+ })
+
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session: Session = Depends(websocket_session)):
diff --git a/extension/react-app/src/pages/gui.tsx b/extension/react-app/src/pages/gui.tsx
index ff113636..4c89bbaa 100644
--- a/extension/react-app/src/pages/gui.tsx
+++ b/extension/react-app/src/pages/gui.tsx
@@ -351,21 +351,57 @@ function GUI(props: GUIProps) {
>
Continue Server Starting
</p>
- <p
- style={{
- margin: "auto",
- textAlign: "center",
- marginTop: "4px",
- fontSize: "12px",
- cursor: "pointer",
- opacity: 0.7,
- }}
- onClick={() => {
- postVscMessage("toggleDevTools", {});
- }}
- >
- <u>Click to view logs</u>
- </p>
+ <div className="flex mx-8 my-2">
+ <p
+ style={{
+ margin: "auto",
+ textAlign: "center",
+ marginTop: "4px",
+ fontSize: "12px",
+ cursor: "pointer",
+ opacity: 0.7,
+ }}
+ >
+ <u>
+ <a
+ style={{ color: "inherit" }}
+ href="https://continue.dev/docs/troubleshooting"
+ >
+ Troubleshooting help
+ </a>
+ </u>
+ </p>
+ <p
+ style={{
+ margin: "auto",
+ textAlign: "center",
+ marginTop: "4px",
+ fontSize: "12px",
+ cursor: "pointer",
+ opacity: 0.7,
+ }}
+ onClick={() => {
+ postVscMessage("toggleDevTools", {});
+ }}
+ >
+ <u>View logs</u>
+ </p>
+ <p
+ style={{
+ margin: "auto",
+ textAlign: "center",
+ marginTop: "4px",
+ fontSize: "12px",
+ cursor: "pointer",
+ opacity: 0.7,
+ }}
+ onClick={() => {
+ postVscMessage("reloadWindow", {});
+ }}
+ >
+ <u>Reload the window</u>
+ </p>
+ </div>
<div className="w-3/4 m-auto text-center text-xs">
Tip: Drag the Continue logo from the far left of the window to the
right, then toggle Continue using option/alt+command+m.
diff --git a/extension/src/debugPanel.ts b/extension/src/debugPanel.ts
index d133080b..e6dade37 100644
--- a/extension/src/debugPanel.ts
+++ b/extension/src/debugPanel.ts
@@ -252,6 +252,10 @@ export function setupDebugPanel(
vscode.commands.executeCommand("continue.viewLogs");
break;
}
+ case "reloadWindow": {
+ vscode.commands.executeCommand("workbench.action.reloadWindow");
+ break;
+ }
case "focusEditor": {
setFocusedOnContinueInput(false);
vscode.commands.executeCommand(