summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/openai.py
diff options
context:
space:
mode:
authorNate Sesti <33237525+sestinj@users.noreply.github.com>2023-10-09 18:37:27 -0700
committerGitHub <noreply@github.com>2023-10-09 18:37:27 -0700
commitf09150617ed2454f3074bcf93f53aae5ae637d40 (patch)
tree5cfe614a64d921dfe58b049f426d67a8b832c71f /server/continuedev/libs/llm/openai.py
parent985304a213f620cdff3f8f65f74ed7e3b79be29d (diff)
downloadsncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.gz
sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.tar.bz2
sncontinue-f09150617ed2454f3074bcf93f53aae5ae637d40.zip
Preview (#541)
* Strong typing (#533) * refactor: :recycle: get rid of continuedev.src.continuedev structure * refactor: :recycle: switching back to server folder * feat: :sparkles: make config.py imports shorter * feat: :bookmark: publish as pre-release vscode extension * refactor: :recycle: refactor and add more completion params to ui * build: :building_construction: download from preview S3 * fix: :bug: fix paths * fix: :green_heart: package:pre-release * ci: :green_heart: more time for tests * fix: :green_heart: fix build scripts * fix: :bug: fix import in run.py * fix: :bookmark: update version to try again * ci: 💚 Update package.json version [skip ci] * refactor: :fire: don't check for old extensions version * fix: :bug: small bug fixes * fix: :bug: fix config.py import paths * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: platform-specific builds test #1 * feat: :green_heart: ship with binary * fix: :green_heart: fix copy statement to include.exe for windows * fix: :green_heart: cd extension before packaging * chore: :loud_sound: count tokens generated * fix: :green_heart: remove npm_config_arch * fix: :green_heart: publish as pre-release! * chore: :bookmark: update version * perf: :green_heart: hardcode distro paths * fix: :bug: fix yaml syntax error * chore: :bookmark: update version * fix: :green_heart: update permissions and version * feat: :bug: kill old server if needed * feat: :lipstick: update marketplace icon for pre-release * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: auto-reload for config.py * feat: :wrench: update default config.py imports * feat: :sparkles: codelens in config.py * feat: :sparkles: select model param count from UI * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: more model options, ollama error handling * perf: :zap: don't show server loading immediately * fix: :bug: fixing small UI details * ci: 💚 Update package.json version [skip ci] * feat: :rocket: headers param on LLM class * fix: :bug: fix headers for openai.;y * feat: :sparkles: highlight code on cmd+shift+L * ci: 💚 Update package.json version [skip ci] * feat: :lipstick: sticky top bar in gui.tsx * fix: :loud_sound: websocket logging and horizontal scrollbar * ci: 💚 Update package.json version [skip ci] * feat: :sparkles: allow AzureOpenAI Service through GGML * ci: 💚 Update package.json version [skip ci] * fix: :bug: fix automigration * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: upload binaries in ci, download apple silicon * chore: :fire: remove notes * fix: :green_heart: use curl to download binary * fix: :green_heart: set permissions on apple silicon binary * fix: :green_heart: testing * fix: :green_heart: cleanup file * fix: :green_heart: fix preview.yaml * fix: :green_heart: only upload once per binary * fix: :green_heart: install rosetta * ci: :green_heart: download binary after tests * ci: 💚 Update package.json version [skip ci] * ci: :green_heart: prepare ci for merge to main --------- Co-authored-by: GitHub Action <action@github.com>
Diffstat (limited to 'server/continuedev/libs/llm/openai.py')
-rw-r--r--server/continuedev/libs/llm/openai.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/openai.py b/server/continuedev/libs/llm/openai.py
new file mode 100644
index 00000000..ba29279b
--- /dev/null
+++ b/server/continuedev/libs/llm/openai.py
@@ -0,0 +1,156 @@
+from typing import Callable, List, Literal, Optional
+
+import certifi
+import openai
+from pydantic import Field
+
+from ...core.main import ChatMessage
+from .base import LLM
+
+CHAT_MODELS = {
+ "gpt-3.5-turbo",
+ "gpt-3.5-turbo-16k",
+ "gpt-4",
+ "gpt-3.5-turbo-0613",
+ "gpt-4-32k",
+}
+MAX_TOKENS_FOR_MODEL = {
+ "gpt-3.5-turbo": 4096,
+ "gpt-3.5-turbo-0613": 4096,
+ "gpt-3.5-turbo-16k": 16_384,
+ "gpt-4": 8192,
+ "gpt-35-turbo-16k": 16_384,
+ "gpt-35-turbo-0613": 4096,
+ "gpt-35-turbo": 4096,
+ "gpt-4-32k": 32_768,
+}
+
+
+class OpenAI(LLM):
+ """
+ The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo.
+
+ If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this:
+
+ ```python title="~/.continue/config.py"
+ from continuedev.libs.llm.openai import OpenAI
+
+ config = ContinueConfig(
+ ...
+ models=Models(
+ default=OpenAI(
+ api_key="EMPTY",
+ model="<MODEL_NAME>",
+ api_base="http://localhost:8000", # change to your server
+ )
+ )
+ )
+ ```
+
+ Options for serving models locally with an OpenAI-compatible server include:
+
+ - [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation)
+ - [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md)
+ - [LocalAI](https://localai.io/basics/getting_started/)
+ - [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server)
+ """
+
+ api_key: str = Field(
+ ...,
+ description="OpenAI API key",
+ )
+
+ proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.")
+
+ api_base: Optional[str] = Field(None, description="OpenAI API base URL.")
+
+ api_type: Optional[Literal["azure", "openai"]] = Field(
+ None, description="OpenAI API type."
+ )
+
+ api_version: Optional[str] = Field(
+ None, description="OpenAI API version. For use with Azure OpenAI Service."
+ )
+
+ engine: Optional[str] = Field(
+ None, description="OpenAI engine. For use with Azure OpenAI Service."
+ )
+
+ async def start(
+ self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None
+ ):
+ await super().start(write_log=write_log, unique_id=unique_id)
+
+ if self.context_length is None:
+ self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096)
+
+ openai.api_key = self.api_key
+ if self.api_type is not None:
+ openai.api_type = self.api_type
+ if self.api_base is not None:
+ openai.api_base = self.api_base
+ if self.api_version is not None:
+ openai.api_version = self.api_version
+
+ if self.verify_ssl is not None and self.verify_ssl is False:
+ openai.verify_ssl_certs = False
+
+ if self.proxy is not None:
+ openai.proxy = self.proxy
+
+ openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
+
+ def collect_args(self, options):
+ args = super().collect_args(options)
+ if self.engine is not None:
+ args["engine"] = self.engine
+
+ if not args["model"].endswith("0613") and "functions" in args:
+ del args["functions"]
+
+ return args
+
+ async def _stream_complete(self, prompt, options):
+ args = self.collect_args(options)
+ args["stream"] = True
+
+ if args["model"] in CHAT_MODELS:
+ async for chunk in await openai.ChatCompletion.acreate(
+ messages=[{"role": "user", "content": prompt}],
+ **args,
+ headers=self.headers,
+ ):
+ if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta:
+ yield chunk.choices[0].delta.content
+ else:
+ async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers):
+ if len(chunk.choices) > 0:
+ yield chunk.choices[0].text
+
+ async def _stream_chat(self, messages: List[ChatMessage], options):
+ args = self.collect_args(options)
+
+ async for chunk in await openai.ChatCompletion.acreate(
+ messages=messages,
+ stream=True,
+ **args,
+ headers=self.headers,
+ ):
+ if not hasattr(chunk, "choices") or len(chunk.choices) == 0:
+ continue
+ yield chunk.choices[0].delta
+
+ async def _complete(self, prompt: str, options):
+ args = self.collect_args(options)
+
+ if args["model"] in CHAT_MODELS:
+ resp = await openai.ChatCompletion.acreate(
+ messages=[{"role": "user", "content": prompt}],
+ **args,
+ headers=self.headers,
+ )
+ return resp.choices[0].message.content
+ else:
+ return (
+ (await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text
+ )