summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-08-20 17:21:00 -0700
committerNate Sesti <sestinj@gmail.com>2023-08-20 17:21:00 -0700
commit84ec574e182ec441e95d13c3543a934e0a036228 (patch)
tree9e058eb72fbd8de31dcf221175538f4056c38f5a
parentdf86a15774716f7f7e0903f4917eb284708a5556 (diff)
downloadsncontinue-84ec574e182ec441e95d13c3543a934e0a036228.tar.gz
sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.tar.bz2
sncontinue-84ec574e182ec441e95d13c3543a934e0a036228.zip
fix: :bug: fix replicate to work with models requiring prompt input
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py13
-rw-r--r--docs/docs/customization.md6
2 files changed, 12 insertions, 7 deletions
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
index c4373185..0424d827 100644
--- a/continuedev/src/continuedev/libs/llm/replicate.py
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -10,7 +10,7 @@ from . import LLM
class ReplicateLLM(LLM):
api_key: str
- model: str = "nateraw/stablecode-completion-alpha-3b-4k:e82ebe958f0a5be6846d1a82041925767edb1d1f162596c643e48fbea332b1bb"
+ model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781"
max_context_length: int = 2048
_client: replicate.Client = None
@@ -40,7 +40,9 @@ class ReplicateLLM(LLM):
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
):
def helper():
- output = self._client.run(self.model, input={"message": prompt})
+ output = self._client.run(
+ self.model, input={"message": prompt, "prompt": prompt}
+ )
completion = ""
for item in output:
completion += item
@@ -56,11 +58,14 @@ class ReplicateLLM(LLM):
async def stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
):
- for item in self._client.run(self.model, input={"message": prompt}):
+ for item in self._client.run(
+ self.model, input={"message": prompt, "prompt": prompt}
+ ):
yield item
async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
for item in self._client.run(
- self.model, input={"message": messages[-1].content}
+ self.model,
+ input={"message": messages[-1].content, "prompt": messages[-1].content},
):
yield {"content": item, "role": "assistant"}
diff --git a/docs/docs/customization.md b/docs/docs/customization.md
index a3774e91..b7279fe3 100644
--- a/docs/docs/customization.md
+++ b/docs/docs/customization.md
@@ -117,7 +117,7 @@ config = ContinueConfig(
)
```
-### Replicate (beta)
+### Replicate
Replicate is a great option for newly released language models or models that you've deployed through their platform. Sign up for an account [here](https://replicate.ai/), copy your API key, and then select any model from the [Replicate Streaming List](https://replicate.com/collections/streaming-language-models). Change `~/.continue/config.py` to look like this:
@@ -129,13 +129,13 @@ config = ContinueConfig(
...
models=Models(
default=ReplicateLLM(
- model="stablecode-completion-alpha-3b-4k",
+ model="replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781",
api_key="my-replicate-api-key")
)
)
```
-If you don't specify the `model` parameter, it will default to `stablecode-completion-alpha-3b-4k`.
+If you don't specify the `model` parameter, it will default to `replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781`.
### Self-hosting an open-source model