From 868e0b7ef5357b89186119c3c2fa8bd427b8db30 Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 16 Jul 2023 00:21:56 -0700 Subject: Anthropic support --- continuedev/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'continuedev/pyproject.toml') diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 6727e29a..08c3fd04 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Nate Sesti "] readme = "README.md" [tool.poetry.dependencies] -python = "^3.8" +python = "^3.8.1" diff-match-patch = "^20230430" fastapi = "^0.95.1" typer = "^0.7.0" @@ -24,6 +24,7 @@ tiktoken = "^0.4.0" jsonref = "^1.1.0" jsonschema = "^4.17.3" directory-tree = "^0.0.3.1" +anthropic = "^0.3.4" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" -- cgit v1.2.3-70-g09d2 From 73e1cfbefbf450ab6564aba653e0132843223c7a Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Sun, 16 Jul 2023 22:12:44 -0700 Subject: templated system messages --- continuedev/poetry.lock | 65 +++++++++++++++++++++- continuedev/pyproject.toml | 3 +- continuedev/src/continuedev/core/config.py | 1 + continuedev/src/continuedev/core/sdk.py | 10 ++-- continuedev/src/continuedev/libs/llm/ggml.py | 6 +- .../src/continuedev/libs/llm/hf_inference_api.py | 3 +- continuedev/src/continuedev/libs/llm/openai.py | 6 +- .../src/continuedev/libs/llm/proxy_server.py | 6 +- .../src/continuedev/libs/util/count_tokens.py | 7 ++- .../src/continuedev/libs/util/templating.py | 39 +++++++++++++ 10 files changed, 127 insertions(+), 19 deletions(-) create mode 100644 continuedev/src/continuedev/libs/util/templating.py (limited to 'continuedev/pyproject.toml') diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index a49a570f..625aabc9 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -297,6 +297,18 @@ files = [ {file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"}, ] +[[package]] +name = "chevron" +version = "0.14.0" +description = "Mustache templating language renderer" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "chevron-0.14.0-py3-none-any.whl", hash = "sha256:fbf996a709f8da2e745ef763f482ce2d311aa817d287593a5b990d6d6e4f0443"}, + {file = "chevron-0.14.0.tar.gz", hash = "sha256:87613aafdf6d77b6a90ff073165a61ae5086e21ad49057aa0e53681601800ebf"}, +] + [[package]] name = "click" version = "8.1.3" @@ -600,6 +612,25 @@ files = [ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] +[[package]] +name = "importlib-resources" +version = "6.0.0" +description = "Read resources from Python packages" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_resources-6.0.0-py3-none-any.whl", hash = "sha256:d952faee11004c045f785bb5636e8f885bed30dc3c940d5d42798a2a4541c185"}, + {file = "importlib_resources-6.0.0.tar.gz", hash = "sha256:4cf94875a8368bd89531a756df9a9ebe1f150e0f885030b461237bc7f2d905f2"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [[package]] name = "jsonref" version = "1.1.0" @@ -626,6 +657,8 @@ files = [ [package.dependencies] attrs = ">=17.4.0" +importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} +pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""} pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2" [package.extras] @@ -1024,6 +1057,18 @@ sql-other = ["SQLAlchemy (>=1.4.16)"] test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.6.3)"] +[[package]] +name = "pkgutil-resolve-name" +version = "1.3.10" +description = "Resolve a name to an object." +category = "main" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"}, + {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"}, +] + [[package]] name = "posthog" version = "3.0.1" @@ -1818,7 +1863,23 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[[package]] +name = "zipp" +version = "3.16.2" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.16.2-py3-none-any.whl", hash = "sha256:679e51dd4403591b2d6838a48de3d283f3d188412a9782faadf845f298736ba0"}, + {file = "zipp-3.16.2.tar.gz", hash = "sha256:ebc15946aa78bd63458992fc81ec3b6f7b1e92d51c35e6de1c3804e73b799147"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [metadata] lock-version = "2.0" -python-versions = "^3.9" -content-hash = "3ba2a7278fda36a059d76e227be94b0cb5e2efc9396b47a9642b916680214d9f" +python-versions = "^3.8.1" +content-hash = "82510deb9f4afb5bc38db0dfd88ad88005fa0b6221c24e8c1700c006360f3f88" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 6727e29a..3077de1c 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Nate Sesti "] readme = "README.md" [tool.poetry.dependencies] -python = "^3.8" +python = "^3.8.1" diff-match-patch = "^20230430" fastapi = "^0.95.1" typer = "^0.7.0" @@ -24,6 +24,7 @@ tiktoken = "^0.4.0" jsonref = "^1.1.0" jsonschema = "^4.17.3" directory-tree = "^0.0.3.1" +chevron = "^0.14.0" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 957609c5..91a47c8e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -85,6 +85,7 @@ class ContinueConfig(BaseModel): slash_commands: Optional[List[SlashCommand]] = DEFAULT_SLASH_COMMANDS on_traceback: Optional[List[OnTracebackSteps]] = [ OnTracebackSteps(step_name="DefaultOnTracebackStep")] + system_message: Optional[str] = None # Want to force these to be the slash commands for now @validator('slash_commands', pre=True) diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index eb60109c..ac57c122 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -34,10 +34,12 @@ MODEL_PROVIDER_TO_ENV_VAR = { class Models: provider_keys: Dict[ModelProvider, str] = {} model_providers: List[ModelProvider] + system_message: str def __init__(self, sdk: "ContinueSDK", model_providers: List[ModelProvider]): self.sdk = sdk self.model_providers = model_providers + self.system_message = sdk.config.system_message @classmethod async def create(cls, sdk: "ContinueSDK", with_providers: List[ModelProvider] = ["openai"]) -> "Models": @@ -53,12 +55,12 @@ class Models: def __load_openai_model(self, model: str) -> OpenAI: api_key = self.provider_keys["openai"] if api_key == "": - return ProxyServer(self.sdk.ide.unique_id, model) - return OpenAI(api_key=api_key, default_model=model) + return ProxyServer(self.sdk.ide.unique_id, model, system_message=self.system_message) + return OpenAI(api_key=api_key, default_model=model, system_message=self.system_message) def __load_hf_inference_api_model(self, model: str) -> HuggingFaceInferenceAPI: api_key = self.provider_keys["hf_inference_api"] - return HuggingFaceInferenceAPI(api_key=api_key, model=model) + return HuggingFaceInferenceAPI(api_key=api_key, model=model, system_message=self.system_message) @cached_property def starcoder(self): @@ -82,7 +84,7 @@ class Models: @cached_property def ggml(self): - return GGML() + return GGML(system_message=self.system_message) def __model_from_name(self, model_name: str): if model_name == "starcoder": diff --git a/continuedev/src/continuedev/libs/llm/ggml.py b/continuedev/src/continuedev/libs/llm/ggml.py index d3589b70..6007fdb4 100644 --- a/continuedev/src/continuedev/libs/llm/ggml.py +++ b/continuedev/src/continuedev/libs/llm/ggml.py @@ -33,7 +33,7 @@ class GGML(LLM): args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + self.name, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ @@ -50,7 +50,7 @@ class GGML(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.name, messages, args["max_tokens"], None, functions=args.get("functions", None)) + self.name, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) args["stream"] = True async with aiohttp.ClientSession() as session: @@ -77,7 +77,7 @@ class GGML(LLM): async with aiohttp.ClientSession() as session: async with session.post(f"{SERVER_URL}/v1/completions", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args }) as resp: try: diff --git a/continuedev/src/continuedev/libs/llm/hf_inference_api.py b/continuedev/src/continuedev/libs/llm/hf_inference_api.py index 803ba122..7e11fbbe 100644 --- a/continuedev/src/continuedev/libs/llm/hf_inference_api.py +++ b/continuedev/src/continuedev/libs/llm/hf_inference_api.py @@ -11,9 +11,10 @@ class HuggingFaceInferenceAPI(LLM): api_key: str model: str - def __init__(self, api_key: str, model: str): + def __init__(self, api_key: str, model: str, system_message: str = None): self.api_key = api_key self.model = model + self.system_message = system_message # TODO: Nothing being done with this def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs): """Return the completion of the text with the given temperature.""" diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index f0877d90..d973f19e 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -37,7 +37,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args, ): if "content" in chunk.choices[0].delta: @@ -58,7 +58,7 @@ class OpenAI(LLM): async for chunk in await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], messages, args["max_tokens"], functions=args.get("functions", None)), + args["model"], messages, args["max_tokens"], functions=args.get("functions", None), system_message=self.system_message), **args, ): yield chunk.choices[0].delta @@ -69,7 +69,7 @@ class OpenAI(LLM): if args["model"] in CHAT_MODELS: resp = (await openai.ChatCompletion.acreate( messages=compile_chat_messages( - args["model"], with_history, args["max_tokens"], prompt, functions=None), + args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), **args, )).choices[0].message.content else: diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index eab6e441..3ec492f3 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -38,7 +38,7 @@ class ProxyServer(LLM): async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/complete", json={ - "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None), + "messages": compile_chat_messages(args["model"], with_history, args["max_tokens"], prompt, functions=None, system_message=self.system_message), "unique_id": self.unique_id, **args }) as resp: @@ -50,7 +50,7 @@ class ProxyServer(LLM): async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None)) + self.default_model, messages, args["max_tokens"], None, functions=args.get("functions", None), system_message=self.system_message) async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_chat", json={ @@ -74,7 +74,7 @@ class ProxyServer(LLM): async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: args = {**self.default_args, **kwargs} messages = compile_chat_messages( - self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None)) + self.default_model, with_history, args["max_tokens"], prompt, functions=args.get("functions", None), system_message=self.system_message) async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session: async with session.post(f"{SERVER_URL}/stream_complete", json={ diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py index e1baeca1..1ca98fe6 100644 --- a/continuedev/src/continuedev/libs/util/count_tokens.py +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -1,6 +1,7 @@ import json from typing import Dict, List, Union from ...core.main import ChatMessage +from .templating import render_system_message import tiktoken aliases = { @@ -85,13 +86,15 @@ def compile_chat_messages(model: str, msgs: List[ChatMessage], max_tokens: int, for function in functions: prompt_tokens += count_tokens(model, json.dumps(function)) + rendered_system_message = render_system_message(system_message) + msgs = prune_chat_history(model, - msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, system_message)) + msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + max_tokens + count_tokens(model, rendered_system_message)) history = [] if system_message: history.append({ "role": "system", - "content": system_message + "content": rendered_system_message }) history += [msg.to_dict(with_functions=functions is not None) for msg in msgs] diff --git a/continuedev/src/continuedev/libs/util/templating.py b/continuedev/src/continuedev/libs/util/templating.py new file mode 100644 index 00000000..ebfc2e31 --- /dev/null +++ b/continuedev/src/continuedev/libs/util/templating.py @@ -0,0 +1,39 @@ +import os +import chevron + + +def get_vars_in_template(template): + """ + Get the variables in a template + """ + return [token[1] for token in chevron.tokenizer.tokenize(template) if token[0] == 'variable'] + + +def escape_var(var: str) -> str: + """ + Escape a variable so it can be used in a template + """ + return var.replace(os.path.sep, '').replace('.', '') + + +def render_system_message(system_message: str) -> str: + """ + Render system message with mustache syntax. + Right now it only supports rendering absolute file paths as their contents. + """ + vars = get_vars_in_template(system_message) + + args = {} + for var in vars: + if var.startswith(os.path.sep): + # Escape vars which are filenames, because mustache doesn't allow / in variable names + escaped_var = escape_var(var) + system_message = system_message.replace( + var, escaped_var) + + if os.path.exists(var): + args[escaped_var] = open(var, 'r').read() + else: + args[escaped_var] = '' + + return chevron.render(system_message, args) -- cgit v1.2.3-70-g09d2 From db1e35497de924c001f421d1d3277f02258b55db Mon Sep 17 00:00:00 2001 From: Nate Sesti Date: Tue, 18 Jul 2023 21:10:08 -0700 Subject: psutil profiling, temperature in config.json --- continuedev/poetry.lock | 29 +++++++++++++++++++++- continuedev/pyproject.toml | 1 + continuedev/src/continuedev/core/config.py | 5 ++++ .../src/continuedev/libs/llm/proxy_server.py | 8 ++++-- continuedev/src/continuedev/server/main.py | 25 ++++++++++++++++++- continuedev/src/continuedev/steps/chat.py | 3 ++- continuedev/src/continuedev/steps/core/core.py | 2 +- 7 files changed, 67 insertions(+), 6 deletions(-) (limited to 'continuedev/pyproject.toml') diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index e8927fe7..1cd4a591 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -1171,6 +1171,33 @@ dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest"] +[[package]] +name = "psutil" +version = "5.9.5" +description = "Cross-platform lib for process and system monitoring in Python." +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"}, + {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"}, + {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"}, + {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"}, + {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"}, + {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"}, + {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, + {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + [[package]] name = "pydantic" version = "1.10.7" @@ -2015,4 +2042,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "87dbf6d1e56ce6ba81a01a59c0de2d3717925bac9639710bf3ff3ce30f5f5e2c" +content-hash = "3fcd19c11b9c338a181e591b56e21d59c7834abff431fb9f40cc1ea874b64557" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index 6a646cbe..0abc9504 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -26,6 +26,7 @@ jsonschema = "^4.17.3" directory-tree = "^0.0.3.1" anthropic = "^0.3.4" chevron = "^0.14.0" +psutil = "^5.9.5" [tool.poetry.scripts] typegen = "src.continuedev.models.generate_json_schema:main" diff --git a/continuedev/src/continuedev/core/config.py b/continuedev/src/continuedev/core/config.py index 6af0878d..70c4876e 100644 --- a/continuedev/src/continuedev/core/config.py +++ b/continuedev/src/continuedev/core/config.py @@ -82,6 +82,7 @@ class ContinueConfig(BaseModel): allow_anonymous_telemetry: Optional[bool] = True default_model: Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "claude-2", "ggml"] = 'gpt-4' + temperature: Optional[float] = 0.5 custom_commands: Optional[List[CustomCommand]] = [CustomCommand( name="test", description="This is an example custom command. Use /config to edit it and create more", @@ -98,6 +99,10 @@ class ContinueConfig(BaseModel): def default_slash_commands_validator(cls, v): return DEFAULT_SLASH_COMMANDS + @validator('temperature', pre=True) + def temperature_validator(cls, v): + return max(0.0, min(1.0, v)) + def load_config(config_file: str) -> ContinueConfig: """ diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 18e0e6f4..bd50fe02 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,8 +1,10 @@ from functools import cached_property import json +import traceback from typing import Any, Callable, Coroutine, Dict, Generator, List, Literal, Union import aiohttp +from ..util.telemetry import capture_event from ...core.main import ChatMessage from ..llm import LLM from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens, format_chat_messages @@ -81,8 +83,10 @@ class ProxyServer(LLM): yield loaded_chunk if "content" in loaded_chunk: completion += loaded_chunk["content"] - except: - raise Exception(str(line[0])) + except Exception as e: + capture_event(self.unique_id, "proxy_server_parse_error", { + "error_title": "Proxy server stream_chat parsing failed", "error_message": '\n'.join(traceback.format_exception(e))}) + self.write_log(f"Completion: \n\n{completion}") async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: diff --git a/continuedev/src/continuedev/server/main.py b/continuedev/src/continuedev/server/main.py index aa093853..42dc0cc1 100644 --- a/continuedev/src/continuedev/server/main.py +++ b/continuedev/src/continuedev/server/main.py @@ -1,5 +1,6 @@ +import time +import psutil import os -import sys from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .ide import router as ide_router @@ -51,9 +52,31 @@ def cleanup(): session_manager.persist_session(session_id) +def cpu_usage_report(): + process = psutil.Process(os.getpid()) + # Call cpu_percent once to start measurement, but ignore the result + process.cpu_percent(interval=None) + # Wait for a short period of time + time.sleep(1) + # Call cpu_percent again to get the CPU usage over the interval + cpu_usage = process.cpu_percent(interval=None) + print(f"CPU usage: {cpu_usage}%") + + atexit.register(cleanup) + if __name__ == "__main__": try: + # import threading + + # def cpu_usage_loop(): + # while True: + # cpu_usage_report() + # time.sleep(2) + + # cpu_thread = threading.Thread(target=cpu_usage_loop) + # cpu_thread.start() + run_server() except Exception as e: cleanup() diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index 8c03969e..aade1ea1 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -29,7 +29,8 @@ class SimpleChatStep(Step): completion = "" messages = self.messages or await sdk.get_chat_context() - generator = sdk.models.default.stream_chat(messages, temperature=0.5) + generator = sdk.models.default.stream_chat( + messages, temperature=sdk.config.temperature) try: async for chunk in generator: if sdk.current_step_was_deleted(): diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 2b049ecc..4afc36e8 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -492,7 +492,7 @@ Please output the code to be inserted at the cursor in order to fulfill the user role="user", content=f"```\n{rif.contents}\n```\n\nUser request: \"{self.user_input}\"\n\nThis is the code after changing to perfectly comply with the user request. It does not include any placeholder code, only real implementations:\n\n```\n", summary=self.user_input)] generator = model_to_use.stream_chat( - messages, temperature=0, max_tokens=max_tokens) + messages, temperature=sdk.config.temperature, max_tokens=max_tokens) try: async for chunk in generator: -- cgit v1.2.3-70-g09d2