diff options
| author | Luna <git@l4.pm> | 2023-07-30 23:27:24 -0300 | 
|---|---|---|
| committer | Luna <git@l4.pm> | 2023-07-30 23:27:24 -0300 | 
| commit | 0174a769f10f5ea8b1ec06787fc75eca8c45a1f1 (patch) | |
| tree | 692805a8d253159d1b3351ace3ff9a172672cfaa /continuedev | |
| parent | 374058e07ca699b5a345b270067636f6785df3af (diff) | |
| download | sncontinue-0174a769f10f5ea8b1ec06787fc75eca8c45a1f1.tar.gz sncontinue-0174a769f10f5ea8b1ec06787fc75eca8c45a1f1.tar.bz2 sncontinue-0174a769f10f5ea8b1ec06787fc75eca8c45a1f1.zip | |
fix ProxyServer client session usage
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 109 | 
1 files changed, 53 insertions, 56 deletions
| diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 56b123db..44734b1c 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -68,17 +68,16 @@ class ProxyServer(LLM):          messages = compile_chat_messages(              args["model"], with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with self._client_session as session: -            async with session.post(f"{SERVER_URL}/complete", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                if resp.status != 200: -                    raise Exception(await resp.text()) - -                response_text = await resp.text() -                self.write_log(f"Completion: \n\n{response_text}") -                return response_text +        async with self._client_session.post(f"{SERVER_URL}/complete", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            if resp.status != 200: +                raise Exception(await resp.text()) + +            response_text = await resp.text() +            self.write_log(f"Completion: \n\n{response_text}") +            return response_text      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:          args = {**self.default_args, **kwargs} @@ -86,34 +85,33 @@ class ProxyServer(LLM):              args["model"], messages, self.context_length, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with self._client_session as session: -            async with session.post(f"{SERVER_URL}/stream_chat", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                # This is streaming application/json instaed of text/event-stream -                completion = "" -                if resp.status != 200: -                    raise Exception(await resp.text()) -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            json_chunk = line[0].decode("utf-8") -                            json_chunk = "{}" if json_chunk == "" else json_chunk -                            chunks = json_chunk.split("\n") -                            for chunk in chunks: -                                if chunk.strip() != "": -                                    loaded_chunk = json.loads(chunk) -                                    yield loaded_chunk -                                    if "content" in loaded_chunk: -                                        completion += loaded_chunk["content"] -                        except Exception as e: -                            posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { -                                "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) -                    else: -                        break - -                self.write_log(f"Completion: \n\n{completion}") +        async with self._client_session.post(f"{SERVER_URL}/stream_chat", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            # This is streaming application/json instaed of text/event-stream +            completion = "" +            if resp.status != 200: +                raise Exception(await resp.text()) +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        json_chunk = "{}" if json_chunk == "" else json_chunk +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                loaded_chunk = json.loads(chunk) +                                yield loaded_chunk +                                if "content" in loaded_chunk: +                                    completion += loaded_chunk["content"] +                    except Exception as e: +                        posthog_logger.capture_event(self.unique_id, "proxy_server_parse_error", { +                            "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) +                else: +                    break + +            self.write_log(f"Completion: \n\n{completion}")      async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = {**self.default_args, **kwargs} @@ -121,20 +119,19 @@ class ProxyServer(LLM):              self.model, with_history, self.context_length, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)          self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}") -        async with self._client_session as session: -            async with session.post(f"{SERVER_URL}/stream_complete", json={ -                "messages": messages, -                **args -            }, headers=self.get_headers()) as resp: -                completion = "" -                if resp.status != 200: -                    raise Exception(await resp.text()) -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            decoded_line = line.decode("utf-8") -                            yield decoded_line -                            completion += decoded_line -                        except: -                            raise Exception(str(line)) -                self.write_log(f"Completion: \n\n{completion}") +        async with self._client_session.post(f"{SERVER_URL}/stream_complete", json={ +            "messages": messages, +            **args +        }, headers=self.get_headers()) as resp: +            completion = "" +            if resp.status != 200: +                raise Exception(await resp.text()) +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        decoded_line = line.decode("utf-8") +                        yield decoded_line +                        completion += decoded_line +                    except: +                        raise Exception(str(line)) +            self.write_log(f"Completion: \n\n{completion}") | 
