diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 19 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/anthropic.py | 33 | 
2 files changed, 40 insertions, 12 deletions
| diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 42cfbcb9..d040ea41 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -18,6 +18,7 @@ from ..libs.util.telemetry import posthog_logger  from ..libs.util.paths import getConfigFilePath  from .models import Models  from ..libs.util.logging import logger +# __import__("anthropic", globals(), locals(), ["AsyncAnthropic"], 0)  class Autopilot: @@ -46,7 +47,8 @@ class ContinueSDK(AbstractContinueSDK):              config = sdk._load_config_dot_py()              sdk.config = config          except Exception as e: -            logger.error(f"Failed to load config.py: {e}") +            logger.error( +                f"Failed to load config.py: {traceback.format_exception(e)}")              sdk.config = ContinueConfig(              ) if sdk._last_valid_config is None else sdk._last_valid_config @@ -170,9 +172,15 @@ class ContinueSDK(AbstractContinueSDK):          def load_module(module_name: str, class_names: List[str]):              # from anthropic import AsyncAnthropic -            module = importlib.import_module(module_name) -            for class_name in class_names: -                globals()[class_name] = getattr(module, class_name) +            print("IMPORTING") +            # exec("from anthropic import AsyncAnthropic", globals(), locals()) +            # imports = __import__("anthropic", globals(), locals(), ["AsyncAnthropic"], 0) +            # print("IMPORTS: ", imports) +            # for class_name in class_names: +            #     globals()[class_name] = getattr(imports, class_name) +            # module = importlib.import_module(module_name) +            # for class_name in class_names: +            #     globals()[class_name] = getattr(module, class_name)          while True:              # Execute the file content @@ -200,7 +208,8 @@ class ContinueSDK(AbstractContinueSDK):                  # Get the module name                  module_name = line[1]                  # Get the class name -                class_names = list(map(lambda x: x.replace(",", ""), filter(lambda x: x.strip() != "", line[3:]))) +                class_names = list(map(lambda x: x.replace( +                    ",", ""), filter(lambda x: x.strip() != "", line[3:])))                  # Load the module                  print( diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py index e6b88d03..ac5c56a4 100644 --- a/continuedev/src/continuedev/libs/llm/anthropic.py +++ b/continuedev/src/continuedev/libs/llm/anthropic.py @@ -1,29 +1,32 @@ -  from functools import cached_property  import time -from typing import Any, Coroutine, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union  from ...core.main import ChatMessage  from anthropic import HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic  from ..llm import LLM -from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens +from ..util.count_tokens import compile_chat_messages, DEFAULT_ARGS, count_tokens, format_chat_messages  class AnthropicLLM(LLM):      model: str = "claude-2"      requires_api_key: str = "ANTHROPIC_API_KEY" +    requires_write_log = True      _async_client: AsyncAnthropic = None      class Config:          arbitrary_types_allowed = True -    async def start(self, *, api_key: Optional[str] = None, **kwargs): +    write_log: Optional[Callable[[str], None]] = None + +    async def start(self, *, api_key: Optional[str] = None, write_log: Callable[[str], None], **kwargs): +        self.write_log = write_log          self._async_client = AsyncAnthropic(api_key=api_key)      async def stop(self):          pass -    @cached_property +    @property      def name(self):          return self.model @@ -68,12 +71,18 @@ class AnthropicLLM(LLM):          args.update(kwargs)          args["stream"] = True          args = self._transform_args(args) +        prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}" +        self.write_log(f"Prompt: \n\n{prompt}") +        completion = ""          async for chunk in await self._async_client.completions.create( -            prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", +            prompt=prompt,              **args          ):              yield chunk.completion +            completion += chunk.completion + +        self.write_log(f"Completion: \n\n{completion}")      async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          args = self.default_args.copy() @@ -82,7 +91,10 @@ class AnthropicLLM(LLM):          args = self._transform_args(args)          messages = compile_chat_messages( -            args["model"], messages, self.context_length, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) +            args["model"], messages, self.context_length, args["max_tokens_to_sample"], functions=args.get("functions", None), system_message=self.system_message) + +        completion = "" +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          async for chunk in await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args @@ -91,6 +103,9 @@ class AnthropicLLM(LLM):                  "role": "assistant",                  "content": chunk.completion              } +            completion += chunk.completion + +        self.write_log(f"Completion: \n\n{completion}")      async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:          args = {**self.default_args, **kwargs} @@ -98,9 +113,13 @@ class AnthropicLLM(LLM):          messages = compile_chat_messages(              args["model"], with_history, self.context_length, args["max_tokens_to_sample"], prompt, functions=None, system_message=self.system_message) + +        completion = "" +        self.write_log(f"Prompt: \n\n{format_chat_messages(messages)}")          resp = (await self._async_client.completions.create(              prompt=self.__messages_to_prompt(messages),              **args          )).completion +        self.write_log(f"Completion: \n\n{resp}")          return resp | 
