diff options
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 139 | ||||
-rw-r--r-- | extension/package-lock.json | 4 | ||||
-rw-r--r-- | extension/package.json | 2 |
5 files changed, 145 insertions, 5 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index de95a259..d92c51cd 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -98,8 +98,8 @@ class Autopilot(ContinueBaseModel): user_input_queue=self._main_user_input_queue, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ - "code"].adding_highlighted_code, - selected_context_items=await self.context_manager.get_selected_items() + "code"].adding_highlighted_code if self.context_manager is not None else False, + selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], ) self.full_state = full_state return full_state diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index bf22d696..a5b16168 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -40,6 +40,7 @@ class ContinueSDK(AbstractContinueSDK): @classmethod async def create(cls, autopilot: Autopilot) -> "ContinueSDK": sdk = ContinueSDK(autopilot) + autopilot.continue_sdk = sdk try: config = sdk._load_config_dot_py() diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py new file mode 100644 index 00000000..a9f9f7aa --- /dev/null +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -0,0 +1,139 @@ +from functools import cached_property +import json +from typing import Any, Coroutine, Dict, Generator, List, Union + +import aiohttp +from ...core.main import ChatMessage +from ..llm import LLM +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens + + +class Ollama(LLM): + model: str = "llama2" + server_url: str = "http://localhost:11434" + max_context_length: int = 2048 + + _client_session: aiohttp.ClientSession = None + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession() + + async def stop(self): + await self._client_session.close() + + @property + def name(self): + return self.model + + @property + def context_length(self) -> int: + return self.max_context_length + + @property + def default_args(self): + return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} + + def count_tokens(self, text: str): + return count_tokens(self.name, text) + + def convert_to_chat(self, msgs: ChatMessage) -> str: + if len(msgs) == 0: + return "" + + prompt = "" + has_system = msgs[0]["role"] == "system" + if has_system: + system_message = f"""\ + <<SYS>> + {self.system_message} + <</SYS>> + + """ + if len(msgs) > 1: + prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" + else: + prompt += f"[INST] {system_message} [/INST]" + return + + for i in range(2 if has_system else 0, len(msgs)): + if msgs[i]["role"] == "user": + prompt += f"[INST] {msgs[i]['content']} [/INST]" + else: + prompt += msgs[i]['content'] + + return prompt + + async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) + prompt = self.convert_to_chat(messages) + + async with self._client_session.post(f"{self.server_url}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield j["response"] + except: + raise Exception(str(line[0])) + + async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} + messages = compile_chat_messages( + self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message) + prompt = self.convert_to_chat(messages) + + async with self._client_session.post(f"{self.server_url}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield { + "role": "assistant", + "content": j["response"] + } + except: + raise Exception(str(line[0])) + + async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: + completion = "" + + async with self._client_session.post(f"{self.server_url}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + completion += j["response"] + except: + raise Exception(str(line[0])) + + return completion diff --git a/extension/package-lock.json b/extension/package-lock.json index 4c0b7093..2ab3ad94 100644 --- a/extension/package-lock.json +++ b/extension/package-lock.json @@ -1,12 +1,12 @@ { "name": "continue", - "version": "0.0.228", + "version": "0.0.229", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "continue", - "version": "0.0.228", + "version": "0.0.229", "license": "Apache-2.0", "dependencies": { "@electron/rebuild": "^3.2.10", diff --git a/extension/package.json b/extension/package.json index df54bb4f..481fbdd9 100644 --- a/extension/package.json +++ b/extension/package.json @@ -14,7 +14,7 @@ "displayName": "Continue", "pricing": "Free", "description": "The open-source coding autopilot", - "version": "0.0.228", + "version": "0.0.229", "publisher": "Continue", "engines": { "vscode": "^1.67.0" |