summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/openai.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/openai.py')
-rw-r--r--server/continuedev/libs/llm/openai.py156
1 files changed, 0 insertions, 156 deletions
diff --git a/server/continuedev/libs/llm/openai.py b/server/continuedev/libs/llm/openai.py
deleted file mode 100644
index ba29279b..00000000
--- a/server/continuedev/libs/llm/openai.py
+++ /dev/null
@@ -1,156 +0,0 @@
-from typing import Callable, List, Literal, Optional
-
-import certifi
-import openai
-from pydantic import Field
-
-from ...core.main import ChatMessage
-from .base import LLM
-
-CHAT_MODELS = {
- "gpt-3.5-turbo",
- "gpt-3.5-turbo-16k",
- "gpt-4",
- "gpt-3.5-turbo-0613",
- "gpt-4-32k",
-}
-MAX_TOKENS_FOR_MODEL = {
- "gpt-3.5-turbo": 4096,
- "gpt-3.5-turbo-0613": 4096,
- "gpt-3.5-turbo-16k": 16_384,
- "gpt-4": 8192,
- "gpt-35-turbo-16k": 16_384,
- "gpt-35-turbo-0613": 4096,
- "gpt-35-turbo": 4096,
- "gpt-4-32k": 32_768,
-}
-
-
-class OpenAI(LLM):
- """
- The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo.
-
- If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this:
-
- ```python title="~/.continue/config.py"
- from continuedev.libs.llm.openai import OpenAI
-
- config = ContinueConfig(
- ...
- models=Models(
- default=OpenAI(
- api_key="EMPTY",
- model="<MODEL_NAME>",
- api_base="http://localhost:8000", # change to your server
- )
- )
- )
- ```
-
- Options for serving models locally with an OpenAI-compatible server include:
-
- - [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation)
- - [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md)
- - [LocalAI](https://localai.io/basics/getting_started/)
- - [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server)
- """
-
- api_key: str = Field(
- ...,
- description="OpenAI API key",
- )
-
- proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.")
-
- api_base: Optional[str] = Field(None, description="OpenAI API base URL.")
-
- api_type: Optional[Literal["azure", "openai"]] = Field(
- None, description="OpenAI API type."
- )
-
- api_version: Optional[str] = Field(
- None, description="OpenAI API version. For use with Azure OpenAI Service."
- )
-
- engine: Optional[str] = Field(
- None, description="OpenAI engine. For use with Azure OpenAI Service."
- )
-
- async def start(
- self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None
- ):
- await super().start(write_log=write_log, unique_id=unique_id)
-
- if self.context_length is None:
- self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096)
-
- openai.api_key = self.api_key
- if self.api_type is not None:
- openai.api_type = self.api_type
- if self.api_base is not None:
- openai.api_base = self.api_base
- if self.api_version is not None:
- openai.api_version = self.api_version
-
- if self.verify_ssl is not None and self.verify_ssl is False:
- openai.verify_ssl_certs = False
-
- if self.proxy is not None:
- openai.proxy = self.proxy
-
- openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
-
- def collect_args(self, options):
- args = super().collect_args(options)
- if self.engine is not None:
- args["engine"] = self.engine
-
- if not args["model"].endswith("0613") and "functions" in args:
- del args["functions"]
-
- return args
-
- async def _stream_complete(self, prompt, options):
- args = self.collect_args(options)
- args["stream"] = True
-
- if args["model"] in CHAT_MODELS:
- async for chunk in await openai.ChatCompletion.acreate(
- messages=[{"role": "user", "content": prompt}],
- **args,
- headers=self.headers,
- ):
- if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta:
- yield chunk.choices[0].delta.content
- else:
- async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers):
- if len(chunk.choices) > 0:
- yield chunk.choices[0].text
-
- async def _stream_chat(self, messages: List[ChatMessage], options):
- args = self.collect_args(options)
-
- async for chunk in await openai.ChatCompletion.acreate(
- messages=messages,
- stream=True,
- **args,
- headers=self.headers,
- ):
- if not hasattr(chunk, "choices") or len(chunk.choices) == 0:
- continue
- yield chunk.choices[0].delta
-
- async def _complete(self, prompt: str, options):
- args = self.collect_args(options)
-
- if args["model"] in CHAT_MODELS:
- resp = await openai.ChatCompletion.acreate(
- messages=[{"role": "user", "content": prompt}],
- **args,
- headers=self.headers,
- )
- return resp.choices[0].message.content
- else:
- return (
- (await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text
- )