summaryrefslogtreecommitdiff
path: root/continuedev/src
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-06-12 11:00:34 -0700
committerNate Sesti <sestinj@gmail.com>2023-06-12 11:00:34 -0700
commit57a6d8fc0a808ca160b5f691882a7893ed438c97 (patch)
tree35d6ad35e74f2e09c2d34ca6b2c421aea964cf7d /continuedev/src
parent40ba9eaf82a1386ccacf5046c072df3d131d5284 (diff)
downloadsncontinue-57a6d8fc0a808ca160b5f691882a7893ed438c97.tar.gz
sncontinue-57a6d8fc0a808ca160b5f691882a7893ed438c97.tar.bz2
sncontinue-57a6d8fc0a808ca160b5f691882a7893ed438c97.zip
streaming and highlighted code to chat context
Diffstat (limited to 'continuedev/src')
-rw-r--r--continuedev/src/continuedev/core/policy.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py38
-rw-r--r--continuedev/src/continuedev/steps/chat.py7
-rw-r--r--continuedev/src/continuedev/steps/react.py16
4 files changed, 25 insertions, 42 deletions
diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py
index e71a1cb2..00b5427c 100644
--- a/continuedev/src/continuedev/core/policy.py
+++ b/continuedev/src/continuedev/core/policy.py
@@ -57,10 +57,12 @@ class DemoPolicy(Policy):
return ContinueStepStep(prompt=" ".join(user_input.split(" ")[1:]))
# return EditHighlightedCodeStep(user_input=user_input)
return NLDecisionStep(user_input=user_input, steps=[
- EditHighlightedCodeStep(user_input=user_input),
+ (EditHighlightedCodeStep(user_input=user_input),
+ "Edit the highlighted code"),
# AnswerQuestionChroma(question=user_input),
# EditFileChroma(request=user_input),
- SimpleChatStep(user_input=user_input)
+ (SimpleChatStep(user_input=user_input),
+ "Respond to the user with a chat message"),
], default_step=EditHighlightedCodeStep(user_input=user_input))
state = history.get_current()
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 6a537afd..9b8d3447 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -23,7 +23,7 @@ class OpenAI(LLM):
def with_system_message(self, system_message: Union[str, None]):
return OpenAI(api_key=self.api_key, system_message=system_message)
- def stream_chat(self, messages, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
self.completion_count += 1
args = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0} | kwargs
@@ -31,7 +31,7 @@ class OpenAI(LLM):
args["model"] = "gpt-3.5-turbo"
for chunk in openai.ChatCompletion.create(
- messages=messages,
+ messages=self.compile_chat_messages(with_history, prompt),
**args,
):
if "content" in chunk.choices[0].delta:
@@ -39,7 +39,21 @@ class OpenAI(LLM):
else:
continue
- def stream_complete(self, prompt: str, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]:
+ history = []
+ if self.system_message:
+ history.append({
+ "role": "system",
+ "content": self.system_message
+ })
+ history += [msg.dict() for msg in msgs]
+ history.append({
+ "role": "user",
+ "content": prompt
+ })
+ return history
+
+ def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
self.completion_count += 1
args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5,
"top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, "suffix": None} | kwargs
@@ -47,10 +61,7 @@ class OpenAI(LLM):
if args["model"] == "gpt-3.5-turbo":
generator = openai.ChatCompletion.create(
- messages=[{
- "role": "user",
- "content": prompt
- }],
+ messages=self.compile_chat_messages(with_history, prompt),
**args,
)
for chunk in generator:
@@ -71,19 +82,8 @@ class OpenAI(LLM):
"frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs
if args["model"] == "gpt-3.5-turbo":
- messages = []
- if self.system_message:
- messages.append({
- "role": "system",
- "content": self.system_message
- })
- messages += [msg.dict() for msg in with_history]
- messages.append({
- "role": "user",
- "content": prompt
- })
resp = openai.ChatCompletion.create(
- messages=messages,
+ messages=self.compile_chat_messages(with_history, prompt),
**args,
).choices[0].message.content
else:
diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py
index 56e49223..aadcfa8e 100644
--- a/continuedev/src/continuedev/steps/chat.py
+++ b/continuedev/src/continuedev/steps/chat.py
@@ -10,15 +10,10 @@ class SimpleChatStep(Step):
name: str = "Chat"
async def run(self, sdk: ContinueSDK):
-<<<<<<< Updated upstream
- self.description = sdk.models.gpt35.complete(self.user_input, with_history=await sdk.get_chat_context())
-=======
- # TODO: With history
self.description = ""
- for chunk in sdk.models.gpt35.stream_chat([{"role": "user", "content": self.user_input}]):
+ for chunk in sdk.models.gpt35.stream_chat(self.user_input, with_history=await sdk.get_chat_context()):
self.description += chunk
await sdk.update_ui()
self.name = sdk.models.gpt35.complete(
f"Write a short title for the following chat message: {self.description}").strip()
->>>>>>> Stashed changes
diff --git a/continuedev/src/continuedev/steps/react.py b/continuedev/src/continuedev/steps/react.py
index d98b41c6..d825d424 100644
--- a/continuedev/src/continuedev/steps/react.py
+++ b/continuedev/src/continuedev/steps/react.py
@@ -1,9 +1,5 @@
from textwrap import dedent
-<<<<<<< Updated upstream
-from typing import List, Union
-=======
-from typing import List, Tuple
->>>>>>> Stashed changes
+from typing import List, Union, Tuple
from ..core.main import Step
from ..core.sdk import ContinueSDK
from .core.core import MessageStep
@@ -11,15 +7,10 @@ from .core.core import MessageStep
class NLDecisionStep(Step):
user_input: str
-<<<<<<< Updated upstream
- steps: List[Step]
- hide: bool = True
default_step: Union[Step, None] = None
-=======
steps: List[Tuple[Step, str]]
hide: bool = True
->>>>>>> Stashed changes
async def run(self, sdk: ContinueSDK):
step_descriptions = "\n".join([
@@ -40,13 +31,8 @@ class NLDecisionStep(Step):
step_to_run = None
for step in self.steps:
-<<<<<<< Updated upstream
- if step.name.lower() in resp:
- step_to_run = step
-=======
if step[0].name.lower() in resp:
step_to_run = step[0]
->>>>>>> Stashed changes
step_to_run = step_to_run or self.default_step or self.steps[0]