summaryrefslogtreecommitdiff
path: root/continuedev
diff options
context:
space:
mode:
Diffstat (limited to 'continuedev')
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py2
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py6
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py6
-rw-r--r--continuedev/src/continuedev/libs/util/count_tokens.py14
7 files changed, 24 insertions, 22 deletions
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 4c4de213..2766db4b 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -9,15 +9,15 @@ from pydantic import BaseModel
class LLM(ABC):
system_message: Union[str, None] = None
- async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
- def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the completion through generator."""
raise NotImplementedError
- async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the chat through generator."""
raise NotImplementedError
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index c82895c6..625d4e57 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -54,7 +54,7 @@ class AnthropicLLM(LLM):
prompt += AI_PROMPT
return prompt
- async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
@@ -66,7 +66,7 @@ class AnthropicLLM(LLM):
):
yield chunk.completion
- async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
@@ -83,7 +83,7 @@ class AnthropicLLM(LLM):
"content": chunk.completion
}
- async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
args = self._transform_args(args)
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index 6007fdb4..4889a556 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -26,7 +26,7 @@ class GGML(LLM):
def count_tokens(self, text: str):
return count_tokens(self.name, text)
- async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
@@ -47,7 +47,7 @@ class GGML(LLM):
except:
raise Exception(str(line))
- async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
@@ -72,7 +72,7 @@ class GGML(LLM):
except:
raise Exception(str(line[0]))
- async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
async with aiohttp.ClientSession() as session:
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 7e11fbbe..36f03270 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -16,7 +16,7 @@ class HuggingFaceInferenceAPI(LLM):
self.model = model
self.system_message = system_message # TODO: Nothing being done with this
- def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs):
+ def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs):
"""Return the completion of the text with the given temperature."""
API_URL = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index 64bb39a2..96a4ab71 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -42,7 +42,7 @@ class OpenAI(LLM):
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
- async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
@@ -72,7 +72,7 @@ class OpenAI(LLM):
self.write_log(f"Completion:\n\n{completion}")
- async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = self.default_args.copy()
args.update(kwargs)
args["stream"] = True
@@ -93,7 +93,7 @@ class OpenAI(LLM):
completion += chunk.choices[0].delta.content
self.write_log(f"Completion: \n\n{completion}")
- async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
if args["model"] in CHAT_MODELS:
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index bd50fe02..b1bb8f06 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -38,7 +38,7 @@ class ProxyServer(LLM):
def count_tokens(self, text: str):
return count_tokens(self.default_model, text)
- async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]:
+ async def complete(self, prompt: str, with_history: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, str]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
@@ -57,7 +57,7 @@ class ProxyServer(LLM):
except:
raise Exception(await resp.text())
- async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
+ async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
args["model"], messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message)
@@ -89,7 +89,7 @@ class ProxyServer(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
+ async def stream_complete(self, prompt, with_history: List[ChatMessage] = None, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message)
diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py
index 987aa722..6e0a3b88 100644
--- a/continuedev/src/continuedev/libs/util/count_tokens.py
+++ b/continuedev/src/continuedev/libs/util/count_tokens.py
@@ -101,13 +101,15 @@ def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens:
TOKEN_BUFFER_FOR_SAFETY = 100
-def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
+def compile_chat_messages(model: str, msgs: Union[List[ChatMessage], None], max_tokens: int, prompt: Union[str, None] = None, functions: Union[List, None] = None, system_message: Union[str, None] = None) -> List[Dict]:
"""
The total number of tokens is system_message + sum(msgs) + functions + prompt after it is converted to a message
"""
+ msgs_copy = msgs.copy() if msgs is not None else []
+
if prompt is not None:
prompt_msg = ChatMessage(role="user", content=prompt, summary=prompt)
- msgs += [prompt_msg]
+ msgs_copy += [prompt_msg]
if system_message is not None:
# NOTE: System message takes second precedence to user prompt, so it is placed just before
@@ -116,7 +118,7 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
system_chat_msg = ChatMessage(
role="system", content=rendered_system_message, summary=rendered_system_message)
# insert at second-to-last position
- msgs.insert(-1, system_chat_msg)
+ msgs_copy.insert(-1, system_chat_msg)
# Add tokens from functions
function_tokens = 0
@@ -124,11 +126,11 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int,
for function in functions:
function_tokens += count_tokens(model, json.dumps(function))
- msgs = prune_chat_history(
- model, msgs, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
+ msgs_copy = prune_chat_history(
+ model, msgs_copy, MAX_TOKENS_FOR_MODEL[model], function_tokens + max_tokens + TOKEN_BUFFER_FOR_SAFETY)
history = [msg.to_dict(with_functions=functions is not None)
- for msg in msgs]
+ for msg in msgs_copy]
# Move system message back to start
if system_message is not None and len(history) >= 2 and history[-2]["role"] == "system":