1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
|
from typing import Callable, List, Literal, Optional
import certifi
import openai
from pydantic import Field
from ...core.main import ChatMessage
from .base import LLM
CHAT_MODELS = {
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-3.5-turbo-0613",
"gpt-4-32k",
}
MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16_384,
"gpt-4": 8192,
"gpt-35-turbo-16k": 16_384,
"gpt-35-turbo-0613": 4096,
"gpt-35-turbo": 4096,
"gpt-4-32k": 32_768,
}
class OpenAI(LLM):
"""
The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo.
If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this:
```python title="~/.continue/config.py"
from continuedev.libs.llm.openai import OpenAI
config = ContinueConfig(
...
models=Models(
default=OpenAI(
api_key="EMPTY",
model="<MODEL_NAME>",
api_base="http://localhost:8000", # change to your server
)
)
)
```
Options for serving models locally with an OpenAI-compatible server include:
- [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation)
- [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md)
- [LocalAI](https://localai.io/basics/getting_started/)
- [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server)
"""
api_key: str = Field(
...,
description="OpenAI API key",
)
proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.")
api_base: Optional[str] = Field(None, description="OpenAI API base URL.")
api_type: Optional[Literal["azure", "openai"]] = Field(
None, description="OpenAI API type."
)
api_version: Optional[str] = Field(
None, description="OpenAI API version. For use with Azure OpenAI Service."
)
engine: Optional[str] = Field(
None, description="OpenAI engine. For use with Azure OpenAI Service."
)
async def start(
self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None
):
await super().start(write_log=write_log, unique_id=unique_id)
if self.context_length is None:
self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096)
openai.api_key = self.api_key
if self.api_type is not None:
openai.api_type = self.api_type
if self.api_base is not None:
openai.api_base = self.api_base
if self.api_version is not None:
openai.api_version = self.api_version
if self.verify_ssl is not None and self.verify_ssl is False:
openai.verify_ssl_certs = False
if self.proxy is not None:
openai.proxy = self.proxy
openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
def collect_args(self, options):
args = super().collect_args(options)
if self.engine is not None:
args["engine"] = self.engine
if not args["model"].endswith("0613") and "functions" in args:
del args["functions"]
return args
async def _stream_complete(self, prompt, options):
args = self.collect_args(options)
args["stream"] = True
if args["model"] in CHAT_MODELS:
async for chunk in await openai.ChatCompletion.acreate(
messages=[{"role": "user", "content": prompt}],
**args,
headers=self.headers,
):
if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta:
yield chunk.choices[0].delta.content
else:
async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers):
if len(chunk.choices) > 0:
yield chunk.choices[0].text
async def _stream_chat(self, messages: List[ChatMessage], options):
args = self.collect_args(options)
async for chunk in await openai.ChatCompletion.acreate(
messages=messages,
stream=True,
**args,
headers=self.headers,
):
if not hasattr(chunk, "choices") or len(chunk.choices) == 0:
continue
yield chunk.choices[0].delta
async def _complete(self, prompt: str, options):
args = self.collect_args(options)
if args["model"] in CHAT_MODELS:
resp = await openai.ChatCompletion.acreate(
messages=[{"role": "user", "content": prompt}],
**args,
headers=self.headers,
)
return resp.choices[0].message.content
else:
return (
(await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text
)
|