From 885f88af1d7b35e03b1de4df3e74a60da1a777ed Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Mon, 24 Jul 2023 00:52:06 -0700 Subject: move proxy server unique id to header --- continuedev/src/continuedev/core/autopilot.py | 2 ++ .../src/continuedev/libs/llm/proxy_server.py | 30 +++++++++++++--------- 2 files changed, 20 insertions(+), 12 deletions(-) (limited to 'continuedev/src') diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index ecc587ce..9dbced32 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -37,6 +37,8 @@ def get_error_title(e: Exception) -> str: return "The request failed. Please check your internet connection and try again. If this issue persists, you can use our API key for free by going to VS Code settings and changing the value of continue.OPENAI_API_KEY to \"\"" elif isinstance(e, openai_errors.InvalidRequestError): return 'Invalid request sent to OpenAI. Please try again.' + elif "rate_limit_ip_middleware" in e.__str__(): + return 'You have reached your limit for free usage of our token. You can continue using Continue by entering your own OpenAI API key in VS Code settings.' elif e.__str__().startswith("Cannot connect to host"): return "The request failed. Please check your internet connection and try again." return e.__str__() or e.__repr__() diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index b1bb8f06..75c91c4e 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,5 +1,4 @@ -from functools import cached_property import json import traceback from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union @@ -37,6 +36,10 @@ class ProxyServer(LLM): def count_tokens(self, text: str): return count_tokens(self.default_model, text) + + def get_headers(self): + # headers with unique id + return {"unique_id": self.unique_id} async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} @@ -47,16 +50,15 @@ class ProxyServer(LLM): async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ "messages": messages, - "unique_id": self.unique_id, **args - }) as resp: - try: - response_text = await resp.text() - self.write_log(f"Completion: \n\n{response_text}") - return response_text - except: + }, headers=self.get_headers()) as resp: + if resp.status != 200: raise Exception(await resp.text()) + response_text = await resp.text() + self.write_log(f"Completion: \n\n{response_text}") + return response_text + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( @@ -66,11 +68,12 @@ class ProxyServer(LLM): async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ "messages": messages, - "unique_id": self.unique_id, **args - }) as resp: + }, headers=self.get_headers()) as resp: # This is streaming application/json instaed of text/event-stream completion = "" + if resp.status != 200: + raise Exception(await resp.text()) async for line in resp.content.iter_chunks(): if line[1]: try: @@ -86,6 +89,8 @@ class ProxyServer(LLM): except Exception as e: capture_event(self.unique_id, "proxy_server_parse_error", { "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + else: + break self.write_log(f"Completion: \n\n{completion}") @@ -98,10 +103,11 @@ class ProxyServer(LLM): async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ "messages": messages, - "unique_id": self.unique_id, **args - }) as resp: + }, headers=self.get_headers()) as resp: completion = "" + if resp.status != 200: + raise Exception(await resp.text()) async for line in resp.content.iter_any(): if line: try: -- cgit v1.2.3-70-g09d2