diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-16 11:56:57 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-16 11:56:57 -0700 |
commit | 72e6956f2dce9ca343e68a51e4d07eb5d8bd701a (patch) | |
tree | a61a04b519954ff47fa4c3e102bf91b3c127f97e | |
parent | 121fa902f8c53f536a8ca586a8db10e0c54fdf27 (diff) | |
download | sncontinue-72e6956f2dce9ca343e68a51e4d07eb5d8bd701a.tar.gz sncontinue-72e6956f2dce9ca343e68a51e4d07eb5d8bd701a.tar.bz2 sncontinue-72e6956f2dce9ca343e68a51e4d07eb5d8bd701a.zip |
fix together.py for llama-70b
-rw-r--r-- | continuedev/src/continuedev/libs/llm/together.py | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py index c3f171c9..874dea07 100644 --- a/continuedev/src/continuedev/libs/llm/together.py +++ b/continuedev/src/continuedev/libs/llm/together.py @@ -89,19 +89,20 @@ class TogetherLLM(LLM): }) as resp: async for line in resp.content.iter_chunks(): if line[1]: - try: - json_chunk = line[0].decode("utf-8") - if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): - continue - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": + json_chunk = line[0].decode("utf-8") + if json_chunk.startswith(": ping - ") or json_chunk.startswith("data: [DONE]"): + continue + if json_chunk.startswith("data: "): + json_chunk = json_chunk[6:] + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + json_chunk = json.loads(chunk) + if "choices" in json_chunk: yield { "role": "assistant", - "content": json.loads(chunk[6:])["choices"][0]["text"] + "content": json_chunk["choices"][0]["text"] } - except: - raise Exception(str(line[0])) async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: args = {**self.default_args, **kwargs} @@ -117,6 +118,9 @@ class TogetherLLM(LLM): try: text = await resp.text() j = json.loads(text) - return j["output"]["choices"][0]["text"] + if "choices" not in j["output"]: + raise Exception(text) + if "output" in j: + return j["output"]["choices"][0]["text"] except: raise Exception(await resp.text()) |