diff options
Diffstat (limited to 'continuedev/src')
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 38 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 7 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/react.py | 16 | 
4 files changed, 25 insertions, 42 deletions
| diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index e71a1cb2..00b5427c 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -57,10 +57,12 @@ class DemoPolicy(Policy):                  return ContinueStepStep(prompt=" ".join(user_input.split(" ")[1:]))              # return EditHighlightedCodeStep(user_input=user_input)              return NLDecisionStep(user_input=user_input, steps=[ -                EditHighlightedCodeStep(user_input=user_input), +                (EditHighlightedCodeStep(user_input=user_input), +                 "Edit the highlighted code"),                  # AnswerQuestionChroma(question=user_input),                  # EditFileChroma(request=user_input), -                SimpleChatStep(user_input=user_input) +                (SimpleChatStep(user_input=user_input), +                 "Respond to the user with a chat message"),              ], default_step=EditHighlightedCodeStep(user_input=user_input))          state = history.get_current() diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 6a537afd..9b8d3447 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -23,7 +23,7 @@ class OpenAI(LLM):      def with_system_message(self, system_message: Union[str, None]):          return OpenAI(api_key=self.api_key, system_message=system_message) -    def stream_chat(self, messages, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          self.completion_count += 1          args = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,                  "frequency_penalty": 0, "presence_penalty": 0} | kwargs @@ -31,7 +31,7 @@ class OpenAI(LLM):          args["model"] = "gpt-3.5-turbo"          for chunk in openai.ChatCompletion.create( -            messages=messages, +            messages=self.compile_chat_messages(with_history, prompt),              **args,          ):              if "content" in chunk.choices[0].delta: @@ -39,7 +39,21 @@ class OpenAI(LLM):              else:                  continue -    def stream_complete(self, prompt: str, **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: +        history = [] +        if self.system_message: +            history.append({ +                "role": "system", +                "content": self.system_message +            }) +        history += [msg.dict() for msg in msgs] +        history.append({ +            "role": "user", +            "content": prompt +        }) +        return history + +    def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:          self.completion_count += 1          args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5,                  "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, "suffix": None} | kwargs @@ -47,10 +61,7 @@ class OpenAI(LLM):          if args["model"] == "gpt-3.5-turbo":              generator = openai.ChatCompletion.create( -                messages=[{ -                    "role": "user", -                    "content": prompt -                }], +                messages=self.compile_chat_messages(with_history, prompt),                  **args,              )              for chunk in generator: @@ -71,19 +82,8 @@ class OpenAI(LLM):                  "frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs          if args["model"] == "gpt-3.5-turbo": -            messages = [] -            if self.system_message: -                messages.append({ -                    "role": "system", -                    "content": self.system_message -                }) -            messages += [msg.dict() for msg in with_history] -            messages.append({ -                "role": "user", -                "content": prompt -            })              resp = openai.ChatCompletion.create( -                messages=messages, +                messages=self.compile_chat_messages(with_history, prompt),                  **args,              ).choices[0].message.content          else: diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 56e49223..aadcfa8e 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -10,15 +10,10 @@ class SimpleChatStep(Step):      name: str = "Chat"      async def run(self, sdk: ContinueSDK): -<<<<<<< Updated upstream -        self.description = sdk.models.gpt35.complete(self.user_input, with_history=await sdk.get_chat_context()) -======= -        # TODO: With history          self.description = "" -        for chunk in sdk.models.gpt35.stream_chat([{"role": "user", "content": self.user_input}]): +        for chunk in sdk.models.gpt35.stream_chat(self.user_input, with_history=await sdk.get_chat_context()):              self.description += chunk              await sdk.update_ui()          self.name = sdk.models.gpt35.complete(              f"Write a short title for the following chat message: {self.description}").strip() ->>>>>>> Stashed changes diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py index d98b41c6..d825d424 100644 --- a/continuedev/src/continuedev/steps/react.py +++ b/continuedev/src/continuedev/steps/react.py @@ -1,9 +1,5 @@  from textwrap import dedent -<<<<<<< Updated upstream -from typing import List, Union -======= -from typing import List, Tuple ->>>>>>> Stashed changes +from typing import List, Union, Tuple  from ..core.main import Step  from ..core.sdk import ContinueSDK  from .core.core import MessageStep @@ -11,15 +7,10 @@ from .core.core import MessageStep  class NLDecisionStep(Step):      user_input: str -<<<<<<< Updated upstream -    steps: List[Step] -    hide: bool = True      default_step: Union[Step, None] = None -=======      steps: List[Tuple[Step, str]]      hide: bool = True ->>>>>>> Stashed changes      async def run(self, sdk: ContinueSDK):          step_descriptions = "\n".join([ @@ -40,13 +31,8 @@ class NLDecisionStep(Step):          step_to_run = None          for step in self.steps: -<<<<<<< Updated upstream -            if step.name.lower() in resp: -                step_to_run = step -=======              if step[0].name.lower() in resp:                  step_to_run = step[0] ->>>>>>> Stashed changes          step_to_run = step_to_run or self.default_step or self.steps[0] | 
