diff options
| author | Nate Sesti <sestinj@gmail.com> | 2023-08-02 11:35:44 -0700 | 
|---|---|---|
| committer | Nate Sesti <sestinj@gmail.com> | 2023-08-02 11:35:44 -0700 | 
| commit | c7cfc1be10c9875804cd295bbcccb0184a97ac10 (patch) | |
| tree | a9886477a4742ef953c7b07622bf381dfaf4f987 /continuedev | |
| parent | f96a430b2c36ffa3511ffb015a86d5fdfae7d606 (diff) | |
| download | sncontinue-c7cfc1be10c9875804cd295bbcccb0184a97ac10.tar.gz sncontinue-c7cfc1be10c9875804cd295bbcccb0184a97ac10.tar.bz2 sncontinue-c7cfc1be10c9875804cd295bbcccb0184a97ac10.zip | |
anthropic fixes
Diffstat (limited to 'continuedev')
| -rw-r--r-- | continuedev/src/continuedev/core/context.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 34 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/plugins/policies/default.py | 7 | 
6 files changed, 34 insertions, 20 deletions
| diff --git a/continuedev/src/continuedev/core/context.py b/continuedev/src/continuedev/core/context.py index 3f5f6fd3..20725216 100644 --- a/continuedev/src/continuedev/core/context.py +++ b/continuedev/src/continuedev/core/context.py @@ -159,8 +159,9 @@ class ContextManager:              if not meilisearch_running:                  logger.warning(                      "MeiliSearch not running, avoiding any dependent context providers") -                self.context_providers = list( -                    filter(lambda cp: cp.title == "code", self.context_providers)) +                self.context_providers = { +                    title: provider for title, provider in self.context_providers.items() if title == "code" +                }      async def load_index(self, workspace_dir: str):          for _, provider in self.context_providers.items(): diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index a5b16168..30fcc144 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -195,10 +195,8 @@ class ContinueSDK(AbstractContinueSDK):          context_messages: List[ChatMessage] = await self.__autopilot.context_manager.get_chat_messages()          # Insert at the end, but don't insert after latest user message or function call -        i = -2 if (len(history_context) > 0 and ( -            history_context[-1].role == "user" or history_context[-1].role == "function")) else -1          for msg in context_messages: -            history_context.insert(i, msg) +            history_context.insert(-1, msg)          return history_context diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 50577993..40edb99b 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -10,7 +10,7 @@ class LLM(ContinueBaseModel, ABC):      requires_unique_id: bool = False      requires_write_log: bool = False -    system_message: Union[str, None] = None +    system_message: Optional[str] = None      @abstractproperty      def name(self): diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index ec1b7e40..8a548223 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,33 +1,33 @@  from functools import cached_property  import time -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union  from ...core.main import ChatMessage  from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages  class AnthropicLLM(LLM):      model: str = "claude-2"      requires_api_key: str = "ANTHROPIC_API_KEY" +    requires_write_log = True      _async_client: AsyncAnthropic = None      class Config:          arbitrary_types_allowed = True -    def __init__(self, model: str, system_message: str = None): -        self.model = model -        self.system_message = system_message +    write_log: Optional[Callable[[str], None]] = None -    async def start(self, *, api_key: Optional[str] = None, **kwargs): +    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): +        self.write_log = write_log          self._async_client = AsyncAnthropic(api_key=api_key)      async def stop(self):          pass -    @cached_property +    @property      def name(self):          return self.model @@ -72,12 +72,18 @@ class AnthropicLLM(LLM):          args.update(kwargs)          args["stream"] = True          args = self._transform_args(args) +        prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" +        self.write_log(f"Prompt: \n\n{prompt}") +        completion = ""          async for chunk in await self._async_client.completions.create( -            prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", +            prompt=prompt,              **args          ):              yield chunk.completion +            completion += chunk.completion + +        self.write_log(f"Completion: \n\n{completion}")      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy() @@ -86,7 +92,10 @@ class AnthropicLLM(LLM):          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + +        completion = "" +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async for chunk in await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args @@ -95,6 +104,9 @@ class AnthropicLLM(LLM):                  "role": "assistant",                  "content": chunk.completion              } +            completion += chunk.completion + +        self.write_log(f"Completion: \n\n{completion}")      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} @@ -102,9 +114,13 @@ class AnthropicLLM(LLM):          messages = compile_chat_messages(              args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + +        completion = "" +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          resp = (await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          )).completion +        self.write_log(f"Completion: \n\n{resp}")          return resp diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index fce6e8ab..99c851ca 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,5 +1,3 @@ -from functools import cached_property -import json  from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union, Optional  from pydantic import BaseModel diff --git a/continuedev/src/continuedev/plugins/policies/default.py b/continuedev/src/continuedev/plugins/policies/default.py index 523c2cf4..0d74fa3f 100644 --- a/continuedev/src/continuedev/plugins/policies/default.py +++ b/continuedev/src/continuedev/plugins/policies/default.py @@ -1,5 +1,5 @@  from textwrap import dedent -from typing import Union +from typing import Type, Union  from ..steps.chat import SimpleChatStep  from ..steps.welcome import WelcomeStep @@ -46,7 +46,8 @@ def parse_custom_command(inp: str, config: ContinueConfig) -> Union[None, Step]:  class DefaultPolicy(Policy): -    default_step: Step = SimpleChatStep() +    default_step: Type[Step] = SimpleChatStep +    default_params: dict = {}      def next(self, config: ContinueConfig, history: History) -> Step:          # At the very start, run initial Steps spcecified in the config @@ -75,6 +76,6 @@ class DefaultPolicy(Policy):              if user_input.startswith("/edit"):                  return EditHighlightedCodeStep(user_input=user_input[5:]) -            return self.default_step.copy() +            return self.default_step(**self.default_params)          return None | 
