summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/google_palm_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/google_palm_api.py')
-rw-r--r--server/continuedev/libs/llm/google_palm_api.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/google_palm_api.py b/server/continuedev/libs/llm/google_palm_api.py
new file mode 100644
index 00000000..3379fefe
--- /dev/null
+++ b/server/continuedev/libs/llm/google_palm_api.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import requests
+from pydantic import Field
+
+from ...core.main import ChatMessage
+from .base import LLM
+
+
+class GooglePaLMAPI(LLM):
+ """
+ The Google PaLM API is currently in public preview, so production applications are not supported yet. However, you can [create an API key in Google MakerSuite](https://makersuite.google.com/u/2/app/apikey) and begin trying out the `chat-bison-001` model. Change `~/.continue/config.py` to look like this:
+
+ ```python title="~/.continue/config.py"
+ from continuedev.core.models import Models
+ from continuedev.libs.llm.hf_inference_api import GooglePaLMAPI
+
+ config = ContinueConfig(
+ ...
+ models=Models(
+ default=GooglePaLMAPI(
+ model="chat-bison-001"
+ api_key="<MAKERSUITE_API_KEY>",
+ )
+ )
+ ```
+ """
+
+ api_key: str = Field(..., description="Google PaLM API key")
+
+ model: str = "chat-bison-001"
+
+ async def _stream_complete(self, prompt, options):
+ api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}"
+ body = {"prompt": {"messages": [{"content": prompt}]}}
+ response = requests.post(api_url, json=body)
+ yield response.json()["candidates"][0]["content"]
+
+ async def _stream_chat(self, messages: List[ChatMessage], options):
+ msg_lst = []
+ for message in messages:
+ msg_lst.append({"content": message["content"]})
+
+ api_url = f"https://generativelanguage.googleapis.com/v1beta2/models/{self.model}:generateMessage?key={self.api_key}"
+ body = {"prompt": {"messages": msg_lst}}
+ response = requests.post(api_url, json=body)
+ yield {
+ "content": response.json()["candidates"][0]["content"],
+ "role": "assistant",
+ }