summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/openai.py
blob: ba29279bc59321bc1d15963c1c93348a951dc03d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Callable, List, Literal, Optional

import certifi
import openai
from pydantic import Field

from ...core.main import ChatMessage
from .base import LLM

CHAT_MODELS = {
    "gpt-3.5-turbo",
    "gpt-3.5-turbo-16k",
    "gpt-4",
    "gpt-3.5-turbo-0613",
    "gpt-4-32k",
}
MAX_TOKENS_FOR_MODEL = {
    "gpt-3.5-turbo": 4096,
    "gpt-3.5-turbo-0613": 4096,
    "gpt-3.5-turbo-16k": 16_384,
    "gpt-4": 8192,
    "gpt-35-turbo-16k": 16_384,
    "gpt-35-turbo-0613": 4096,
    "gpt-35-turbo": 4096,
    "gpt-4-32k": 32_768,
}


class OpenAI(LLM):
    """
    The OpenAI class can be used to access OpenAI models like gpt-4 and gpt-3.5-turbo.

    If you are locally serving a model that uses an OpenAI-compatible server, you can simply change the `api_base` in the `OpenAI` class like this:

    ```python title="~/.continue/config.py"
    from continuedev.libs.llm.openai import OpenAI

    config = ContinueConfig(
        ...
        models=Models(
            default=OpenAI(
                api_key="EMPTY",
                model="<MODEL_NAME>",
                api_base="http://localhost:8000", # change to your server
            )
        )
    )
    ```

    Options for serving models locally with an OpenAI-compatible server include:

    - [text-gen-webui](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai#setup--installation)
    - [FastChat](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md)
    - [LocalAI](https://localai.io/basics/getting_started/)
    - [llama-cpp-python](https://github.com/abetlen/llama-cpp-python#web-server)
    """

    api_key: str = Field(
        ...,
        description="OpenAI API key",
    )

    proxy: Optional[str] = Field(None, description="Proxy URL to use for requests.")

    api_base: Optional[str] = Field(None, description="OpenAI API base URL.")

    api_type: Optional[Literal["azure", "openai"]] = Field(
        None, description="OpenAI API type."
    )

    api_version: Optional[str] = Field(
        None, description="OpenAI API version. For use with Azure OpenAI Service."
    )

    engine: Optional[str] = Field(
        None, description="OpenAI engine. For use with Azure OpenAI Service."
    )

    async def start(
        self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None
    ):
        await super().start(write_log=write_log, unique_id=unique_id)

        if self.context_length is None:
            self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096)

        openai.api_key = self.api_key
        if self.api_type is not None:
            openai.api_type = self.api_type
        if self.api_base is not None:
            openai.api_base = self.api_base
        if self.api_version is not None:
            openai.api_version = self.api_version

        if self.verify_ssl is not None and self.verify_ssl is False:
            openai.verify_ssl_certs = False

        if self.proxy is not None:
            openai.proxy = self.proxy

        openai.ca_bundle_path = self.ca_bundle_path or certifi.where()

    def collect_args(self, options):
        args = super().collect_args(options)
        if self.engine is not None:
            args["engine"] = self.engine

        if not args["model"].endswith("0613") and "functions" in args:
            del args["functions"]

        return args

    async def _stream_complete(self, prompt, options):
        args = self.collect_args(options)
        args["stream"] = True

        if args["model"] in CHAT_MODELS:
            async for chunk in await openai.ChatCompletion.acreate(
                messages=[{"role": "user", "content": prompt}],
                **args,
                headers=self.headers,
            ):
                if len(chunk.choices) > 0 and "content" in chunk.choices[0].delta:
                    yield chunk.choices[0].delta.content
        else:
            async for chunk in await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers):
                if len(chunk.choices) > 0:
                    yield chunk.choices[0].text

    async def _stream_chat(self, messages: List[ChatMessage], options):
        args = self.collect_args(options)

        async for chunk in await openai.ChatCompletion.acreate(
            messages=messages,
            stream=True,
            **args,
            headers=self.headers,
        ):
            if not hasattr(chunk, "choices") or len(chunk.choices) == 0:
                continue
            yield chunk.choices[0].delta

    async def _complete(self, prompt: str, options):
        args = self.collect_args(options)

        if args["model"] in CHAT_MODELS:
            resp = await openai.ChatCompletion.acreate(
                messages=[{"role": "user", "content": prompt}],
                **args,
                headers=self.headers,
            )
            return resp.choices[0].message.content
        else:
            return (
                (await openai.Completion.acreate(prompt=prompt, **args, headers=self.headers)).choices[0].text
            )