summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-07-31 13:23:48 -0700
committerGitHub <noreply@github.com>2023-07-31 13:23:48 -0700
commitb9b965614f26eaf716c0a2d7ffa6f4aab52eefa9 (patch)
tree2fecad4589d4cbe1bdd2df2d6ed8630d28312dd4
parent457c9940ec6bdabd89de84a23abbf246aaf662c4 (diff)
parent72d18fb8aaac9d192a508cd54fdb296321972379 (diff)
downloadsncontinue-b9b965614f26eaf716c0a2d7ffa6f4aab52eefa9.tar.gz
sncontinue-b9b965614f26eaf716c0a2d7ffa6f4aab52eefa9.tar.bz2
sncontinue-b9b965614f26eaf716c0a2d7ffa6f4aab52eefa9.zip
Merge pull request #330 from continuedev/ollama
Llama-2 support with Ollama
-rw-r--r--continuedev/src/continuedev/core/autopilot.py4
-rw-r--r--continuedev/src/continuedev/core/sdk.py1
-rw-r--r--continuedev/src/continuedev/libs/llm/ollama.py139
-rw-r--r--extension/package-lock.json4
-rw-r--r--extension/package.json2
5 files changed, 145 insertions, 5 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index de95a259..d92c51cd 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -98,8 +98,8 @@ class Autopilot(ContinueBaseModel):
user_input_queue=self._main_user_input_queue,
slash_commands=self.get_available_slash_commands(),
adding_highlighted_code=self.context_manager.context_providers[
- "code"].adding_highlighted_code,
- selected_context_items=await self.context_manager.get_selected_items()
+ "code"].adding_highlighted_code if self.context_manager is not None else False,
+ selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],
)
self.full_state = full_state
return full_state
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index bf22d696..a5b16168 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -40,6 +40,7 @@ class ContinueSDK(AbstractContinueSDK):
@classmethod
async def create(cls, autopilot: Autopilot) -> "ContinueSDK":
sdk = ContinueSDK(autopilot)
+ autopilot.continue_sdk = sdk
try:
config = sdk._load_config_dot_py()
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py
new file mode 100644
index 00000000..a9f9f7aa
--- /dev/null
+++ b/continuedev/src/continuedev/libs/llm/ollama.py
@@ -0,0 +1,139 @@
+from functools import cached_property
+import json
+from typing import Any, Coroutine, Dict, Generator, List, Union
+
+import aiohttp
+from ...core.main import ChatMessage
+from ..llm import LLM
+from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens
+
+
+class Ollama(LLM):
+ model: str = "llama2"
+ server_url: str = "http://localhost:11434"
+ max_context_length: int = 2048
+
+ _client_session: aiohttp.ClientSession = None
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ async def start(self, **kwargs):
+ self._client_session = aiohttp.ClientSession()
+
+ async def stop(self):
+ await self._client_session.close()
+
+ @property
+ def name(self):
+ return self.model
+
+ @property
+ def context_length(self) -> int:
+ return self.max_context_length
+
+ @property
+ def default_args(self):
+ return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
+
+ def count_tokens(self, text: str):
+ return count_tokens(self.name, text)
+
+ def convert_to_chat(self, msgs: ChatMessage) -> str:
+ if len(msgs) == 0:
+ return ""
+
+ prompt = ""
+ has_system = msgs[0]["role"] == "system"
+ if has_system:
+ system_message = f"""\
+ <<SYS>>
+ {self.system_message}
+ <</SYS>>
+
+ """
+ if len(msgs) > 1:
+ prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"
+ else:
+ prompt += f"[INST] {system_message} [/INST]"
+ return
+
+ for i in range(2 if has_system else 0, len(msgs)):
+ if msgs[i]["role"] == "user":
+ prompt += f"[INST] {msgs[i]['content']} [/INST]"
+ else:
+ prompt += msgs[i]['content']
+
+ return prompt
+
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)
+ prompt = self.convert_to_chat(messages)
+
+ async with self._client_session.post(f"{self.server_url}/api/generate", json={
+ "prompt": prompt,
+ "model": self.model,
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ yield j["response"]
+ except:
+ raise Exception(str(line[0]))
+
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ args = {**self.default_args, **kwargs}
+ messages = compile_chat_messages(
+ self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message)
+ prompt = self.convert_to_chat(messages)
+
+ async with self._client_session.post(f"{self.server_url}/api/generate", json={
+ "prompt": prompt,
+ "model": self.model,
+ }) as resp:
+ # This is streaming application/json instaed of text/event-stream
+ async for line in resp.content.iter_chunks():
+ if line[1]:
+ try:
+ json_chunk = line[0].decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ yield {
+ "role": "assistant",
+ "content": j["response"]
+ }
+ except:
+ raise Exception(str(line[0]))
+
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
+ completion = ""
+
+ async with self._client_session.post(f"{self.server_url}/api/generate", json={
+ "prompt": prompt,
+ "model": self.model,
+ }) as resp:
+ async for line in resp.content.iter_any():
+ if line:
+ try:
+ json_chunk = line.decode("utf-8")
+ chunks = json_chunk.split("\n")
+ for chunk in chunks:
+ if chunk.strip() != "":
+ j = json.loads(chunk)
+ if "response" in j:
+ completion += j["response"]
+ except:
+ raise Exception(str(line[0]))
+
+ return completion
diff --git a/extension/package-lock.json b/extension/package-lock.json
index 4c0b7093..2ab3ad94 100644
--- a/extension/package-lock.json
+++ b/extension/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "continue",
- "version": "0.0.228",
+ "version": "0.0.229",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "continue",
- "version": "0.0.228",
+ "version": "0.0.229",
"license": "Apache-2.0",
"dependencies": {
"@electron/rebuild": "^3.2.10",
diff --git a/extension/package.json b/extension/package.json
index df54bb4f..481fbdd9 100644
--- a/extension/package.json
+++ b/extension/package.json
@@ -14,7 +14,7 @@
"displayName": "Continue",
"pricing": "Free",
"description": "The open-source coding autopilot",
- "version": "0.0.228",
+ "version": "0.0.229",
"publisher": "Continue",
"engines": {
"vscode": "^1.67.0"