summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-20 17:21:00 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-20 17:21:00 -0700
commit84ec574e182ec441e95d13c3543a934e0a036228 (patch)
tree9e058eb72fbd8de31dcf221175538f4056c38f5a /continuedev
parentdf86a15774716f7f7e0903f4917eb284708a5556 (diff)
downloadsncontinue-84ec574e182ec441e95d13c3543a934e0a036228.tar.gz
sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.tar.bz2
sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.zip
fix: :bug: fix replicate to work with models requiring prompt input
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
index c4373185..0424d827 100644
--- a/continuedev/src/continuedev/libs/llm/replicate.py
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -10,7 +10,7 @@ from . import LLM
class ReplicateLLM(LLM):
api_key: str
- model: str = "nateraw/stablecode-completion-alpha-3b-4k:e82ebe958f0a5be6846d1a82041925767edb1d1f162596c643e48fbea332b1bb"
+ model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781"
max_context_length: int = 2048
_client: replicate.Client = None
@@ -40,7 +40,9 @@ class ReplicateLLM(LLM):
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
):
def helper():
- output = self._client.run(self.model, input={"message": prompt})
+ output = self._client.run(
+ self.model, input={"message": prompt, "prompt": prompt}
+ )
completion = ""
for item in output:
completion += item
@@ -56,11 +58,14 @@ class ReplicateLLM(LLM):
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
):
- for item in self._client.run(self.model, input={"message": prompt}):
+ for item in self._client.run(
+ self.model, input={"message": prompt, "prompt": prompt}
+ ):
yield item
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
for item in self._client.run(
- self.model, input={"message": messages[-1].content}
+ self.model,
+ input={"message": messages[-1].content, "prompt": messages[-1].content},
):
yield {"content": item, "role": "assistant"}