diff options
-rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 4 | ||||
-rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 1 | ||||
-rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 133 | ||||
-rw-r--r-- | extension/package-lock.json | 4 | ||||
-rw-r--r-- | extension/package.json | 2 |
5 files changed, 84 insertions, 60 deletions
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index de95a259..d92c51cd 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -98,8 +98,8 @@ class Autopilot(ContinueBaseModel): user_input_queue=self._main_user_input_queue, slash_commands=self.get_available_slash_commands(), adding_highlighted_code=self.context_manager.context_providers[ - "code"].adding_highlighted_code, - selected_context_items=await self.context_manager.get_selected_items() + "code"].adding_highlighted_code if self.context_manager is not None else False, + selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [], ) self.full_state = full_state return full_state diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index bf22d696..a5b16168 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -40,6 +40,7 @@ class ContinueSDK(AbstractContinueSDK): @classmethod async def create(cls, autopilot: Autopilot) -> "ContinueSDK": sdk = ContinueSDK(autopilot) + autopilot.continue_sdk = sdk try: config = sdk._load_config_dot_py() diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index 86da4115..a9f9f7aa 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -11,13 +11,28 @@ from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_token class Ollama(LLM): model: str = "llama2" server_url: str = "http://localhost:11434" - max_context_length: int + max_context_length: int = 2048 - @cached_property + _client_session: aiohttp.ClientSession = None + + class Config: + arbitrary_types_allowed = True + + async def start(self, **kwargs): + self._client_session = aiohttp.ClientSession() + + async def stop(self): + await self._client_session.close() + + @property def name(self): return self.model @property + def context_length(self) -> int: + return self.max_context_length + + @property def default_args(self): return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -29,7 +44,7 @@ class Ollama(LLM): return "" prompt = "" - has_system = msgs[0].role == "system" + has_system = msgs[0]["role"] == "system" if has_system: system_message = f"""\ <<SYS>> @@ -38,79 +53,87 @@ class Ollama(LLM): """ if len(msgs) > 1: - prompt += f"[INST] {system_message}{msgs[1].content} [/INST]" + prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]" else: prompt += f"[INST] {system_message} [/INST]" return for i in range(2 if has_system else 0, len(msgs)): - if msgs[i].role == "user": - prompt += f"[INST] {msgs[i].content} [/INST]" + if msgs[i]["role"] == "user": + prompt += f"[INST] {msgs[i]['content']} [/INST]" else: - prompt += msgs[i].content + prompt += msgs[i]['content'] return prompt async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, self.max_context_length, prompt, system_message=self.system_message) + self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message) prompt = self.convert_to_chat(messages) - async with aiohttp.ClientSession() as session: - async with session.post(f"{self.server_urlL}/api/generate", json={ - "prompt": prompt, - "model": self.model, - }) as resp: - async for line in resp.content.iter_any(): - if line: - try: - j = json.dumps(line.decode("utf-8")) - yield j["response"] - if j["done"]: - break - except: - raise Exception(str(line)) + async with self._client_session.post(f"{self.server_url}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield j["response"] + except: + raise Exception(str(line[0])) async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: + args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, self.max_context_length, prompt, system_message=self.system_message) + self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message) prompt = self.convert_to_chat(messages) - async with aiohttp.ClientSession() as session: - async with session.post(f"{self.server_urlL}/api/generate", json={ - "prompt": prompt, - "model": self.model, - }) as resp: - # This is streaming application/json instaed of text/event-stream - async for line in resp.content.iter_chunks(): - if line[1]: - try: - j = json.dumps(line.decode("utf-8")) - yield { - "role": "assistant", - "content": j["response"] - } - if j["done"]: - break - except: - raise Exception(str(line[0])) + async with self._client_session.post(f"{self.server_url}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + # This is streaming application/json instaed of text/event-stream + async for line in resp.content.iter_chunks(): + if line[1]: + try: + json_chunk = line[0].decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + yield { + "role": "assistant", + "content": j["response"] + } + except: + raise Exception(str(line[0])) async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]: completion = "" - async with aiohttp.ClientSession() as session: - async with session.post(f"{self.server_urlL}/api/generate", json={ - "prompt": prompt, - "model": self.model, - }) as resp: - async for line in resp.content.iter_any(): - if line: - try: - j = json.dumps(line.decode("utf-8")) - completion += j["response"] - if j["done"]: - break - except: - raise Exception(str(line)) + async with self._client_session.post(f"{self.server_url}/api/generate", json={ + "prompt": prompt, + "model": self.model, + }) as resp: + async for line in resp.content.iter_any(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + completion += j["response"] + except: + raise Exception(str(line[0])) return completion diff --git a/extension/package-lock.json b/extension/package-lock.json index 4c0b7093..2ab3ad94 100644 --- a/extension/package-lock.json +++ b/extension/package-lock.json @@ -1,12 +1,12 @@ { "name": "continue", - "version": "0.0.228", + "version": "0.0.229", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "continue", - "version": "0.0.228", + "version": "0.0.229", "license": "Apache-2.0", "dependencies": { "@electron/rebuild": "^3.2.10", diff --git a/extension/package.json b/extension/package.json index df54bb4f..481fbdd9 100644 --- a/extension/package.json +++ b/extension/package.json @@ -14,7 +14,7 @@ "displayName": "Continue", "pricing": "Free", "description": "The open-source coding autopilot", - "version": "0.0.228", + "version": "0.0.229", "publisher": "Continue", "engines": { "vscode": "^1.67.0" |