diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-07-31 13:23:11 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-07-31 13:23:11 -0700 | 
| commit | 72d18fb8aaac9d192a508cd54fdb296321972379 (patch) | |
| tree | 2fecad4589d4cbe1bdd2df2d6ed8630d28312dd4 /continuedev/src | |
| parent | 2459afd95d80ab92a61ea4df37d07db5c99fb6fc (diff) | |
| download | sncontinue-72d18fb8aaac9d192a508cd54fdb296321972379.tar.gz sncontinue-72d18fb8aaac9d192a508cd54fdb296321972379.tar.bz2 sncontinue-72d18fb8aaac9d192a508cd54fdb296321972379.zip | |
feat: :sparkles: llama-2 support
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 1 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/ollama.py | 133 | 
3 files changed, 81 insertions, 57 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index de95a259..d92c51cd 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -98,8 +98,8 @@ class Autopilot(ContinueBaseModel):              user_input_queue=self._main_user_input_queue,              slash_commands=self.get_available_slash_commands(),              adding_highlighted_code=self.context_manager.context_providers[ -                "code"].adding_highlighted_code, -            selected_context_items=await self.context_manager.get_selected_items() +                "code"].adding_highlighted_code if self.context_manager is not None else False, +            selected_context_items=await self.context_manager.get_selected_items() if self.context_manager is not None else [],          )          self.full_state = full_state          return full_state diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index bf22d696..a5b16168 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -40,6 +40,7 @@ class ContinueSDK(AbstractContinueSDK):      @classmethod      async def create(cls, autopilot: Autopilot) -> "ContinueSDK":          sdk = ContinueSDK(autopilot) +        autopilot.continue_sdk = sdk          try:              config = sdk._load_config_dot_py() diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py index 86da4115..a9f9f7aa 100644 --- a/continuedev/src/continuedev/libs/llm/ollama.py +++ b/continuedev/src/continuedev/libs/llm/ollama.py @@ -11,13 +11,28 @@ from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_token  class Ollama(LLM):      model: str = "llama2"      server_url: str = "http://localhost:11434" -    max_context_length: int +    max_context_length: int = 2048 -    @cached_property +    _client_session: aiohttp.ClientSession = None + +    class Config: +        arbitrary_types_allowed = True + +    async def start(self, **kwargs): +        self._client_session = aiohttp.ClientSession() + +    async def stop(self): +        await self._client_session.close() + +    @property      def name(self):          return self.model      @property +    def context_length(self) -> int: +        return self.max_context_length + +    @property      def default_args(self):          return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024} @@ -29,7 +44,7 @@ class Ollama(LLM):              return ""          prompt = "" -        has_system = msgs[0].role == "system" +        has_system = msgs[0]["role"] == "system"          if has_system:              system_message = f"""\                  <<SYS>> @@ -38,79 +53,87 @@ class Ollama(LLM):                  """              if len(msgs) > 1: -                prompt += f"[INST] {system_message}{msgs[1].content} [/INST]" +                prompt += f"[INST] {system_message}{msgs[1]['content']} [/INST]"              else:                  prompt += f"[INST] {system_message} [/INST]"                  return          for i in range(2 if has_system else 0, len(msgs)): -            if msgs[i].role == "user": -                prompt += f"[INST] {msgs[i].content} [/INST]" +            if msgs[i]["role"] == "user": +                prompt += f"[INST] {msgs[i]['content']} [/INST]"              else: -                prompt += msgs[i].content +                prompt += msgs[i]['content']          return prompt      async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, with_history, self.max_context_length, prompt, system_message=self.system_message) +            self.name, with_history, self.context_length, args["max_tokens"], prompt, functions=None, system_message=self.system_message)          prompt = self.convert_to_chat(messages) -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{self.server_urlL}/api/generate", json={ -                "prompt": prompt, -                "model": self.model, -            }) as resp: -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            j = json.dumps(line.decode("utf-8")) -                            yield j["response"] -                            if j["done"]: -                                break -                        except: -                            raise Exception(str(line)) +        async with self._client_session.post(f"{self.server_url}/api/generate", json={ +            "prompt": prompt, +            "model": self.model, +        }) as resp: +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        json_chunk = line.decode("utf-8") +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                j = json.loads(chunk) +                                if "response" in j: +                                    yield j["response"] +                    except: +                        raise Exception(str(line[0]))      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = {**self.default_args, **kwargs}          messages = compile_chat_messages( -            self.name, messages, self.max_context_length, prompt, system_message=self.system_message) +            self.name, messages, self.context_length, args["max_tokens"], None, functions=None, system_message=self.system_message)          prompt = self.convert_to_chat(messages) -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{self.server_urlL}/api/generate", json={ -                "prompt": prompt, -                "model": self.model, -            }) as resp: -                # This is streaming application/json instaed of text/event-stream -                async for line in resp.content.iter_chunks(): -                    if line[1]: -                        try: -                            j = json.dumps(line.decode("utf-8")) -                            yield { -                                "role": "assistant", -                                "content": j["response"] -                            } -                            if j["done"]: -                                break -                        except: -                            raise Exception(str(line[0])) +        async with self._client_session.post(f"{self.server_url}/api/generate", json={ +            "prompt": prompt, +            "model": self.model, +        }) as resp: +            # This is streaming application/json instaed of text/event-stream +            async for line in resp.content.iter_chunks(): +                if line[1]: +                    try: +                        json_chunk = line[0].decode("utf-8") +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                j = json.loads(chunk) +                                if "response" in j: +                                    yield { +                                        "role": "assistant", +                                        "content": j["response"] +                                    } +                    except: +                        raise Exception(str(line[0]))      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          completion = "" -        async with aiohttp.ClientSession() as session: -            async with session.post(f"{self.server_urlL}/api/generate", json={ -                "prompt": prompt, -                "model": self.model, -            }) as resp: -                async for line in resp.content.iter_any(): -                    if line: -                        try: -                            j = json.dumps(line.decode("utf-8")) -                            completion += j["response"] -                            if j["done"]: -                                break -                        except: -                            raise Exception(str(line)) +        async with self._client_session.post(f"{self.server_url}/api/generate", json={ +            "prompt": prompt, +            "model": self.model, +        }) as resp: +            async for line in resp.content.iter_any(): +                if line: +                    try: +                        json_chunk = line.decode("utf-8") +                        chunks = json_chunk.split("\n") +                        for chunk in chunks: +                            if chunk.strip() != "": +                                j = json.loads(chunk) +                                if "response" in j: +                                    completion += j["response"] +                    except: +                        raise Exception(str(line[0]))          return completion | 
