diff options
-rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 109 |
1 files changed, 53 insertions, 56 deletions
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 56b123db..44734b1c 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -68,17 +68,16 @@ class ProxyServer(LLM): messages = compile_chat_messages( args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - if resp.status != 200: - raise Exception(await resp.text()) - - response_text = await resp.text() - self.write_log(f"Completion: \n\n{response_text}") - return response_text + async with self._client_session.post(f"{SERVER_URL}/complete", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + if resp.status != 200: + raise Exception(await resp.text()) + + response_text = await resp.text() + self.write_log(f"Completion: \n\n{response_text}") + return response_text async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} @@ -86,34 +85,33 @@ class ProxyServer(LLM): args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/stream_chat", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - # This is streaming application/json instaed of text/event-stream - completion = "" - if resp.status != 200: - raise Exception(await resp.text()) - async for line in resp.content.iter_chunks(): - if line[1]: - try: - json_chunk = line[0].decode("utf-8") - json_chunk = "{}" if json_chunk == "" else json_chunk - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - loaded_chunk = json.loads(chunk) - yield loaded_chunk - if "content" in loaded_chunk: - completion += loaded_chunk["content"] - except Exception as e: - posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { - "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) - else: - break - - self.write_log(f"Completion: \n\n{completion}") + async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + # This is streaming application/json instaed of text/event-stream + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + json_chunk = "{}" if json_chunk == "" else json_chunk + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + loaded_chunk = json.loads(chunk) + yield loaded_chunk + if "content" in loaded_chunk: + completion += loaded_chunk["content"] + except Exception as e: + posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { + "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + else: + break + + self.write_log(f"Completion: \n\n{completion}") async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} @@ -121,20 +119,19 @@ class ProxyServer(LLM): self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") - async with self._client_session as session: - async with session.post(f"{SERVER_URL}/stream_complete", json={ - "messages": messages, - **args - }, headers=self.get_headers()) as resp: - completion = "" - if resp.status != 200: - raise Exception(await resp.text()) - async for line in resp.content.iter_any(): - if line: - try: - decoded_line = line.decode("utf-8") - yield decoded_line - completion += decoded_line - except: - raise Exception(str(line)) - self.write_log(f"Completion: \n\n{completion}") + async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ + "messages": messages, + **args + }, headers=self.get_headers()) as resp: + completion = "" + if resp.status != 200: + raise Exception(await resp.text()) + async for line in resp.content.iter_any(): + if line: + try: + decoded_line = line.decode("utf-8") + yield decoded_line + completion += decoded_line + except: + raise Exception(str(line)) + self.write_log(f"Completion: \n\n{completion}") |