summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/bug-report-🐛.md36
-rw-r--r--.github/ISSUE_TEMPLATE/feature-request-💪.md6
-rw-r--r--continuedev/src/continuedev/core/abstract_sdk.py4
-rw-r--r--continuedev/src/continuedev/core/autopilot.py2
-rw-r--r--continuedev/src/continuedev/core/sdk.py13
-rw-r--r--continuedev/src/continuedev/libs/llm/__init__.py55
-rw-r--r--continuedev/src/continuedev/libs/llm/anthropic.py55
-rw-r--r--continuedev/src/continuedev/libs/llm/ggml.py55
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_inference_api.py29
-rw-r--r--continuedev/src/continuedev/libs/llm/hf_tgi.py53
-rw-r--r--continuedev/src/continuedev/libs/llm/llamacpp.py50
-rw-r--r--continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py39
-rw-r--r--continuedev/src/continuedev/libs/llm/ollama.py39
-rw-r--r--continuedev/src/continuedev/libs/llm/openai.py93
-rw-r--r--continuedev/src/continuedev/libs/llm/proxy_server.py52
-rw-r--r--continuedev/src/continuedev/libs/llm/replicate.py31
-rw-r--r--continuedev/src/continuedev/libs/llm/text_gen_interface.py52
-rw-r--r--continuedev/src/continuedev/libs/llm/together.py39
-rw-r--r--continuedev/src/continuedev/libs/util/edit_config.py3
-rw-r--r--continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py8
-rw-r--r--continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py2
-rw-r--r--continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/chat.py14
-rw-r--r--continuedev/src/continuedev/plugins/steps/chroma.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/core/core.py9
-rw-r--r--continuedev/src/continuedev/plugins/steps/help.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/main.py4
-rw-r--r--continuedev/src/continuedev/plugins/steps/react.py2
-rw-r--r--continuedev/src/continuedev/plugins/steps/search_directory.py2
30 files changed, 242 insertions, 513 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug-report-🐛.md b/.github/ISSUE_TEMPLATE/bug-report-🐛.md
index ab37cfbe..1070d7e2 100644
--- a/.github/ISSUE_TEMPLATE/bug-report-🐛.md
+++ b/.github/ISSUE_TEMPLATE/bug-report-🐛.md
@@ -1,10 +1,9 @@
---
name: "Bug report \U0001F41B"
about: Create a report to help us fix your bug
-title: ''
+title: ""
labels: bug
-assignees: ''
-
+assignees: ""
---
**Describe the bug**
@@ -12,32 +11,45 @@ A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
+
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
-**Expected behavior**
-A clear and concise description of what you expected to happen.
-
-**Screenshots**
-If applicable, add screenshots to help explain your problem.
-
**Environment**
+
- Operating System: [e.g. MacOS]
- Python Version: [e.g. 3.10.6]
- Continue Version: [e.g. v0.0.207]
-**Console logs**
+**Logs**
+
```
REPLACE THIS SECTION WITH CONSOLE LOGS OR A SCREENSHOT...
+```
+
+To get the Continue server logs:
+
+1. cmd+shift+p (MacOS) / ctrl+shift+p (Windows)
+2. Search for and then select "Continue: View Continue Server Logs"
+3. Scroll to the bottom of `continue.log` and copy the last 100 lines or so
+
+To get the VS Code console logs:
-To get the console logs in VS Code:
1. cmd+shift+p (MacOS) / ctrl+shift+p (Windows)
2. Search for and then select "Developer: Toggle Developer Tools"
3. Select Console
4. Read the console logs
-```
+
+If the problem is related to LLM prompting:
+
+1. Hover the problematic response in the Continue UI
+2. Click the "magnifying glass" icon
+3. Copy the contents of the `continue_logs.txt` file that opens
+
+**Screenshots**
+If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
diff --git a/.github/ISSUE_TEMPLATE/feature-request-💪.md b/.github/ISSUE_TEMPLATE/feature-request-💪.md
index 2b138a9a..b356d488 100644
--- a/.github/ISSUE_TEMPLATE/feature-request-💪.md
+++ b/.github/ISSUE_TEMPLATE/feature-request-💪.md
@@ -1,10 +1,9 @@
---
name: "Feature request \U0001F4AA"
about: Suggest an idea for this project
-title: ''
+title: ""
labels: enhancement
assignees: TyDunn
-
---
**Is your feature request related to a problem? Please describe.**
@@ -13,8 +12,5 @@ A clear and concise description of what the problem is. Ex. I'm always frustrate
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
-**Describe alternatives you've considered**
-A clear and concise description of any alternative solutions or features you've considered.
-
**Additional context**
Add any other context or screenshots about the feature request here.
diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py
index 98730d38..fdb99d47 100644
--- a/continuedev/src/continuedev/core/abstract_sdk.py
+++ b/continuedev/src/continuedev/core/abstract_sdk.py
@@ -71,10 +71,6 @@ class AbstractContinueSDK(ABC):
async def delete_directory(self, path: str):
pass
- @abstractmethod
- async def get_user_secret(self, env_var: str) -> str:
- pass
-
config: ContinueConfig
@abstractmethod
diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py
index bae82739..de0b8c53 100644
--- a/continuedev/src/continuedev/core/autopilot.py
+++ b/continuedev/src/continuedev/core/autopilot.py
@@ -507,7 +507,7 @@ class Autopilot(ContinueBaseModel):
if self.session_info is None:
async def create_title():
- title = await self.continue_sdk.models.medium.complete(
+ title = await self.continue_sdk.models.medium._complete(
f'Give a short title to describe the current chat session. Do not put quotes around the title. The first message was: "{user_input}". Do not use more than 10 words. The title is: ',
max_tokens=20,
)
diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py
index 37992b67..9b1c2cd0 100644
--- a/continuedev/src/continuedev/core/sdk.py
+++ b/continuedev/src/continuedev/core/sdk.py
@@ -94,14 +94,7 @@ class ContinueSDK(AbstractContinueSDK):
self.history.timeline[self.history.current_index].logs.append(message)
async def start_model(self, llm: LLM):
- kwargs = {}
- if llm.requires_api_key:
- kwargs["api_key"] = await self.get_user_secret(llm.requires_api_key)
- if llm.requires_unique_id:
- kwargs["unique_id"] = self.ide.unique_id
- if llm.requires_write_log:
- kwargs["write_log"] = self.write_log
- await llm.start(**kwargs)
+ await llm.start(unique_id=self.ide.unique_id, write_log=self.write_log)
async def _ensure_absolute_path(self, path: str) -> str:
if os.path.isabs(path):
@@ -211,10 +204,6 @@ class ContinueSDK(AbstractContinueSDK):
path = await self._ensure_absolute_path(path)
return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))
- async def get_user_secret(self, env_var: str) -> str:
- # TODO support error prompt dynamically set on env_var
- return await self.ide.getUserSecret(env_var)
-
_last_valid_config: ContinueConfig = None
def _load_config_dot_py(self) -> ContinueConfig:
diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py
index 1e77a691..6a321a41 100644
--- a/continuedev/src/continuedev/libs/llm/__init__.py
+++ b/continuedev/src/continuedev/libs/llm/__init__.py
@@ -1,19 +1,32 @@
-from abc import ABC, abstractproperty
-from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
+from abc import ABC
+from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from ...models.main import ContinueBaseModel
+from ..util.count_tokens import DEFAULT_ARGS, count_tokens
class LLM(ContinueBaseModel, ABC):
- requires_api_key: Optional[str] = None
- requires_unique_id: bool = False
- requires_write_log: bool = False
title: Optional[str] = None
system_message: Optional[str] = None
+ context_length: int = 2048
+ "The maximum context length of the LLM in tokens, as counted by count_tokens."
+
+ unique_id: Optional[str] = None
+ "The unique ID of the user."
+
+ model: str
+ "The model name"
+
prompt_templates: dict = {}
+ write_log: Optional[Callable[[str], None]] = None
+ "A function that takes a string and writes it to the log."
+
+ api_key: Optional[str] = None
+ "The API key for the LLM provider."
+
class Config:
arbitrary_types_allowed = True
extra = "allow"
@@ -21,36 +34,39 @@ class LLM(ContinueBaseModel, ABC):
def dict(self, **kwargs):
original_dict = super().dict(**kwargs)
original_dict.pop("write_log", None)
- original_dict["name"] = self.name
original_dict["class_name"] = self.__class__.__name__
return original_dict
- @abstractproperty
- def name(self):
- """Return the name of the LLM."""
- raise NotImplementedError
+ def collect_args(self, **kwargs) -> Any:
+ """Collect the arguments for the LLM."""
+ args = {**DEFAULT_ARGS.copy(), "model": self.model, "max_tokens": 1024}
+ args.update(kwargs)
+ return args
- async def start(self, *, api_key: Optional[str] = None, **kwargs):
+ async def start(
+ self, write_log: Callable[[str], None] = None, unique_id: Optional[str] = None
+ ):
"""Start the connection to the LLM."""
- raise NotImplementedError
+ self.write_log = write_log
+ self.unique_id = unique_id
async def stop(self):
"""Stop the connection to the LLM."""
- raise NotImplementedError
+ pass
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
"""Return the completion of the text with the given temperature."""
raise NotImplementedError
- def stream_complete(
+ def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the completion through generator."""
raise NotImplementedError
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
"""Stream the chat through generator."""
@@ -58,9 +74,4 @@ class LLM(ContinueBaseModel, ABC):
def count_tokens(self, text: str):
"""Return the number of tokens in the given text."""
- raise NotImplementedError
-
- @abstractproperty
- def context_length(self) -> int:
- """Return the context length of the LLM in tokens, as counted by count_tokens."""
- raise NotImplementedError
+ return count_tokens(self.model, text)
diff --git a/continuedev/src/continuedev/libs/llm/anthropic.py b/continuedev/src/continuedev/libs/llm/anthropic.py
index 16bc2fce..b5aff63a 100644
--- a/continuedev/src/continuedev/libs/llm/anthropic.py
+++ b/continuedev/src/continuedev/libs/llm/anthropic.py
@@ -1,47 +1,36 @@
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Coroutine, Dict, Generator, List, Union
from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from ..util.count_tokens import compile_chat_messages
class AnthropicLLM(LLM):
api_key: str
+ "Anthropic API key"
+
model: str = "claude-2"
- requires_write_log = True
_async_client: AsyncAnthropic = None
class Config:
arbitrary_types_allowed = True
- write_log: Optional[Callable[[str], None]] = None
-
async def start(
self,
- *,
- api_key: Optional[str] = None,
- write_log: Callable[[str], None],
**kwargs,
):
- self.write_log = write_log
+ await super().start(**kwargs)
self._async_client = AsyncAnthropic(api_key=self.api_key)
- async def stop(self):
- pass
-
- @property
- def name(self):
- return self.model
+ if self.model == "claude-2":
+ self.context_length = 100_000
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.model}
+ def collect_args(self, **kwargs) -> Any:
+ args = super().collect_args(**kwargs)
- def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
- args = args.copy()
if "max_tokens" in args:
args["max_tokens_to_sample"] = args["max_tokens"]
del args["max_tokens"]
@@ -51,15 +40,6 @@ class AnthropicLLM(LLM):
del args["presence_penalty"]
return args
- def count_tokens(self, text: str):
- return count_tokens(self.model, text)
-
- @property
- def context_length(self):
- if self.model == "claude-2":
- return 100000
- raise Exception(f"Unknown Anthropic model {self.model}")
-
def __messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
prompt = ""
@@ -76,13 +56,11 @@ class AnthropicLLM(LLM):
prompt += AI_PROMPT
return prompt
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
- args = self._transform_args(args)
prompt = f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}"
self.write_log(f"Prompt: \n\n{prompt}")
@@ -95,13 +73,11 @@ class AnthropicLLM(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
- args = self._transform_args(args)
messages = compile_chat_messages(
args["model"],
@@ -123,11 +99,10 @@ class AnthropicLLM(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
- args = self._transform_args(args)
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
args["model"],
diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py
index db3aaed7..1668fb65 100644
--- a/continuedev/src/continuedev/libs/llm/ggml.py
+++ b/continuedev/src/continuedev/libs/llm/ggml.py
@@ -1,62 +1,29 @@
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import compile_chat_messages, format_chat_messages
class GGML(LLM):
- # this is model-specific
- max_context_length: int = 2048
server_url: str = "http://localhost:8000"
verify_ssl: Optional[bool] = None
-
- requires_write_log = True
-
- write_log: Optional[Callable[[str], None]] = None
+ model: str = "ggml"
class Config:
arbitrary_types_allowed = True
- async def start(self, write_log: Callable[[str], None], **kwargs):
- self.write_log = write_log
-
- async def stop(self):
- pass
-
- @property
- def name(self):
- return "ggml"
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
-
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
- args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name,
+ self.model,
with_history,
self.context_length,
args["max_tokens"],
@@ -84,12 +51,12 @@ class GGML(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
messages,
self.context_length,
args["max_tokens"],
@@ -142,10 +109,10 @@ class GGML(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
self.write_log(f"Prompt: \n\n{prompt}")
async with aiohttp.ClientSession(
diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
index 4b78a247..3a586a43 100644
--- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py
+++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py
@@ -5,17 +5,14 @@ import requests
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, count_tokens
DEFAULT_MAX_TIME = 120.0
class HuggingFaceInferenceAPI(LLM):
- model: str
hf_token: str
self_hosted_url: str = None
- max_context_length: int = 2048
verify_ssl: Optional[bool] = None
_client_session: aiohttp.ClientSession = None
@@ -24,6 +21,7 @@ class HuggingFaceInferenceAPI(LLM):
arbitrary_types_allowed = True
async def start(self, **kwargs):
+ await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
)
@@ -31,22 +29,7 @@ class HuggingFaceInferenceAPI(LLM):
async def stop(self):
await self._client_session.close()
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
-
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
):
"""Return the completion of the text with the given temperature."""
@@ -77,14 +60,14 @@ class HuggingFaceInferenceAPI(LLM):
return data[0]["generated_text"]
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]:
- response = await self.complete(messages[-1].content, messages[:-1])
+ response = await self._complete(messages[-1].content, messages[:-1])
yield {"content": response, "role": "assistant"}
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Any | List | Dict, None, None]:
- response = await self.complete(prompt, with_history)
+ response = await self._complete(prompt, with_history)
yield response
diff --git a/continuedev/src/continuedev/libs/llm/hf_tgi.py b/continuedev/src/continuedev/libs/llm/hf_tgi.py
index f04e700d..f106f83f 100644
--- a/continuedev/src/continuedev/libs/llm/hf_tgi.py
+++ b/continuedev/src/continuedev/libs/llm/hf_tgi.py
@@ -5,44 +5,22 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from ..util.count_tokens import compile_chat_messages
from .prompts.chat import code_llama_template_messages
class HuggingFaceTGI(LLM):
model: str = "huggingface-tgi"
- max_context_length: int = 2048
server_url: str = "http://localhost:8080"
verify_ssl: Optional[bool] = None
template_messages: Callable[[List[ChatMessage]], str] = code_llama_template_messages
- requires_write_log = True
-
- write_log: Optional[Callable[[str], None]] = None
-
class Config:
arbitrary_types_allowed = True
- async def start(self, write_log: Callable[[str], None], **kwargs):
- self.write_log = write_log
-
- async def stop(self):
- pass
-
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
-
- def _transform_args(self, args):
+ def collect_args(self, **kwargs) -> Any:
+ args = super().collect_args(**kwargs)
args = {
**args,
"max_new_tokens": args.get("max_tokens", 1024),
@@ -50,19 +28,14 @@ class HuggingFaceTGI(LLM):
args.pop("max_tokens", None)
return args
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
- args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name,
+ self.model,
with_history,
self.context_length,
args["max_tokens"],
@@ -93,12 +66,12 @@ class HuggingFaceTGI(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
messages,
self.context_length,
args["max_tokens"],
@@ -107,7 +80,7 @@ class HuggingFaceTGI(LLM):
system_message=self.system_message,
)
- async for chunk in self.stream_complete(
+ async for chunk in self._stream_complete(
None, self.template_messages(messages), **args
):
yield {
@@ -115,13 +88,13 @@ class HuggingFaceTGI(LLM):
"content": chunk,
}
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
completion = ""
- async for chunk in self.stream_complete(prompt, with_history, **args):
+ async for chunk in self._stream_complete(prompt, with_history, **args):
completion += chunk
return completion
diff --git a/continuedev/src/continuedev/libs/llm/llamacpp.py b/continuedev/src/continuedev/libs/llm/llamacpp.py
index e6f38cd0..7940c4c9 100644
--- a/continuedev/src/continuedev/libs/llm/llamacpp.py
+++ b/continuedev/src/continuedev/libs/llm/llamacpp.py
@@ -6,12 +6,12 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from ..util.count_tokens import compile_chat_messages
from .prompts.chat import llama2_template_messages
class LlamaCpp(LLM):
- max_context_length: int = 2048
+ model: str = "llamacpp"
server_url: str = "http://localhost:8080"
verify_ssl: Optional[bool] = None
@@ -20,9 +20,6 @@ class LlamaCpp(LLM):
use_command: Optional[str] = None
- requires_write_log = True
- write_log: Optional[Callable[[str], None]] = None
-
class Config:
arbitrary_types_allowed = True
@@ -31,29 +28,8 @@ class LlamaCpp(LLM):
d.pop("template_messages")
return d
- async def start(self, write_log: Callable[[str], None], **kwargs):
- self.write_log = write_log
-
- async def stop(self):
- await self._client_session.close()
-
- @property
- def name(self):
- return "llamacpp"
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
-
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- def _transform_args(self, args: Dict[str, Any]) -> Dict[str, Any]:
- args = args.copy()
+ def collect_args(self, **kwargs) -> Any:
+ args = super().collect_args(**kwargs)
if "max_tokens" in args:
args["n_predict"] = args["max_tokens"]
del args["max_tokens"]
@@ -85,16 +61,14 @@ class LlamaCpp(LLM):
await process.wait()
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
- args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name,
+ self.model,
with_history,
self.context_length,
args["max_tokens"],
@@ -125,12 +99,12 @@ class LlamaCpp(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
messages,
self.context_length,
args["max_tokens"],
@@ -177,10 +151,10 @@ class LlamaCpp(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
self.write_log(f"Prompt: \n\n{prompt}")
diff --git a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
index daffe41f..99b7c47f 100644
--- a/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
+++ b/continuedev/src/continuedev/libs/llm/maybe_proxy_openai.py
@@ -1,4 +1,4 @@
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
from ...core.main import ChatMessage
from . import LLM
@@ -10,63 +10,42 @@ class MaybeProxyOpenAI(LLM):
model: str
api_key: Optional[str] = None
- requires_write_log: bool = True
- requires_unique_id: bool = True
- system_message: Union[str, None] = None
-
llm: Optional[LLM] = None
def update_llm_properties(self):
if self.llm is not None:
self.llm.system_message = self.system_message
- @property
- def name(self):
- if self.llm is not None:
- return self.llm.name
- else:
- return None
-
- @property
- def context_length(self):
- return self.llm.context_length
-
- async def start(
- self,
- *,
- api_key: Optional[str] = None,
- unique_id: str,
- write_log: Callable[[str], None]
- ):
+ async def start(self, **kwargs):
if self.api_key is None or self.api_key.strip() == "":
self.llm = ProxyServer(model=self.model)
else:
self.llm = OpenAI(api_key=self.api_key, model=self.model)
- await self.llm.start(write_log=write_log, unique_id=unique_id)
+ await self.llm.start(**kwargs)
async def stop(self):
await self.llm.stop()
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
self.update_llm_properties()
- return await self.llm.complete(prompt, with_history=with_history, **kwargs)
+ return await self.llm._complete(prompt, with_history=with_history, **kwargs)
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
self.update_llm_properties()
- resp = self.llm.stream_complete(prompt, with_history=with_history, **kwargs)
+ resp = self.llm._stream_complete(prompt, with_history=with_history, **kwargs)
async for item in resp:
yield item
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
self.update_llm_properties()
- resp = self.llm.stream_chat(messages=messages, **kwargs)
+ resp = self.llm._stream_chat(messages=messages, **kwargs)
async for item in resp:
yield item
diff --git a/continuedev/src/continuedev/libs/llm/ollama.py b/continuedev/src/continuedev/libs/llm/ollama.py
index 03300435..ef8ed47b 100644
--- a/continuedev/src/continuedev/libs/llm/ollama.py
+++ b/continuedev/src/continuedev/libs/llm/ollama.py
@@ -7,17 +7,15 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from ..util.count_tokens import compile_chat_messages
from .prompts.chat import llama2_template_messages
class Ollama(LLM):
model: str = "llama2"
server_url: str = "http://localhost:11434"
- max_context_length: int = 2048
_client_session: aiohttp.ClientSession = None
- requires_write_log = True
prompt_templates = {
"edit": dedent(
@@ -36,34 +34,19 @@ class Ollama(LLM):
class Config:
arbitrary_types_allowed = True
- async def start(self, write_log, **kwargs):
+ async def start(self, **kwargs):
+ await super().start(**kwargs)
self._client_session = aiohttp.ClientSession()
- self.write_log = write_log
async def stop(self):
await self._client_session.close()
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self) -> int:
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
-
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
with_history,
self.context_length,
args["max_tokens"],
@@ -102,12 +85,12 @@ class Ollama(LLM):
yield urllib.parse.unquote(url_decode_buffer)
url_decode_buffer = ""
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
messages,
self.context_length,
args["max_tokens"],
@@ -143,11 +126,11 @@ class Ollama(LLM):
completion += j["response"]
self.write_log(f"Completion:\n{completion}")
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
completion = ""
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
async with self._client_session.post(
f"{self.server_url}/api/generate",
json={
diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py
index c5d19ed2..a017af22 100644
--- a/continuedev/src/continuedev/libs/llm/openai.py
+++ b/continuedev/src/continuedev/libs/llm/openai.py
@@ -12,26 +12,15 @@ from typing import (
import certifi
import openai
-from pydantic import BaseModel
from ...core.main import ChatMessage
from ..llm import LLM
from ..util.count_tokens import (
- DEFAULT_ARGS,
compile_chat_messages,
- count_tokens,
format_chat_messages,
prune_raw_prompt_from_top,
)
-
-class OpenAIServerInfo(BaseModel):
- api_base: Optional[str] = None
- engine: Optional[str] = None
- api_version: Optional[str] = None
- api_type: Literal["azure", "openai"] = "openai"
-
-
CHAT_MODELS = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613"}
MAX_TOKENS_FOR_MODEL = {
"gpt-3.5-turbo": 4096,
@@ -47,32 +36,43 @@ MAX_TOKENS_FOR_MODEL = {
class OpenAI(LLM):
api_key: str
- model: str
- openai_server_info: Optional[OpenAIServerInfo] = None
+ "OpenAI API key"
+
verify_ssl: Optional[bool] = None
+ "Whether to verify SSL certificates for requests."
+
ca_bundle_path: Optional[str] = None
+ "Path to CA bundle to use for requests."
+
proxy: Optional[str] = None
+ "Proxy URL to use for requests."
- requires_write_log = True
+ api_base: Optional[str] = None
+ "OpenAI API base URL."
- write_log: Optional[Callable[[str], None]] = None
+ api_type: Optional[Literal["azure", "openai"]] = None
+ "OpenAI API type."
+
+ api_version: Optional[str] = None
+ "OpenAI API version. For use with Azure OpenAI Service."
+
+ engine: Optional[str] = None
+ "OpenAI engine. For use with Azure OpenAI Service."
async def start(
- self,
- *,
- api_key: Optional[str] = None,
- write_log: Callable[[str], None],
- **kwargs,
+ self, unique_id: Optional[str] = None, write_log: Callable[[str], None] = None
):
- self.write_log = write_log
- openai.api_key = self.api_key
+ await super().start(write_log=write_log, unique_id=unique_id)
+
+ self.context_length = MAX_TOKENS_FOR_MODEL.get(self.model, 4096)
- if self.openai_server_info is not None:
- openai.api_type = self.openai_server_info.api_type
- if self.openai_server_info.api_base is not None:
- openai.api_base = self.openai_server_info.api_base
- if self.openai_server_info.api_version is not None:
- openai.api_version = self.openai_server_info.api_version
+ openai.api_key = self.api_key
+ if self.api_type is not None:
+ openai.api_type = self.api_type
+ if self.api_base is not None:
+ openai.api_base = self.api_base
+ if self.api_version is not None:
+ openai.api_version = self.api_version
if self.verify_ssl is not None and self.verify_ssl is False:
openai.verify_ssl_certs = False
@@ -82,32 +82,16 @@ class OpenAI(LLM):
openai.ca_bundle_path = self.ca_bundle_path or certifi.where()
- async def stop(self):
- pass
-
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return MAX_TOKENS_FOR_MODEL.get(self.model, 4096)
-
- @property
- def default_args(self):
- args = {**DEFAULT_ARGS, "model": self.model}
- if self.openai_server_info is not None:
- args["engine"] = self.openai_server_info.engine
+ def collect_args(self, **kwargs):
+ args = super().collect_args()
+ if self.engine is not None:
+ args["engine"] = self.engine
return args
- def count_tokens(self, text: str):
- return count_tokens(self.model, text)
-
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
if args["model"] in CHAT_MODELS:
@@ -142,11 +126,10 @@ class OpenAI(LLM):
self.write_log(f"Completion:\n\n{completion}")
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
if not args["model"].endswith("0613") and "functions" in args:
@@ -174,10 +157,10 @@ class OpenAI(LLM):
completion += chunk.choices[0].delta.content
self.write_log(f"Completion: \n\n{completion}")
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
if args["model"] in CHAT_MODELS:
messages = compile_chat_messages(
diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py
index fa77a22a..3ac6371f 100644
--- a/continuedev/src/continuedev/libs/llm/proxy_server.py
+++ b/continuedev/src/continuedev/libs/llm/proxy_server.py
@@ -1,19 +1,14 @@
import json
import ssl
import traceback
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Coroutine, Dict, Generator, List, Union
import aiohttp
import certifi
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import compile_chat_messages, format_chat_messages
from ..util.telemetry import posthog_logger
ca_bundle_path = certifi.where()
@@ -31,59 +26,32 @@ MAX_TOKENS_FOR_MODEL = {
class ProxyServer(LLM):
- model: str
- system_message: Optional[str]
-
- unique_id: str = None
- write_log: Callable[[str], None] = None
_client_session: aiohttp.ClientSession
- requires_unique_id = True
- requires_write_log = True
-
class Config:
arbitrary_types_allowed = True
async def start(
self,
- *,
- api_key: Optional[str] = None,
- write_log: Callable[[str], None],
- unique_id: str,
**kwargs,
):
+ await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl_context=ssl_context)
)
- self.write_log = write_log
- self.unique_id = unique_id
+ self.context_length = MAX_TOKENS_FOR_MODEL[self.model]
async def stop(self):
await self._client_session.close()
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return MAX_TOKENS_FOR_MODEL[self.model]
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.model}
-
- def count_tokens(self, text: str):
- return count_tokens(self.model, text)
-
def get_headers(self):
# headers with unique id
return {"unique_id": self.unique_id}
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
args["model"],
@@ -107,10 +75,10 @@ class ProxyServer(LLM):
self.write_log(f"Completion: \n\n{response_text}")
return response_text
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
args["model"],
messages,
@@ -158,10 +126,10 @@ class ProxyServer(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
self.model,
with_history,
diff --git a/continuedev/src/continuedev/libs/llm/replicate.py b/continuedev/src/continuedev/libs/llm/replicate.py
index 0424d827..fb0d3f5c 100644
--- a/continuedev/src/continuedev/libs/llm/replicate.py
+++ b/continuedev/src/continuedev/libs/llm/replicate.py
@@ -4,39 +4,22 @@ from typing import List
import replicate
from ...core.main import ChatMessage
-from ..util.count_tokens import DEFAULT_ARGS, count_tokens
from . import LLM
class ReplicateLLM(LLM):
api_key: str
+ "Replicate API key"
+
model: str = "replicate/llama-2-70b-chat:58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781"
- max_context_length: int = 2048
_client: replicate.Client = None
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}
-
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def start(self):
+ async def start(self, **kwargs):
+ await super().start(**kwargs)
self._client = replicate.Client(api_token=self.api_key)
- async def stop(self):
- pass
-
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
):
def helper():
@@ -55,7 +38,7 @@ class ReplicateLLM(LLM):
return completion
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
):
for item in self._client.run(
@@ -63,7 +46,7 @@ class ReplicateLLM(LLM):
):
yield item
- async def stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
+ async def _stream_chat(self, messages: List[ChatMessage] = None, **kwargs):
for item in self._client.run(
self.model,
input={"message": messages[-1].content, "prompt": messages[-1].content},
diff --git a/continuedev/src/continuedev/libs/llm/text_gen_interface.py b/continuedev/src/continuedev/libs/llm/text_gen_interface.py
index 380f7b48..59627629 100644
--- a/continuedev/src/continuedev/libs/llm/text_gen_interface.py
+++ b/continuedev/src/continuedev/libs/llm/text_gen_interface.py
@@ -1,51 +1,23 @@
import json
-from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
+from typing import Any, Coroutine, Dict, Generator, List, Optional, Union
import websockets
from ...core.main import ChatMessage
-from ..util.count_tokens import (
- DEFAULT_ARGS,
- compile_chat_messages,
- count_tokens,
- format_chat_messages,
-)
+from ..util.count_tokens import compile_chat_messages, format_chat_messages
from . import LLM
class TextGenUI(LLM):
# this is model-specific
model: str = "text-gen-ui"
- max_context_length: int = 2048
server_url: str = "http://localhost:5000"
streaming_url: str = "http://localhost:5005"
verify_ssl: Optional[bool] = None
- requires_write_log = True
-
- write_log: Optional[Callable[[str], None]] = None
-
class Config:
arbitrary_types_allowed = True
- async def start(self, write_log: Callable[[str], None], **kwargs):
- self.write_log = write_log
-
- async def stop(self):
- pass
-
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.name, "max_tokens": 1024}
-
def _transform_args(self, args):
args = {
**args,
@@ -54,18 +26,12 @@ class TextGenUI(LLM):
args.pop("max_tokens", None)
return args
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream"] = True
- args = {**self.default_args, **kwargs}
-
self.write_log(f"Prompt: \n\n{prompt}")
completion = ""
@@ -89,12 +55,12 @@ class TextGenUI(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
messages,
self.context_length,
args["max_tokens"],
@@ -146,10 +112,10 @@ class TextGenUI(LLM):
self.write_log(f"Completion: \n\n{completion}")
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- generator = self.stream_chat(
+ generator = self._stream_chat(
[ChatMessage(role="user", content=prompt, summary=prompt)], **kwargs
)
diff --git a/continuedev/src/continuedev/libs/llm/together.py b/continuedev/src/continuedev/libs/llm/together.py
index ddae91a9..d8c7334b 100644
--- a/continuedev/src/continuedev/libs/llm/together.py
+++ b/continuedev/src/continuedev/libs/llm/together.py
@@ -5,21 +5,23 @@ import aiohttp
from ...core.main import ChatMessage
from ..llm import LLM
-from ..util.count_tokens import DEFAULT_ARGS, compile_chat_messages, count_tokens
+from ..util.count_tokens import compile_chat_messages
from .prompts.chat import llama2_template_messages
class TogetherLLM(LLM):
# this is model-specific
api_key: str
+ "Together API key"
+
model: str = "togethercomputer/RedPajama-INCITE-7B-Instruct"
- max_context_length: int = 2048
base_url: str = "https://api.together.xyz"
verify_ssl: Optional[bool] = None
_client_session: aiohttp.ClientSession = None
async def start(self, **kwargs):
+ await super().start(**kwargs)
self._client_session = aiohttp.ClientSession(
connector=aiohttp.TCPConnector(verify_ssl=self.verify_ssl)
)
@@ -27,31 +29,14 @@ class TogetherLLM(LLM):
async def stop(self):
await self._client_session.close()
- @property
- def name(self):
- return self.model
-
- @property
- def context_length(self):
- return self.max_context_length
-
- @property
- def default_args(self):
- return {**DEFAULT_ARGS, "model": self.model, "max_tokens": 1024}
-
- def count_tokens(self, text: str):
- return count_tokens(self.name, text)
-
- async def stream_complete(
+ async def _stream_complete(
self, prompt, with_history: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = self.default_args.copy()
- args.update(kwargs)
+ args = self.collect_args(**kwargs)
args["stream_tokens"] = True
- args = {**self.default_args, **kwargs}
messages = compile_chat_messages(
- self.name,
+ self.model,
with_history,
self.context_length,
args["max_tokens"],
@@ -72,12 +57,12 @@ class TogetherLLM(LLM):
except:
raise Exception(str(line))
- async def stream_chat(
+ async def _stream_chat(
self, messages: List[ChatMessage] = None, **kwargs
) -> Generator[Union[Any, List, Dict], None, None]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
- self.name,
+ self.model,
messages,
self.context_length,
args["max_tokens"],
@@ -112,10 +97,10 @@ class TogetherLLM(LLM):
"content": json_chunk["choices"][0]["text"],
}
- async def complete(
+ async def _complete(
self, prompt: str, with_history: List[ChatMessage] = None, **kwargs
) -> Coroutine[Any, Any, str]:
- args = {**self.default_args, **kwargs}
+ args = self.collect_args(**kwargs)
messages = compile_chat_messages(
args["model"],
diff --git a/continuedev/src/continuedev/libs/util/edit_config.py b/continuedev/src/continuedev/libs/util/edit_config.py
index eed43054..45a4a599 100644
--- a/continuedev/src/continuedev/libs/util/edit_config.py
+++ b/continuedev/src/continuedev/libs/util/edit_config.py
@@ -74,9 +74,6 @@ def add_config_import(line: str):
filtered_attrs = {
- "requires_api_key",
- "requires_unique_id",
- "requires_write_log",
"class_name",
"name",
"llm",
diff --git a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
index 43a2b800..fe049268 100644
--- a/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/CreatePipelineRecipe/steps.py
@@ -30,7 +30,7 @@ class SetupPipelineStep(Step):
sdk.context.set("api_description", self.api_description)
source_name = (
- await sdk.models.medium.complete(
+ await sdk.models.medium._complete(
f"Write a snake_case name for the data source described by {self.api_description}: "
)
).strip()
@@ -115,7 +115,7 @@ class ValidatePipelineStep(Step):
if "Traceback" in output or "SyntaxError" in output:
output = "Traceback" + output.split("Traceback")[-1]
file_content = await sdk.ide.readFile(os.path.join(workspace_dir, filename))
- suggestion = await sdk.models.medium.complete(
+ suggestion = await sdk.models.medium._complete(
dedent(
f"""\
```python
@@ -131,7 +131,7 @@ class ValidatePipelineStep(Step):
)
)
- api_documentation_url = await sdk.models.medium.complete(
+ api_documentation_url = await sdk.models.medium._complete(
dedent(
f"""\
The API I am trying to call is the '{sdk.context.get('api_description')}'. I tried calling it in the @resource function like this:
@@ -216,7 +216,7 @@ class RunQueryStep(Step):
)
if "Traceback" in output or "SyntaxError" in output:
- suggestion = await sdk.models.medium.complete(
+ suggestion = await sdk.models.medium._complete(
dedent(
f"""\
```python
diff --git a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py
index d6769148..44065d22 100644
--- a/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py
+++ b/continuedev/src/continuedev/plugins/recipes/DDtoBQRecipe/steps.py
@@ -92,7 +92,7 @@ class LoadDataStep(Step):
docs = f.read()
output = "Traceback" + output.split("Traceback")[-1]
- suggestion = await sdk.models.default.complete(
+ suggestion = await sdk.models.default._complete(
dedent(
f"""\
When trying to load data into BigQuery, the following error occurred:
diff --git a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
index e2712746..4727c994 100644
--- a/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
+++ b/continuedev/src/continuedev/plugins/recipes/WritePytestsRecipe/main.py
@@ -45,7 +45,7 @@ class WritePytestsRecipe(Step):
Here is a complete set of pytest unit tests:"""
)
- tests = await sdk.models.medium.complete(prompt)
+ tests = await sdk.models.medium._complete(prompt)
await sdk.apply_filesystem_edit(AddFile(filepath=path, content=tests))
diff --git a/continuedev/src/continuedev/plugins/steps/chat.py b/continuedev/src/continuedev/plugins/steps/chat.py
index 857183bc..d580f886 100644
--- a/continuedev/src/continuedev/plugins/steps/chat.py
+++ b/continuedev/src/continuedev/plugins/steps/chat.py
@@ -83,11 +83,17 @@ class SimpleChatStep(Step):
messages = self.messages or await sdk.get_chat_context()
- generator = sdk.models.chat.stream_chat(
+ generator = sdk.models.chat._stream_chat(
messages, temperature=sdk.config.temperature
)
- posthog_logger.capture_event("model_use", {"model": sdk.models.default.name})
+ posthog_logger.capture_event(
+ "model_use",
+ {
+ "model": sdk.models.default.model,
+ "provider": sdk.models.default.__class__.__name__,
+ },
+ )
async for chunk in generator:
if sdk.current_step_was_deleted():
@@ -112,7 +118,7 @@ class SimpleChatStep(Step):
await sdk.update_ui()
self.name = add_ellipsis(
remove_quotes_and_escapes(
- await sdk.models.medium.complete(
+ await sdk.models.medium._complete(
f'"{self.description}"\n\nPlease write a short title summarizing the message quoted above. Use no more than 10 words:',
max_tokens=20,
)
@@ -254,7 +260,7 @@ class ChatWithFunctions(Step):
gpt350613 = OpenAI(model="gpt-3.5-turbo-0613")
await sdk.start_model(gpt350613)
- async for msg_chunk in gpt350613.stream_chat(
+ async for msg_chunk in gpt350613._stream_chat(
await sdk.get_chat_context(), functions=functions
):
if sdk.current_step_was_deleted():
diff --git a/continuedev/src/continuedev/plugins/steps/chroma.py b/continuedev/src/continuedev/plugins/steps/chroma.py
index 25633942..9ee2a48d 100644
--- a/continuedev/src/continuedev/plugins/steps/chroma.py
+++ b/continuedev/src/continuedev/plugins/steps/chroma.py
@@ -58,7 +58,7 @@ class AnswerQuestionChroma(Step):
Here is the answer:"""
)
- answer = await sdk.models.medium.complete(prompt)
+ answer = await sdk.models.medium._complete(prompt)
# Make paths relative to the workspace directory
answer = answer.replace(await sdk.ide.getWorkspaceDirectory(), "")
diff --git a/continuedev/src/continuedev/plugins/steps/core/core.py b/continuedev/src/continuedev/plugins/steps/core/core.py
index 2c7416aa..9d40822b 100644
--- a/continuedev/src/continuedev/plugins/steps/core/core.py
+++ b/continuedev/src/continuedev/plugins/steps/core/core.py
@@ -275,7 +275,7 @@ class DefaultModelEditCodeStep(Step):
)
# If using 3.5 and overflows, upgrade to 3.5.16k
- if model_to_use.name == "gpt-3.5-turbo":
+ if model_to_use.model == "gpt-3.5-turbo":
if total_tokens > model_to_use.context_length:
model_to_use = MaybeProxyOpenAI(model="gpt-3.5-turbo-0613")
await sdk.start_model(model_to_use)
@@ -663,11 +663,14 @@ Please output the code to be inserted at the cursor in order to fulfill the user
else:
messages = rendered
- generator = model_to_use.stream_chat(
+ generator = model_to_use._stream_chat(
messages, temperature=sdk.config.temperature, max_tokens=max_tokens
)
- posthog_logger.capture_event("model_use", {"model": model_to_use.name})
+ posthog_logger.capture_event(
+ "model_use",
+ {"model": model_to_use.model, "provider": model_to_use.__class__.__name__},
+ )
try:
async for chunk in generator:
diff --git a/continuedev/src/continuedev/plugins/steps/help.py b/continuedev/src/continuedev/plugins/steps/help.py
index 148dddb8..c73d7eef 100644
--- a/continuedev/src/continuedev/plugins/steps/help.py
+++ b/continuedev/src/continuedev/plugins/steps/help.py
@@ -59,7 +59,7 @@ class HelpStep(Step):
ChatMessage(role="user", content=prompt, summary="Help")
)
messages = await sdk.get_chat_context()
- generator = sdk.models.default.stream_chat(messages)
+ generator = sdk.models.default._stream_chat(messages)
async for chunk in generator:
if "content" in chunk:
self.description += chunk["content"]
diff --git a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
index 721f1306..001876d0 100644
--- a/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
+++ b/continuedev/src/continuedev/plugins/steps/input/nl_multiselect.py
@@ -26,7 +26,7 @@ class NLMultiselectStep(Step):
if first_try is not None:
return first_try
- gpt_parsed = await sdk.models.default.complete(
+ gpt_parsed = await sdk.models.default._complete(
f"These are the available options are: [{', '.join(self.options)}]. The user requested {user_response}. This is the exact string from the options array that they selected:"
)
return extract_option(gpt_parsed) or self.options[0]
diff --git a/continuedev/src/continuedev/plugins/steps/main.py b/continuedev/src/continuedev/plugins/steps/main.py
index ca15aaab..7762666c 100644
--- a/continuedev/src/continuedev/plugins/steps/main.py
+++ b/continuedev/src/continuedev/plugins/steps/main.py
@@ -105,7 +105,7 @@ class FasterEditHighlightedCodeStep(Step):
for rif in range_in_files:
rif_dict[rif.filepath] = rif.contents
- completion = await sdk.models.medium.complete(prompt)
+ completion = await sdk.models.medium._complete(prompt)
# Temporarily doing this to generate description.
self._prompt = prompt
@@ -180,7 +180,7 @@ class StarCoderEditHighlightedCodeStep(Step):
_prompt_and_completion: str = ""
async def describe(self, models: Models) -> Coroutine[str, None, None]:
- return await models.medium.complete(
+ return await models.medium._complete(
f"{self._prompt_and_completion}\n\nPlease give brief a description of the changes made above using markdown bullet points:"
)
diff --git a/continuedev/src/continuedev/plugins/steps/react.py b/continuedev/src/continuedev/plugins/steps/react.py
index a2612731..2ed2d3d7 100644
--- a/continuedev/src/continuedev/plugins/steps/react.py
+++ b/continuedev/src/continuedev/plugins/steps/react.py
@@ -29,7 +29,7 @@ class NLDecisionStep(Step):
Select the step which should be taken next to satisfy the user input. Say only the name of the selected step. You must choose one:"""
)
- resp = (await sdk.models.medium.complete(prompt)).lower()
+ resp = (await sdk.models.medium._complete(prompt)).lower()
step_to_run = None
for step in self.steps:
diff --git a/continuedev/src/continuedev/plugins/steps/search_directory.py b/continuedev/src/continuedev/plugins/steps/search_directory.py
index 04fb98b7..9317bfe1 100644
--- a/continuedev/src/continuedev/plugins/steps/search_directory.py
+++ b/continuedev/src/continuedev/plugins/steps/search_directory.py
@@ -46,7 +46,7 @@ class WriteRegexPatternStep(Step):
async def run(self, sdk: ContinueSDK):
# Ask the user for a regex pattern
- pattern = await sdk.models.medium.complete(
+ pattern = await sdk.models.medium._complete(
dedent(
f"""\
This is the user request: