diff options
author | Nate Sesti <sestinj@gmail.com> | 2023-08-20 17:21:00 -0700 |
---|---|---|
committer | Nate Sesti <sestinj@gmail.com> | 2023-08-20 17:21:00 -0700 |
commit | 84ec574e182ec441e95d13c3543a934e0a036228 (patch) | |
tree | 9e058eb72fbd8de31dcf221175538f4056c38f5a /continuedev | |
parent | df86a15774716f7f7e0903f4917eb284708a5556 (diff) | |
download | sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.tar.gz sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.tar.bz2 sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.zip |
fix: :bug: fix replicate to work with models requiring prompt input
Diffstat (limited to 'continuedev')
-rw-r--r-- | continuedev/src/continuedev/libs/llm/replicate.py | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py index c4373185..0424d827 100644 --- a/continuedev/src/continuedev/libs/llm/replicate.py +++ b/continuedev/src/continuedev/libs/llm/replicate.py @@ -10,7 +10,7 @@ from . import LLM class ReplicateLLM(LLM): api_key: str - model: str = "nateraw/stablecode-completion-alpha-3b-4k:e82ebe958f0a5be6846d1a82041925767edb1d1f162596c643e48fbea332b1bb" + model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" max_context_length: int = 2048 _client: replicate.Client = None @@ -40,7 +40,9 @@ class ReplicateLLM(LLM): self, prompt: str, with_history: List[ChatMessage] = None, **kwargs ): def helper(): - output = self._client.run(self.model, input={"message": prompt}) + output = self._client.run( + self.model, input={"message": prompt, "prompt": prompt} + ) completion = "" for item in output: completion += item @@ -56,11 +58,14 @@ class ReplicateLLM(LLM): async def stream_complete( self, prompt, with_history: List[ChatMessage] = None, **kwargs ): - for item in self._client.run(self.model, input={"message": prompt}): + for item in self._client.run( + self.model, input={"message": prompt, "prompt": prompt} + ): yield item async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs): for item in self._client.run( - self.model, input={"message": messages[-1].content} + self.model, + input={"message": messages[-1].content, "prompt": messages[-1].content}, ): yield {"content": item, "role": "assistant"} |