summaryrefslogtreecommitdiff
path: root/continuedev/src/continuedev/libs/llm
diff options
context:
space:
mode:
authorNate Sesti <sestinj@gmail.com>2023-05-28 12:05:32 -0400
committerNate Sesti <sestinj@gmail.com>2023-05-28 12:05:32 -0400
commit78513ba9a63635a777262806793394131ad43744 (patch)
tree73113812a7bebf476966f48d34a57a7fcd67c91f /continuedev/src/continuedev/libs/llm
parent91f86ca0fdcd6dbbc8470fc41ef1ecf83ffa480f (diff)
downloadsncontinue-78513ba9a63635a777262806793394131ad43744.tar.gz
sncontinue-78513ba9a63635a777262806793394131ad43744.tar.bz2
sncontinue-78513ba9a63635a777262806793394131ad43744.zip
bug fixes, build script, sdk methods
Diffstat (limited to 'continuedev/src/continuedev/libs/llm')
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index bb745e75..10801465 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -6,6 +6,8 @@ import aiohttp
from ..llm import LLM
from pydantic import BaseModel, validator
+DEFAULT_MAX_TOKENS = 2048
+
class OpenAI(LLM):
api_key: str
@@ -22,7 +24,7 @@ class OpenAI(LLM):
def stream_chat(self, messages, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
self.completion_count += 1
- args = {"max_tokens": 512, "temperature": 0.5, "top_p": 1,
+ args = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0} | kwargs
args["stream"] = True
args["model"] = "gpt-3.5-turbo"
@@ -38,7 +40,7 @@ class OpenAI(LLM):
def stream_complete(self, prompt: str, **kwargs) -> Generator[Union[Any, List, Dict], None, None]:
self.completion_count += 1
- args = {"model": self.default_model, "max_tokens": 512, "temperature": 0.5,
+ args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5,
"top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, "suffix": None} | kwargs
args["stream"] = True
@@ -64,7 +66,7 @@ class OpenAI(LLM):
t1 = time.time()
self.completion_count += 1
- args = {"model": self.default_model, "max_tokens": 512, "temperature": 0.5, "top_p": 1,
+ args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1,
"frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs
if args["model"] == "gpt-3.5-turbo":
@@ -132,7 +134,7 @@ class OpenAI(LLM):
def parallel_complete(self, prompts: list[str], suffixes: Union[list[str], None] = None, **kwargs) -> list[str]:
self.completion_count += len(prompts)
- args = {"model": self.default_model, "max_tokens": 512, "temperature": 0.5,
+ args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5,
"top_p": 1, "frequency_penalty": 0, "presence_penalty": 0} | kwargs
async def fn():