summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py3
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py56
2 files changed, 59 insertions, 0 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 40edb99b..70c67856 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -12,6 +12,9 @@ class LLM(ContinueBaseModel, ABC):
system_message: Optional[str] = None
+ class Config:
+ arbitrary_types_allowed = True
+
@abstractproperty
def name(self):
"""Return the name of the LLM."""
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
new file mode 100644
index 00000000..b13e2dec
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -0,0 +1,56 @@
+from abc import abstractproperty
+from typing import List, Optional
+import replicate
+import concurrent.futures
+
+from ..util.count_tokens import DEFAULT_ARGS, count_tokens
+from ...core.main import ChatMessage
+from . import LLM
+
+
+class ReplicateLLM(LLM):
+ api_key: str
+ model: str = "nateraw/stablecode-completion-alpha-3b-4k:e82ebe958f0a5be6846d1a82041925767edb1d1f162596c643e48fbea332b1bb"
+ max_context_length: int = 2048
+
+ _client: replicate.Client = None
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self):
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ async def start(self):
+ self._client = replicate.Client(api_token=self.api_key)
+
+ async def stop(self):
+ pass
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
+ output = self._client.run(self.model, input={"message": prompt})
+ completion = ''
+ for item in output:
+ completion += item
+
+ return completion
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs):
+ for item in self._client.run(self.model, input={"message": prompt}):
+ yield item
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
+ for item in self._client.run(self.model, input={"message": messages[-1].content}):
+ yield {
+ "content": item,
+ "role": "assistant"
+ }