summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py109
1 files changed, 53 insertions, 56 deletions
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index 56b123db..44734b1c 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -68,17 +68,16 @@ class ProxyServer(LLM):
messages = compile_chat_messages(
args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with self._client_session as session:
- async with session.post(f"{SERVER_URL}/complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- if resp.status != 200:
- raise Exception(await resp.text())
-
- response_text = await resp.text()
- self.write_log(f"Completion: \n\n{response_text}")
- return response_text
+ async with self._client_session.post(f"{SERVER_URL}/complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ if resp.status != 200:
+ raise Exception(await resp.text())
+
+ response_text = await resp.text()
+ self.write_log(f"Completion: \n\n{response_text}")
+ return response_text
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
@@ -86,34 +85,33 @@ class ProxyServer(LLM):
args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with self._client_session as session:
- async with session.post(f"{SERVER_URL}/stream_chat", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- # This is streaming application/json instaed of text/event-stream
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_chunks():
- if line[1]:
- try:
- json_chunk = line[0].decode("utf-8")
- json_chunk = "{}" if json_chunk == "" else json_chunk
- chunks = json_chunk.split("\n")
- for chunk in chunks:
- if chunk.strip() != "":
- loaded_chunk = json.loads(chunk)
- yield loaded_chunk
- if "content" in loaded_chunk:
- completion += loaded_chunk["content"]
- except Exception as e:
- posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
- "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
- else:
- break
-
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ json_chunk = "{}" if json_chunk == "" else json_chunk
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ loaded_chunk = json.loads(chunk)
+ yield loaded_chunk
+ if "content" in loaded_chunk:
+ completion += loaded_chunk["content"]
+ except Exception as e:
+ posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", {
+ "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))})
+ else:
+ break
+
+ self.write_log(f"Completion: \n\n{completion}")
async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
@@ -121,20 +119,19 @@ class ProxyServer(LLM):
self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")
- async with self._client_session as session:
- async with session.post(f"{SERVER_URL}/stream_complete", json={
- "messages": messages,
- **args
- }, headers=self.get_headers()) as resp:
- completion = ""
- if resp.status != 200:
- raise Exception(await resp.text())
- async for line in resp.content.iter_any():
- if line:
- try:
- decoded_line = line.decode("utf-8")
- yield decoded_line
- completion += decoded_line
- except:
- raise Exception(str(line))
- self.write_log(f"Completion: \n\n{completion}")
+ async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={
+ "messages": messages,
+ **args
+ }, headers=self.get_headers()) as resp:
+ completion = ""
+ if resp.status != 200:
+ raise Exception(await resp.text())
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ decoded_line = line.decode("utf-8")
+ yield decoded_line
+ completion += decoded_line
+ except:
+ raise Exception(str(line))
+ self.write_log(f"Completion: \n\n{completion}")