summaryrefslogtreecommitdiff
path: root/server/continuedev/libs/llm/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'server/continuedev/libs/llm/base.py')
-rw-r--r--server/continuedev/libs/llm/base.py458
1 files changed, 458 insertions, 0 deletions
diff --git a/server/continuedev/libs/llm/base.py b/server/continuedev/libs/llm/base.py
new file mode 100644
index 00000000..d77cb9fc
--- /dev/null
+++ b/server/continuedev/libs/llm/base.py
@@ -0,0 +1,458 @@
+import ssl
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+
+import aiohttp
+import certifi
+from pydantic import Field, validator
+
+from ...core.main import ChatMessage
+from ...models.main import ContinueBaseModel
+from ..util.count_tokens import (
+ DEFAULT_ARGS,
+ DEFAULT_MAX_TOKENS,
+ compile_chat_messages,
+ count_tokens,
+ format_chat_messages,
+ prune_raw_prompt_from_top,
+)
+from ..util.devdata import dev_data_logger
+from ..util.telemetry import posthog_logger
+
+
+class CompletionOptions(ContinueBaseModel):
+ """Options for the completion."""
+
+ @validator(
+ "*",
+ pre=True,
+ always=True,
+ )
+ def ignore_none_and_set_default(cls, value, field):
+ return value if value is not None else field.default
+
+ model: Optional[str] = Field(None, description="The model name")
+ temperature: Optional[float] = Field(
+ None, description="The temperature of the completion."
+ )
+ top_p: Optional[float] = Field(None, description="The top_p of the completion.")
+ top_k: Optional[int] = Field(None, description="The top_k of the completion.")
+ presence_penalty: Optional[float] = Field(
+ None, description="The presence penalty Aof the completion."
+ )
+ frequency_penalty: Optional[float] = Field(
+ None, description="The frequency penalty of the completion."
+ )
+ stop: Optional[List[str]] = Field(
+ None, description="The stop tokens of the completion."
+ )
+ max_tokens: int = Field(
+ DEFAULT_MAX_TOKENS, description="The maximum number of tokens to generate."
+ )
+ functions: Optional[List[Any]] = Field(
+ None, description="The functions/tools to make available to the model."
+ )
+
+
+class LLM(ContinueBaseModel):
+ title: Optional[str] = Field(
+ None,
+ description="A title that will identify this model in the model selection dropdown",
+ )
+
+ unique_id: Optional[str] = Field(None, description="The unique ID of the user.")
+ model: str = Field(
+ ..., description="The name of the model to be used (e.g. gpt-4, codellama)"
+ )
+
+ system_message: Optional[str] = Field(
+ None, description="A system message that will always be followed by the LLM"
+ )
+
+ context_length: int = Field(
+ 2048,
+ description="The maximum context length of the LLM in tokens, as counted by count_tokens.",
+ )
+
+ stop_tokens: Optional[List[str]] = Field(
+ None, description="Tokens that will stop the completion."
+ )
+ temperature: Optional[float] = Field(
+ None, description="The temperature of the completion."
+ )
+ top_p: Optional[float] = Field(None, description="The top_p of the completion.")
+ top_k: Optional[int] = Field(None, description="The top_k of the completion.")
+ presence_penalty: Optional[float] = Field(
+ None, description="The presence penalty Aof the completion."
+ )
+ frequency_penalty: Optional[float] = Field(
+ None, description="The frequency penalty of the completion."
+ )
+
+ timeout: Optional[int] = Field(
+ 300,
+ description="Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts.",
+ )
+ verify_ssl: Optional[bool] = Field(
+ None, description="Whether to verify SSL certificates for requests."
+ )
+ ca_bundle_path: str = Field(
+ None,
+ description="Path to a custom CA bundle to use when making the HTTP request",
+ )
+ proxy: Optional[str] = Field(
+ None,
+ description="Proxy URL to use when making the HTTP request",
+ )
+ headers: Optional[Dict[str, str]] = Field(
+ None,
+ description="Headers to use when making the HTTP request",
+ )
+ prompt_templates: dict = Field(
+ {},
+ description='A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.',
+ )
+
+ template_messages: Optional[Callable[[List[Dict[str, str]]], str]] = Field(
+ None,
+ description="A function that takes a list of messages and returns a prompt. This ensures that models like llama2, which are trained on specific chat formats, will always receive input in that format.",
+ )
+ write_log: Optional[Callable[[str], None]] = Field(
+ None,
+ description="A function that is called upon every prompt and completion, by default to log to the file which can be viewed by clicking on the magnifying glass.",
+ )
+
+ api_key: Optional[str] = Field(
+ None, description="The API key for the LLM provider."
+ )
+
+ class Config:
+ arbitrary_types_allowed = True
+ extra = "allow"
+ fields = {
+ "title": {
+ "description": "A title that will identify this model in the model selection dropdown"
+ },
+ "system_message": {
+ "description": "A system message that will always be followed by the LLM"
+ },
+ "context_length": {
+ "description": "The maximum context length of the LLM in tokens, as counted by count_tokens."
+ },
+ "unique_id": {"description": "The unique ID of the user."},
+ "model": {
+ "description": "The name of the model to be used (e.g. gpt-4, codellama)"
+ },
+ "timeout": {
+ "description": "Set the timeout for each request to the LLM. If you are running a local LLM that takes a while to respond, you might want to set this to avoid timeouts."
+ },
+ "prompt_templates": {
+ "description": 'A dictionary of prompt templates that can be used to customize the behavior of the LLM in certain situations. For example, set the "edit" key in order to change the prompt that is used for the /edit slash command. Each value in the dictionary is a string templated in mustache syntax, and filled in at runtime with the variables specific to the situation. See the documentation for more information.'
+ },
+ "template_messages": {
+ "description": "A function that takes a list of messages and returns a prompt. This ensures that models like llama2, which are trained on specific chat formats, will always receive input in that format."
+ },
+ "write_log": {
+ "description": "A function that is called upon every prompt and completion, by default to log to the file which can be viewed by clicking on the magnifying glass."
+ },
+ "api_key": {"description": "The API key for the LLM provider."},
+ "verify_ssl": {
+ "description": "Whether to verify SSL certificates for requests."
+ },
+ "ca_bundle_path": {
+ "description": "Path to a custom CA bundle to use when making the HTTP request"
+ },
+ "headers": {
+ "description": "Headers to use when making the HTTP request"
+ },
+ "proxy": {"description": "Proxy URL to use when making the HTTP request"},
+ "stop_tokens": {"description": "Tokens that will stop the completion."},
+ "temperature": {
+ "description": "The sampling temperature used for generation."
+ },
+ "top_p": {
+ "description": "The top_p sampling parameter used for generation."
+ },
+ "top_k": {
+ "description": "The top_k sampling parameter used for generation."
+ },
+ "presence_penalty": {
+ "description": "The presence penalty used for completions."
+ },
+ "frequency_penalty": {
+ "description": "The frequency penalty used for completions."
+ },
+ }
+
+ def dict(self, **kwargs):
+ original_dict = super().dict(**kwargs)
+ original_dict.pop("write_log")
+ if self.template_messages is not None:
+ original_dict["template_messages"] = self.template_messages.__name__
+ original_dict.pop("unique_id")
+ original_dict["class_name"] = self.__class__.__name__
+ return original_dict
+
+ async def start(
+ self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None
+ ):
+ """Start the connection to the LLM."""
+ self.write_log = write_log
+ self.unique_id = unique_id
+
+ async def stop(self):
+ """Stop the connection to the LLM."""
+ pass
+
+ def create_client_session(self):
+ if self.verify_ssl is False:
+ return aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(verify_ssl=False),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
+ headers=self.headers
+ )
+ else:
+ ca_bundle_path = (
+ certifi.where() if self.ca_bundle_path is None else self.ca_bundle_path
+ )
+ ssl_context = ssl.create_default_context(cafile=ca_bundle_path)
+ return aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl_context=ssl_context),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
+ headers=self.headers,
+ )
+
+ def collect_args(self, options: CompletionOptions) -> Dict[str, Any]:
+ """Collect the arguments for the LLM."""
+ args = {**DEFAULT_ARGS.copy(), "model": self.model}
+ args.update(options.dict(exclude_unset=True, exclude_none=True))
+ return args
+
+ def compile_chat_messages(
+ self,
+ options: CompletionOptions,
+ msgs: List[ChatMessage],
+ functions: Optional[List[Any]] = None,
+ ) -> List[Dict]:
+ return compile_chat_messages(
+ model_name=options.model,
+ msgs=msgs,
+ context_length=self.context_length,
+ max_tokens=options.max_tokens,
+ functions=functions,
+ system_message=self.system_message,
+ )
+
+ def template_prompt_like_messages(self, prompt: str) -> str:
+ if self.template_messages is None:
+ return prompt
+
+ msgs = [{"role": "user", "content": prompt}]
+ if self.system_message is not None:
+ msgs.insert(0, {"role": "system", "content": self.system_message})
+
+ return self.template_messages(msgs)
+
+ async def stream_complete(
+ self,
+ prompt: str,
+ raw: bool = False,
+ model: str = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: int = None,
+ presence_penalty: float = None,
+ frequency_penalty: float = None,
+ stop: Optional[List[str]] = None,
+ max_tokens: Optional[int] = None,
+ functions: Optional[List[Any]] = None,
+ log: bool = True,
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ """Yield completion response, either streamed or not."""
+ options = CompletionOptions(
+ model=model or self.model,
+ temperature=temperature or self.temperature,
+ top_p=top_p or self.top_p,
+ top_k=top_k or self.top_k,
+ presence_penalty=presence_penalty or self.presence_penalty,
+ frequency_penalty=frequency_penalty or self.frequency_penalty,
+ stop=stop or self.stop_tokens,
+ max_tokens=max_tokens,
+ functions=functions,
+ )
+
+ prompt = prune_raw_prompt_from_top(
+ self.model, self.context_length, prompt, options.max_tokens
+ )
+
+ if not raw:
+ prompt = self.template_prompt_like_messages(prompt)
+
+ if log:
+ self.write_log(prompt)
+
+ completion = ""
+ async for chunk in self._stream_complete(prompt=prompt, options=options):
+ yield chunk
+ completion += chunk
+
+ # if log:
+ # self.write_log(f"Completion: \n\n{completion}")
+
+ dev_data_logger.capture(
+ "tokens_generated",
+ {"model": self.model, "tokens": self.count_tokens(completion)},
+ )
+ posthog_logger.capture_event(
+ "tokens_generated",
+ {"model": self.model, "tokens": self.count_tokens(completion)},
+ )
+
+ async def complete(
+ self,
+ prompt: str,
+ raw: bool = False,
+ model: str = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: int = None,
+ presence_penalty: float = None,
+ frequency_penalty: float = None,
+ stop: Optional[List[str]] = None,
+ max_tokens: Optional[int] = None,
+ functions: Optional[List[Any]] = None,
+ log: bool = True,
+ ) -> str:
+ """Yield completion response, either streamed or not."""
+ options = CompletionOptions(
+ model=model or self.model,
+ temperature=temperature or self.temperature,
+ top_p=top_p or self.top_p,
+ top_k=top_k or self.top_k,
+ presence_penalty=presence_penalty or self.presence_penalty,
+ frequency_penalty=frequency_penalty or self.frequency_penalty,
+ stop=stop or self.stop_tokens,
+ max_tokens=max_tokens,
+ functions=functions,
+ )
+
+ prompt = prune_raw_prompt_from_top(
+ self.model, self.context_length, prompt, options.max_tokens
+ )
+
+ if not raw:
+ prompt = self.template_prompt_like_messages(prompt)
+
+ if log:
+ self.write_log(prompt)
+
+ completion = await self._complete(prompt=prompt, options=options)
+
+ # if log:
+ # self.write_log(f"Completion: \n\n{completion}")
+
+ dev_data_logger.capture(
+ "tokens_generated",
+ {"model": self.model, "tokens": self.count_tokens(completion)},
+ )
+ posthog_logger.capture_event(
+ "tokens_generated",
+ {"model": self.model, "tokens": self.count_tokens(completion)},
+ )
+
+ return completion
+
+ async def stream_chat(
+ self,
+ messages: List[ChatMessage],
+ model: str = None,
+ temperature: float = None,
+ top_p: float = None,
+ top_k: int = None,
+ presence_penalty: float = None,
+ frequency_penalty: float = None,
+ stop: Optional[List[str]] = None,
+ max_tokens: Optional[int] = None,
+ functions: Optional[List[Any]] = None,
+ log: bool = True,
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ """Yield completion response, either streamed or not."""
+ options = CompletionOptions(
+ model=model or self.model,
+ temperature=temperature or self.temperature,
+ top_p=top_p or self.top_p,
+ top_k=top_k or self.top_k,
+ presence_penalty=presence_penalty or self.presence_penalty,
+ frequency_penalty=frequency_penalty or self.frequency_penalty,
+ stop=stop or self.stop_tokens,
+ max_tokens=max_tokens,
+ functions=functions,
+ )
+
+ messages = self.compile_chat_messages(
+ options=options, msgs=messages, functions=functions
+ )
+ if self.template_messages is not None:
+ prompt = self.template_messages(messages)
+ else:
+ prompt = format_chat_messages(messages)
+
+ if log:
+ self.write_log(prompt)
+
+ completion = ""
+
+ # Use the template_messages function if it exists and do a raw completion
+ if self.template_messages is None:
+ async for chunk in self._stream_chat(messages=messages, options=options):
+ yield chunk
+ if "content" in chunk:
+ completion += chunk["content"]
+ else:
+ async for chunk in self._stream_complete(prompt=prompt, options=options):
+ yield {"role": "assistant", "content": chunk}
+ completion += chunk
+
+ # if log:
+ # self.write_log(f"Completion: \n\n{completion}")
+
+ dev_data_logger.capture(
+ "tokens_generated",
+ {"model": self.model, "tokens": self.count_tokens(completion)},
+ )
+ posthog_logger.capture_event(
+ "tokens_generated",
+ {"model": self.model, "tokens": self.count_tokens(completion)},
+ )
+
+ def _stream_complete(
+ self, prompt, options: CompletionOptions
+ ) -> Generator[str, None, None]:
+ """Stream the completion through generator."""
+ raise NotImplementedError
+
+ async def _complete(
+ self, prompt: str, options: CompletionOptions
+ ) -> Coroutine[Any, Any, str]:
+ """Return the completion of the text with the given temperature."""
+ completion = ""
+ async for chunk in self._stream_complete(prompt=prompt, options=options):
+ completion += chunk
+ return completion
+
+ async def _stream_chat(
+ self, messages: List[ChatMessage], options: CompletionOptions
+ ) -> Generator[Union[Any, List, Dict], None, None]:
+ """Stream the chat through generator."""
+ if self.template_messages is None:
+ raise NotImplementedError(
+ "You must either implement template_messages or _stream_chat"
+ )
+
+ async for chunk in self._stream_complete(
+ prompt=self.template_messages(messages), options=options
+ ):
+ yield {"role": "assistant", "content": chunk}
+
+ def count_tokens(self, text: str):
+ """Return the number of tokens in the given text."""
+ return count_tokens(self.model, text)