diff options
Diffstat (limited to 'server/continuedev/libs/llm/google_palm_api.py')
-rw-r--r-- | server/continuedev/libs/llm/google_palm_api.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/google_palm_api.py b/server/continuedev/libs/llm/google_palm_api.py new file mode 100644 index 00000000..3379fefe --- /dev/null +++ b/server/continuedev/libs/llm/google_palm_api.py @@ -0,0 +1,50 @@ +from typing import List + +import requests +from pydantic import Field + +from ...core.main import ChatMessage +from .base import LLM + + +class GooglePaLMAPI(LLM): + """ + The Google PaLM API is currently in public preview, so production applications are not supported yet. However, you can [create an API key in Google MakerSuite](https://makersuite.google.com/u/2/app/apikey) and begin trying out the `chat-bison-001` model. Change `~/.continue/config.py` to look like this: + + ```python title="~/.continue/config.py" + from continuedev.core.models import Models + from continuedev.libs.llm.hf_inference_api import GooglePaLMAPI + + config = ContinueConfig( + ... + models=Models( + default=GooglePaLMAPI( + model="chat-bison-001" + api_key="<MAKERSUITE_API_KEY>", + ) + ) + ``` + """ + + api_key: str = Field(..., description="Google PaLM API key") + + model: str = "chat-bison-001" + + async def _stream_complete(self, prompt, options): + api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" + body = {"prompt": {"messages": [{"content": prompt}]}} + response = requests.post(api_url, json=body) + yield response.json()["candidates"][0]["content"] + + async def _stream_chat(self, messages: List[ChatMessage], options): + msg_lst = [] + for message in messages: + msg_lst.append({"content": message["content"]}) + + api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}" + body = {"prompt": {"messages": msg_lst}} + response = requests.post(api_url, json=body) + yield { + "content": response.json()["candidates"][0]["content"], + "role": "assistant", + } |