diff options
| author | Nate Sesti <33237525+sestinj@users.noreply.github.com> | 2023-06-25 13:50:23 -0700 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-06-25 13:50:23 -0700 | 
| commit | f28a82dc2b929ddebdcf2589e24fc31a3b1078cc (patch) | |
| tree | 8a87781af84b4acaf3bec75be7d42e93796b25cf | |
| parent | ad462728afc4e6a9e1402aff295010ced9cf2f7a (diff) | |
| parent | 8db3dcb7b9f138b09d7cceedfa830fd150795b30 (diff) | |
| download | sncontinue-f28a82dc2b929ddebdcf2589e24fc31a3b1078cc.tar.gz sncontinue-f28a82dc2b929ddebdcf2589e24fc31a3b1078cc.tar.bz2 sncontinue-f28a82dc2b929ddebdcf2589e24fc31a3b1078cc.zip | |
Merge pull request #153 from continuedev/function-calling
Function calling
| -rw-r--r-- | continuedev/poetry.lock | 74 | ||||
| -rw-r--r-- | continuedev/pyproject.toml | 2 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/abstract_sdk.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/main.py | 95 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/policy.py | 5 | ||||
| -rw-r--r-- | continuedev/src/continuedev/core/sdk.py | 60 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/__init__.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 213 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 98 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/util/count_tokens.py | 84 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/chat.py | 214 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/core/core.py | 9 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/main.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/on_traceback.py | 8 | ||||
| -rw-r--r-- | extension/react-app/src/components/HeaderButtonWithText.tsx | 2 | ||||
| -rw-r--r-- | extension/react-app/src/components/StepContainer.tsx | 1 | ||||
| -rw-r--r-- | extension/react-app/src/tabs/gui.tsx | 4 | 
17 files changed, 593 insertions, 294 deletions
| diff --git a/continuedev/poetry.lock b/continuedev/poetry.lock index 017f12f9..a49a570f 100644 --- a/continuedev/poetry.lock +++ b/continuedev/poetry.lock @@ -360,6 +360,21 @@ files = [  dev = ["attribution (==1.6.2)", "black (==23.3.0)", "flit (==3.8.0)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"]  [[package]] +name = "directory-tree" +version = "0.0.3.1" +description = "Utility Package that Displays out the Tree Structure of a Particular Directory." +category = "main" +optional = false +python-versions = "*" +files = [ +    {file = "directory_tree-0.0.3.1-py3-none-any.whl", hash = "sha256:72411e4f1534afaaccadb21fc082c727a680b6a74e8d21a1406ffbe51389cd85"}, +    {file = "directory_tree-0.0.3.1.tar.gz", hash = "sha256:e4f40d60a45c4cdc0bc8e9ee29311f554dee6c969241c0eef8bcd92b4d4bcd4a"}, +] + +[package.extras] +dev = ["pytest (>=3.7)"] + +[[package]]  name = "fastapi"  version = "0.95.1"  description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" @@ -598,6 +613,26 @@ files = [  ]  [[package]] +name = "jsonschema" +version = "4.17.3" +description = "An implementation of JSON Schema validation for Python" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ +    {file = "jsonschema-4.17.3-py3-none-any.whl", hash = "sha256:a870ad254da1a8ca84b6a2905cac29d265f805acc57af304784962a2aa6508f6"}, +    {file = "jsonschema-4.17.3.tar.gz", hash = "sha256:0f864437ab8b6076ba6707453ef8f98a6a0d512a80e93f8abdb676f737ecb60d"}, +] + +[package.dependencies] +attrs = ">=17.4.0" +pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"] + +[[package]]  name = "langchain"  version = "0.0.171"  description = "Building applications with LLMs through composability" @@ -1067,6 +1102,43 @@ dotenv = ["python-dotenv (>=0.10.4)"]  email = ["email-validator (>=1.0.3)"]  [[package]] +name = "pyrsistent" +version = "0.19.3" +description = "Persistent/Functional/Immutable data structures" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ +    {file = "pyrsistent-0.19.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:20460ac0ea439a3e79caa1dbd560344b64ed75e85d8703943e0b66c2a6150e4a"}, +    {file = "pyrsistent-0.19.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c18264cb84b5e68e7085a43723f9e4c1fd1d935ab240ce02c0324a8e01ccb64"}, +    {file = "pyrsistent-0.19.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b774f9288dda8d425adb6544e5903f1fb6c273ab3128a355c6b972b7df39dcf"}, +    {file = "pyrsistent-0.19.3-cp310-cp310-win32.whl", hash = "sha256:5a474fb80f5e0d6c9394d8db0fc19e90fa540b82ee52dba7d246a7791712f74a"}, +    {file = "pyrsistent-0.19.3-cp310-cp310-win_amd64.whl", hash = "sha256:49c32f216c17148695ca0e02a5c521e28a4ee6c5089f97e34fe24163113722da"}, +    {file = "pyrsistent-0.19.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f0774bf48631f3a20471dd7c5989657b639fd2d285b861237ea9e82c36a415a9"}, +    {file = "pyrsistent-0.19.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab2204234c0ecd8b9368dbd6a53e83c3d4f3cab10ecaf6d0e772f456c442393"}, +    {file = "pyrsistent-0.19.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e42296a09e83028b3476f7073fcb69ffebac0e66dbbfd1bd847d61f74db30f19"}, +    {file = "pyrsistent-0.19.3-cp311-cp311-win32.whl", hash = "sha256:64220c429e42a7150f4bfd280f6f4bb2850f95956bde93c6fda1b70507af6ef3"}, +    {file = "pyrsistent-0.19.3-cp311-cp311-win_amd64.whl", hash = "sha256:016ad1afadf318eb7911baa24b049909f7f3bb2c5b1ed7b6a8f21db21ea3faa8"}, +    {file = "pyrsistent-0.19.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c4db1bd596fefd66b296a3d5d943c94f4fac5bcd13e99bffe2ba6a759d959a28"}, +    {file = "pyrsistent-0.19.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aeda827381f5e5d65cced3024126529ddc4289d944f75e090572c77ceb19adbf"}, +    {file = "pyrsistent-0.19.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42ac0b2f44607eb92ae88609eda931a4f0dfa03038c44c772e07f43e738bcac9"}, +    {file = "pyrsistent-0.19.3-cp37-cp37m-win32.whl", hash = "sha256:e8f2b814a3dc6225964fa03d8582c6e0b6650d68a232df41e3cc1b66a5d2f8d1"}, +    {file = "pyrsistent-0.19.3-cp37-cp37m-win_amd64.whl", hash = "sha256:c9bb60a40a0ab9aba40a59f68214eed5a29c6274c83b2cc206a359c4a89fa41b"}, +    {file = "pyrsistent-0.19.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a2471f3f8693101975b1ff85ffd19bb7ca7dd7c38f8a81701f67d6b4f97b87d8"}, +    {file = "pyrsistent-0.19.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc5d149f31706762c1f8bda2e8c4f8fead6e80312e3692619a75301d3dbb819a"}, +    {file = "pyrsistent-0.19.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3311cb4237a341aa52ab8448c27e3a9931e2ee09561ad150ba94e4cfd3fc888c"}, +    {file = "pyrsistent-0.19.3-cp38-cp38-win32.whl", hash = "sha256:f0e7c4b2f77593871e918be000b96c8107da48444d57005b6a6bc61fb4331b2c"}, +    {file = "pyrsistent-0.19.3-cp38-cp38-win_amd64.whl", hash = "sha256:c147257a92374fde8498491f53ffa8f4822cd70c0d85037e09028e478cababb7"}, +    {file = "pyrsistent-0.19.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b735e538f74ec31378f5a1e3886a26d2ca6351106b4dfde376a26fc32a044edc"}, +    {file = "pyrsistent-0.19.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99abb85579e2165bd8522f0c0138864da97847875ecbd45f3e7e2af569bfc6f2"}, +    {file = "pyrsistent-0.19.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a8cb235fa6d3fd7aae6a4f1429bbb1fec1577d978098da1252f0489937786f3"}, +    {file = "pyrsistent-0.19.3-cp39-cp39-win32.whl", hash = "sha256:c74bed51f9b41c48366a286395c67f4e894374306b197e62810e0fdaf2364da2"}, +    {file = "pyrsistent-0.19.3-cp39-cp39-win_amd64.whl", hash = "sha256:878433581fc23e906d947a6814336eee031a00e6defba224234169ae3d3d6a98"}, +    {file = "pyrsistent-0.19.3-py3-none-any.whl", hash = "sha256:ccf0d6bd208f8111179f0c26fdf84ed7c3891982f2edaeae7422575f47e66b64"}, +    {file = "pyrsistent-0.19.3.tar.gz", hash = "sha256:1a2994773706bbb4995c31a97bc94f1418314923bd1048c6d964837040376440"}, +] + +[[package]]  name = "python-dateutil"  version = "2.8.2"  description = "Extensions to the standard Python datetime module" @@ -1749,4 +1821,4 @@ multidict = ">=4.0"  [metadata]  lock-version = "2.0"  python-versions = "^3.9" -content-hash = "9406bc70d0463b354c294bd9548897a33270b8a04f55141a763d45af8d6928b8" +content-hash = "3ba2a7278fda36a059d76e227be94b0cb5e2efc9396b47a9642b916680214d9f" diff --git a/continuedev/pyproject.toml b/continuedev/pyproject.toml index bbd8a687..64d88b8c 100644 --- a/continuedev/pyproject.toml +++ b/continuedev/pyproject.toml @@ -22,6 +22,8 @@ gpt-index = "^0.6.8"  posthog = "^3.0.1"  tiktoken = "^0.4.0"  jsonref = "^1.1.0" +jsonschema = "^4.17.3" +directory-tree = "^0.0.3.1"  [tool.poetry.scripts]  typegen = "src.continuedev.models.generate_json_schema:main"  diff --git a/continuedev/src/continuedev/core/abstract_sdk.py b/continuedev/src/continuedev/core/abstract_sdk.py index 017e75ef..7bd3da6c 100644 --- a/continuedev/src/continuedev/core/abstract_sdk.py +++ b/continuedev/src/continuedev/core/abstract_sdk.py @@ -85,9 +85,5 @@ class AbstractContinueSDK(ABC):          pass      @abstractmethod -    def add_chat_context(self, content: str, role: ChatMessageRole = "assistant"): -        pass - -    @abstractmethod      async def get_chat_context(self) -> List[ChatMessage]:          pass diff --git a/continuedev/src/continuedev/core/main.py b/continuedev/src/continuedev/core/main.py index d6412ece..b9ae9eba 100644 --- a/continuedev/src/continuedev/core/main.py +++ b/continuedev/src/continuedev/core/main.py @@ -1,18 +1,96 @@ +import json  from textwrap import dedent  from typing import Callable, Coroutine, Dict, Generator, List, Literal, Tuple, Union  from ..models.main import ContinueBaseModel  from pydantic import validator  from .observation import Observation +from pydantic.schema import schema -ChatMessageRole = Literal["assistant", "user", "system"] +ChatMessageRole = Literal["assistant", "user", "system", "function"] + + +class FunctionCall(ContinueBaseModel): +    name: str +    arguments: str  class ChatMessage(ContinueBaseModel):      role: ChatMessageRole -    content: str +    content: str | None +    name: str | None      # A summary for pruning chat context to fit context window. Often the Step name.      summary: str +    function_call: FunctionCall | None = None + +    def to_dict(self, with_functions: bool) -> Dict: +        d = self.dict() +        del d["summary"] +        if d["function_call"] is not None: +            d["function_call"]["name"] = d["function_call"]["name"].replace( +                " ", "") + +        if d["content"] is None: +            d["content"] = "" +        for key, value in list(d.items()): +            if value is None: +                del d[key] + +        if not with_functions: +            d["role"] = "assistant" +            if "name" in d: +                del d["name"] +            if "function_call" in d: +                del d["function_call"] +        return d + + +def resolve_refs(schema_data): +    def traverse(obj): +        if isinstance(obj, dict): +            if '$ref' in obj: +                ref = obj['$ref'] +                parts = ref.split('/') +                ref_obj = schema_data +                for part in parts[1:]: +                    ref_obj = ref_obj[part] +                return traverse(ref_obj) +            else: +                for key, value in obj.items(): +                    obj[key] = traverse(value) +        elif isinstance(obj, list): +            for i in range(len(obj)): +                obj[i] = traverse(obj[i]) +        return obj + +    return traverse(schema_data) + + +unincluded_parameters = ["system_message", "chat_context", +                         "manage_own_chat_context", "hide", "name", "description"] + + +def step_to_json_schema(step) -> str: +    pydantic_class = step.__class__ +    schema_data = schema([pydantic_class]) +    resolved_schema = resolve_refs(schema_data) +    parameters = resolved_schema["definitions"][pydantic_class.__name__] +    for parameter in unincluded_parameters: +        if parameter in parameters["properties"]: +            del parameters["properties"][parameter] +    return { +        "name": step.name.replace(" ", ""), +        "description": step.description or "", +        "parameters": parameters +    } + + +def step_to_fn_call_arguments(step: "Step") -> str: +    args = step.dict() +    for parameter in unincluded_parameters: +        if parameter in args: +            del args[parameter] +    return json.dumps(args)  class HistoryNode(ContinueBaseModel): @@ -24,9 +102,17 @@ class HistoryNode(ContinueBaseModel):      active: bool = True      def to_chat_messages(self) -> List[ChatMessage]: -        if self.step.description is None: +        if self.step.description is None or self.step.manage_own_chat_context:              return self.step.chat_context -        return self.step.chat_context + [ChatMessage(role="assistant", content=self.step.description, summary=self.step.name)] +        return self.step.chat_context + [ +            ChatMessage( +                role="function", +                name=self.step.__class__.__name__, +                content=json.dumps({ +                    "description": self.step.description or "Function complete", +                }), +                summary=f"Ran function {self.step.name}" +            )]  class History(ContinueBaseModel): @@ -144,6 +230,7 @@ class Step(ContinueBaseModel):      system_message: Union[str, None] = None      chat_context: List[ChatMessage] = [] +    manage_own_chat_context: bool = False      class Config:          copy_on_model_validation = False diff --git a/continuedev/src/continuedev/core/policy.py b/continuedev/src/continuedev/core/policy.py index fb13dd27..0b417959 100644 --- a/continuedev/src/continuedev/core/policy.py +++ b/continuedev/src/continuedev/core/policy.py @@ -15,7 +15,7 @@ from ..recipes.WritePytestsRecipe.main import WritePytestsRecipe  from ..recipes.ContinueRecipeRecipe.main import ContinueStepStep  from ..steps.comment_code import CommentCodeStep  from ..steps.react import NLDecisionStep -from ..steps.chat import SimpleChatStep +from ..steps.chat import SimpleChatStep, ChatWithFunctions, EditFileStep, AddFileStep  from ..recipes.DDtoBQRecipe.main import DDtoBQRecipe  from ..steps.core.core import MessageStep  from ..libs.util.step_name_to_steps import get_step_from_name @@ -28,7 +28,7 @@ class DemoPolicy(Policy):          # At the very start, run initial Steps spcecified in the config          if history.get_current() is None:              return ( -                MessageStep(name="Welcome to Continue!", message=dedent("""\ +                MessageStep(name="Welcome to Continue", message=dedent("""\                      Type '/' to see the list of available slash commands. If you highlight code, edits and explanations will be localized to the highlighted range. Otherwise, the currently open file is used. In both cases, the code is combined with the previous steps to construct the context.""")) >>                  WelcomeStep() >>                  # SetupContinueWorkspaceStep() >> @@ -50,6 +50,7 @@ class DemoPolicy(Policy):                          return get_step_from_name(slash_command.step_name, params)              # return EditHighlightedCodeStep(user_input=user_input) +            return ChatWithFunctions(user_input=user_input)              return NLDecisionStep(user_input=user_input, steps=[                  (EditHighlightedCodeStep(user_input=user_input),                   "Edit the highlighted code"), diff --git a/continuedev/src/continuedev/core/sdk.py b/continuedev/src/continuedev/core/sdk.py index 192552e7..62361250 100644 --- a/continuedev/src/continuedev/core/sdk.py +++ b/continuedev/src/continuedev/core/sdk.py @@ -26,6 +26,15 @@ class Models:      def __init__(self, sdk: "ContinueSDK"):          self.sdk = sdk +    def __load_openai_model(self, model: str) -> OpenAI: +        async def load_openai_model(): +            api_key = await self.sdk.get_user_secret( +                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') +            if api_key == "": +                return ProxyServer(self.sdk.ide.unique_id, model) +            return OpenAI(api_key=api_key, default_model=model) +        return asyncio.get_event_loop().run_until_complete(load_openai_model()) +      @cached_property      def starcoder(self):          async def load_starcoder(): @@ -36,33 +45,19 @@ class Models:      @cached_property      def gpt35(self): -        async def load_gpt35(): -            api_key = await self.sdk.get_user_secret( -                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') -            if api_key == "": -                return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo") -            return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo") -        return asyncio.get_event_loop().run_until_complete(load_gpt35()) +        return self.__load_openai_model("gpt-3.5-turbo") + +    @cached_property +    def gpt350613(self): +        return self.__load_openai_model("gpt-3.5-turbo-0613")      @cached_property      def gpt3516k(self): -        async def load_gpt3516k(): -            api_key = await self.sdk.get_user_secret( -                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') -            if api_key == "": -                return ProxyServer(self.sdk.ide.unique_id, "gpt-3.5-turbo-16k") -            return OpenAI(api_key=api_key, default_model="gpt-3.5-turbo-16k") -        return asyncio.get_event_loop().run_until_complete(load_gpt3516k()) +        return self.__load_openai_model("gpt-3.5-turbo-16k")      @cached_property      def gpt4(self): -        async def load_gpt4(): -            api_key = await self.sdk.get_user_secret( -                'OPENAI_API_KEY', 'Enter your OpenAI API key, OR press enter to try for free') -            if api_key == "": -                return ProxyServer(self.sdk.ide.unique_id, "gpt-4") -            return OpenAI(api_key=api_key, default_model="gpt-4") -        return asyncio.get_event_loop().run_until_complete(load_gpt4()) +        return self.__load_openai_model("gpt-4")      def __model_from_name(self, model_name: str):          if model_name == "starcoder": @@ -102,7 +97,7 @@ class ContinueSDK(AbstractContinueSDK):      async def _ensure_absolute_path(self, path: str) -> str:          if os.path.isabs(path):              return path -        return os.path.join(await self.ide.getWorkspaceDirectory(), path) +        return os.path.join(self.ide.workspace_directory, path)      async def run_step(self, step: Step) -> Coroutine[Observation, None, None]:          return await self.__autopilot._run_singular_step(step) @@ -144,15 +139,15 @@ class ContinueSDK(AbstractContinueSDK):          return await self.run_step(FileSystemEditStep(edit=AddFile(filepath=filepath, content=content)))      async def delete_file(self, filename: str): -        filepath = await self._ensure_absolute_path(filename) +        filename = await self._ensure_absolute_path(filename)          return await self.run_step(FileSystemEditStep(edit=DeleteFile(filepath=filename)))      async def add_directory(self, path: str): -        filepath = await self._ensure_absolute_path(path) +        path = await self._ensure_absolute_path(path)          return await self.run_step(FileSystemEditStep(edit=AddDirectory(path=path)))      async def delete_directory(self, path: str): -        filepath = await self._ensure_absolute_path(path) +        path = await self._ensure_absolute_path(path)          return await self.run_step(FileSystemEditStep(edit=DeleteDirectory(path=path)))      async def get_user_secret(self, env_var: str, prompt: str) -> str: @@ -182,10 +177,6 @@ class ContinueSDK(AbstractContinueSDK):      def raise_exception(self, message: str, title: str, with_step: Union[Step, None] = None):          raise ContinueCustomException(message, title, with_step) -    def add_chat_context(self, content: str, summary: Union[str, None] = None, role: ChatMessageRole = "assistant"): -        self.history.timeline[self.history.current_index].step.chat_context.append( -            ChatMessage(content=content, role=role, summary=summary)) -      async def get_chat_context(self) -> List[ChatMessage]:          history_context = self.history.to_chat_history()          highlighted_code = await self.ide.getHighlightedCode() @@ -203,8 +194,15 @@ class ContinueSDK(AbstractContinueSDK):          for rif in highlighted_code:              code = await self.ide.readRangeInFile(rif) -            history_context.append(ChatMessage( -                content=f"{preface} ({rif.filepath}):\n```\n{code}\n```", role="user", summary=f"{preface}: {rif.filepath}")) +            msg = ChatMessage(content=f"{preface} ({rif.filepath}):\n```\n{code}\n```", +                              role="user", summary=f"{preface}: {rif.filepath}") + +            # Don't insert after latest user message or function call +            i = -1 +            if history_context[i].role == "user" or history_context[i].role == "function": +                i -= 1 +            history_context.insert(i, msg) +          return history_context      async def update_ui(self): diff --git a/continuedev/src/continuedev/libs/llm/__init__.py b/continuedev/src/continuedev/libs/llm/__init__.py index 108eedf1..4c4de213 100644 --- a/continuedev/src/continuedev/libs/llm/__init__.py +++ b/continuedev/src/continuedev/libs/llm/__init__.py @@ -13,12 +13,12 @@ class LLM(ABC):          """Return the completion of the text with the given temperature."""          raise NotImplementedError -    def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        """Yield a stream of chat messages.""" +    def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        """Stream the completion through generator."""          raise NotImplementedError -    def with_system_message(self, system_message: Union[str, None]): -        """Return a new model with the given system message.""" +    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        """Stream the chat through generator."""          raise NotImplementedError      def count_tokens(self, text: str): diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index 095cbe51..7621111f 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -1,34 +1,21 @@ -import asyncio  from functools import cached_property  import time  from typing import Any, Coroutine, Dict, Generator, List, Union  from ...core.main import ChatMessage  import openai -import aiohttp  from ..llm import LLM -from pydantic import BaseModel, validator -import tiktoken - -DEFAULT_MAX_TOKENS = 2048 -MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, -    "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, -    "gpt-4": 8192 - DEFAULT_MAX_TOKENS -} -CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4" -} +from ..util.count_tokens import DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, DEFAULT_ARGS, count_tokens  class OpenAI(LLM):      api_key: str -    completion_count: int = 0      default_model: str      def __init__(self, api_key: str, default_model: str, system_message: str = None):          self.api_key = api_key          self.default_model = default_model -        self.system_message = system_message +        self.system_message = (system_message or "") + \ +            "\nDo not ever call the 'python' function."          openai.api_key = api_key @@ -36,115 +23,52 @@ class OpenAI(LLM):      def name(self):          return self.default_model -    @cached_property -    def __encoding_for_model(self): -        aliases = {} -        return tiktoken.encoding_for_model(self.default_model) +    @property +    def default_args(self): +        return DEFAULT_ARGS | {"model": self.default_model}      def count_tokens(self, text: str): -        return len(self.__encoding_for_model.encode(text, disallowed_special=())) - -    def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): -        total_tokens = tokens_for_completion + \ -            sum(self.count_tokens(message.content) for message in chat_history) - -        # 1. Replace beyond last 5 messages with summary -        i = 0 -        while total_tokens > max_tokens and i < len(chat_history) - 5: -            message = chat_history[0] -            total_tokens -= self.count_tokens(message.content) -            total_tokens += self.count_tokens(message.summary) -            message.content = message.summary -            i += 1 - -        # 2. Remove entire messages until the last 5 -        while len(chat_history) > 5 and total_tokens > max_tokens: -            message = chat_history.pop(0) -            total_tokens -= self.count_tokens(message.content) +        return count_tokens(self.default_model, text) -        # 3. Truncate message in the last 5 -        i = 0 -        while total_tokens > max_tokens: -            message = chat_history[0] -            total_tokens -= self.count_tokens(message.content) -            total_tokens += self.count_tokens(message.summary) -            message.content = message.summary -            i += 1 - -        # 4. Remove entire messages in the last 5 -        while total_tokens > max_tokens and len(chat_history) > 0: -            message = chat_history.pop(0) -            total_tokens -= self.count_tokens(message.content) - -        return chat_history +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args | kwargs +        args["stream"] = True -    def with_system_message(self, system_message: Union[str, None]): -        return OpenAI(api_key=self.api_key, default_model=self.default_model, system_message=system_message) +        if args["model"] in CHAT_MODELS: +            async for chunk in await openai.ChatCompletion.acreate( +                messages=compile_chat_messages( +                    args["model"], with_history, prompt, with_functions=False), +                **args, +            ): +                if "content" in chunk.choices[0].delta: +                    yield chunk.choices[0].delta.content +                else: +                    continue +        else: +            async for chunk in await openai.Completion.acreate(prompt=prompt, **args): +                yield chunk.choices[0].text -    async def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        self.completion_count += 1 -        args = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, -                "frequency_penalty": 0, "presence_penalty": 0} | kwargs +    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args | kwargs          args["stream"] = True -        args["model"] = self.default_model +        args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613" +        if not args["model"].endswith("0613") and "functions" in args: +            del args["functions"]          async for chunk in await openai.ChatCompletion.acreate( -            messages=self.compile_chat_messages(with_history, prompt), +            messages=compile_chat_messages( +                args["model"], messages, with_functions=args["model"].endswith("0613")),              **args,          ): -            if "content" in chunk.choices[0].delta: -                yield chunk.choices[0].delta.content -            else: -                continue - -    def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: -        msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens( -            prompt) + 1000 + self.count_tokens(self.system_message or "")) -        history = [] -        if self.system_message: -            history.append({ -                "role": "system", -                "content": self.system_message -            }) -        history += [{"role": msg.role, "content": msg.content} for msg in msgs] -        history.append({ -            "role": "user", -            "content": prompt -        }) - -        return history - -    def stream_complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        self.completion_count += 1 -        args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, -                "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, "suffix": None} | kwargs -        args["stream"] = True - -        if args["model"] in CHAT_MODELS: -            generator = openai.ChatCompletion.create( -                messages=self.compile_chat_messages(with_history, prompt), -                **args, -            ) -            for chunk in generator: -                yield chunk.choices[0].message.content -        else: -            generator = openai.Completion.create( -                prompt=prompt, -                **args, -            ) -            for chunk in generator: -                yield chunk.choices[0].text +            yield chunk.choices[0].delta      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: -        t1 = time.time() - -        self.completion_count += 1 -        args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, -                "frequency_penalty": 0, "presence_penalty": 0, "stream": False} | kwargs +        args = self.default_args | kwargs          if args["model"] in CHAT_MODELS:              resp = (await openai.ChatCompletion.acreate( -                messages=self.compile_chat_messages(with_history, prompt), +                messages=compile_chat_messages( +                    args["model"], with_history, prompt, with_functions=False),                  **args,              )).choices[0].message.content          else: @@ -153,73 +77,4 @@ class OpenAI(LLM):                  **args,              )).choices[0].text -        t2 = time.time() -        print("Completion time:", t2 - t1)          return resp - -    def edit(self, inp: str, instruction: str) -> str: -        try: -            resp = openai.Edit.create( -                input=inp, -                instruction=instruction, -                model='text-davinci-edit-001' -            ).choices[0].text -            return resp -        except Exception as e: -            print("OpenAI error:", e) -            raise e - -    def parallel_edit(self, inputs: list[str], instructions: Union[List[str], str], **kwargs) -> list[str]: -        args = {"temperature": 0.5, "top_p": 1} | kwargs -        args['model'] = 'text-davinci-edit-001' - -        async def fn(): -            async with aiohttp.ClientSession() as session: -                tasks = [] - -                async def get(input, instruction): -                    async with session.post("https://api.openai.com/v1/edits", headers={ -                        "Content-Type": "application/json", -                        "Authorization": "Bearer " + self.api_key -                    }, json={"model": args["model"], "input": input, "instruction": instruction, "temperature": args["temperature"], "max_tokens": args["max_tokens"], "suffix": args["suffix"]}) as resp: -                        json = await resp.json() -                        if "error" in json: -                            print("ERROR IN GPT-3 RESPONSE: ", json) -                            return None -                        return json["choices"][0]["text"] - -                for i in range(len(inputs)): -                    tasks.append(get(inputs[i], instructions[i] if isinstance( -                        instructions, list) else instructions)) - -                return await asyncio.gather(*tasks) - -        return asyncio.run(fn()) - -    def parallel_complete(self, prompts: list[str], suffixes: Union[list[str], None] = None, **kwargs) -> list[str]: -        self.completion_count += len(prompts) -        args = {"model": self.default_model, "max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, -                "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0} | kwargs - -        async def fn(): -            async with aiohttp.ClientSession() as session: -                tasks = [] - -                async def get(prompt, suffix): -                    async with session.post("https://api.openai.com/v1/completions", headers={ -                        "Content-Type": "application/json", -                        "Authorization": "Bearer " + self.api_key -                    }, json={"model": args["model"], "prompt": prompt, "temperature": args["temperature"], "max_tokens": args["max_tokens"], "suffix": suffix}) as resp: -                        json = await resp.json() -                        if "error" in json: -                            print("ERROR IN GPT-3 RESPONSE: ", json) -                            return None -                        return json["choices"][0]["text"] - -                for i in range(len(prompts)): -                    tasks.append(asyncio.ensure_future( -                        get(prompts[i], suffixes[i] if suffixes else None))) - -                return await asyncio.gather(*tasks) - -        return asyncio.run(fn()) diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index a29f5c89..6b0336a7 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -1,24 +1,13 @@  from functools import cached_property  import json  from typing import Any, Coroutine, Dict, Generator, List, Literal, Union -import requests -import tiktoken  import aiohttp -  from ...core.main import ChatMessage  from ..llm import LLM +from ..util.count_tokens import DEFAULT_ARGS, DEFAULT_MAX_TOKENS, compile_chat_messages, CHAT_MODELS, count_tokens -MAX_TOKENS_FOR_MODEL = { -    "gpt-3.5-turbo": 4097, -    "gpt-4": 4097, -} -DEFAULT_MAX_TOKENS = 2048 -CHAT_MODELS = { -    "gpt-3.5-turbo", "gpt-4" -} - -# SERVER_URL = "http://127.0.0.1:8080" -SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app" +SERVER_URL = "http://127.0.0.1:8080" +# SERVER_URL = "https://proxy-server-l6vsfbzhba-uw.a.run.app"  class ProxyServer(LLM): @@ -29,67 +18,66 @@ class ProxyServer(LLM):      def __init__(self, unique_id: str, default_model: Literal["gpt-3.5-turbo", "gpt-4"], system_message: str = None):          self.unique_id = unique_id          self.default_model = default_model -        self.system_message = system_message +        self.system_message = (system_message or "") + \ +            "\nDo not ever call the 'python' function."          self.name = default_model -    @cached_property -    def __encoding_for_model(self): -        aliases = { -            "gpt-3.5-turbo": "gpt3" -        } -        return tiktoken.encoding_for_model(self.default_model) +    @property +    def default_args(self): +        return DEFAULT_ARGS | {"model": self.default_model}      def count_tokens(self, text: str): -        return len(self.__encoding_for_model.encode(text, disallowed_special=())) - -    def __prune_chat_history(self, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): -        tokens = tokens_for_completion -        for i in range(len(chat_history) - 1, -1, -1): -            message = chat_history[i] -            tokens += self.count_tokens(message.content) -            if tokens > max_tokens: -                return chat_history[i + 1:] -        return chat_history - -    def compile_chat_messages(self, msgs: List[ChatMessage], prompt: str) -> List[Dict]: -        msgs = self.__prune_chat_history(msgs, MAX_TOKENS_FOR_MODEL[self.default_model], self.count_tokens( -            prompt) + 1000 + self.count_tokens(self.system_message or "")) -        history = [] -        if self.system_message: -            history.append({ -                "role": "system", -                "content": self.system_message -            }) -        history += [{"role": msg.role, "content": msg.content} for msg in msgs] -        history.append({ -            "role": "user", -            "content": prompt -        }) - -        return history +        return count_tokens(self.default_model, text)      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: +        args = self.default_args | kwargs +          async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/complete", json={ -                "chat_history": self.compile_chat_messages(with_history, prompt), -                "model": self.default_model, +                "messages": compile_chat_messages(args["model"], with_history, prompt, with_functions=False),                  "unique_id": self.unique_id, +                **args              }) as resp:                  try:                      return json.loads(await resp.text()) -                except json.JSONDecodeError: +                except:                      raise Exception(await resp.text()) -    async def stream_chat(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +    async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Any | List | Dict, None, None]]: +        args = self.default_args | kwargs +        messages = compile_chat_messages( +            self.default_model, messages, None, with_functions=args["model"].endswith("0613")) + +        async with aiohttp.ClientSession() as session: +            async with session.post(f"{SERVER_URL}/stream_chat", json={ +                "messages": messages, +                "unique_id": self.unique_id, +                **args +            }) as resp: +                # This is streaming application/json instaed of text/event-stream +                async for line in resp.content.iter_chunks(): +                    if line[1]: +                        try: +                            json_chunk = line[0].decode("utf-8") +                            json_chunk = "{}" if json_chunk == "" else json_chunk +                            yield json.loads(json_chunk) +                        except: +                            raise Exception(str(line[0])) + +    async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: +        args = self.default_args | kwargs +        messages = compile_chat_messages( +            self.default_model, with_history, prompt, with_functions=args["model"].endswith("0613")) +          async with aiohttp.ClientSession() as session:              async with session.post(f"{SERVER_URL}/stream_complete", json={ -                "chat_history": self.compile_chat_messages(with_history, prompt), -                "model": self.default_model, +                "messages": messages,                  "unique_id": self.unique_id, +                **args              }) as resp:                  async for line in resp.content.iter_any():                      if line:                          try:                              yield line.decode("utf-8") -                        except json.JSONDecodeError: +                        except:                              raise Exception(str(line)) diff --git a/continuedev/src/continuedev/libs/util/count_tokens.py b/continuedev/src/continuedev/libs/util/count_tokens.py new file mode 100644 index 00000000..6038b68d --- /dev/null +++ b/continuedev/src/continuedev/libs/util/count_tokens.py @@ -0,0 +1,84 @@ +from typing import Dict, List, Union +from ...core.main import ChatMessage +import tiktoken + +aliases = {} +DEFAULT_MAX_TOKENS = 2048 +MAX_TOKENS_FOR_MODEL = { +    "gpt-3.5-turbo": 4096 - DEFAULT_MAX_TOKENS, +    "gpt-3.5-turbo-0613": 4096 - DEFAULT_MAX_TOKENS, +    "gpt-3.5-turbo-16k": 16384 - DEFAULT_MAX_TOKENS, +    "gpt-4": 8192 - DEFAULT_MAX_TOKENS +} +CHAT_MODELS = { +    "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-3.5-turbo-0613" +} +DEFAULT_ARGS = {"max_tokens": DEFAULT_MAX_TOKENS, "temperature": 0.5, "top_p": 1, +                "frequency_penalty": 0, "presence_penalty": 0} + + +def encoding_for_model(model: str): +    return tiktoken.encoding_for_model(aliases.get(model, model)) + + +def count_tokens(model: str, text: str | None): +    if text is None: +        return 0 +    encoding = encoding_for_model(model) +    return len(encoding.encode(text, disallowed_special=())) + + +def prune_chat_history(model: str, chat_history: List[ChatMessage], max_tokens: int, tokens_for_completion: int): +    total_tokens = tokens_for_completion + \ +        sum(count_tokens(model, message.content) +            for message in chat_history) + +    # 1. Replace beyond last 5 messages with summary +    i = 0 +    while total_tokens > max_tokens and i < len(chat_history) - 5: +        message = chat_history[0] +        total_tokens -= count_tokens(model, message.content) +        total_tokens += count_tokens(model, message.summary) +        message.content = message.summary +        i += 1 + +    # 2. Remove entire messages until the last 5 +    while len(chat_history) > 5 and total_tokens > max_tokens: +        message = chat_history.pop(0) +        total_tokens -= count_tokens(model, message.content) + +    # 3. Truncate message in the last 5 +    i = 0 +    while total_tokens > max_tokens: +        message = chat_history[0] +        total_tokens -= count_tokens(model, message.content) +        total_tokens += count_tokens(model, message.summary) +        message.content = message.summary +        i += 1 + +    # 4. Remove entire messages in the last 5 +    while total_tokens > max_tokens and len(chat_history) > 0: +        message = chat_history.pop(0) +        total_tokens -= count_tokens(model, message.content) + +    return chat_history + + +def compile_chat_messages(model: str, msgs: List[ChatMessage], prompt: str | None = None, with_functions: bool = False, system_message: Union[str, None] = None) -> List[Dict]: +    prompt_tokens = count_tokens(model, prompt) +    msgs = prune_chat_history(model, +                              msgs, MAX_TOKENS_FOR_MODEL[model], prompt_tokens + 1000 + count_tokens(model, system_message)) +    history = [] +    if system_message: +        history.append({ +            "role": "system", +            "content": system_message +        }) +    history += [msg.to_dict(with_functions=with_functions) for msg in msgs] +    if prompt: +        history.append({ +            "role": "user", +            "content": prompt +        }) + +    return history diff --git a/continuedev/src/continuedev/steps/chat.py b/continuedev/src/continuedev/steps/chat.py index fd7457d9..a940c3ba 100644 --- a/continuedev/src/continuedev/steps/chat.py +++ b/continuedev/src/continuedev/steps/chat.py @@ -1,8 +1,19 @@ -from textwrap import dedent -from typing import List -from ..core.main import Step -from ..core.sdk import ContinueSDK +import json +from typing import Any, Coroutine, List + +from .main import EditHighlightedCodeStep  from .core.core import MessageStep +from ..core.main import FunctionCall, Models +from ..core.main import ChatMessage, Step, step_to_json_schema +from ..core.sdk import ContinueSDK +import openai +import os +from dotenv import load_dotenv +from directory_tree import display_tree + +load_dotenv() +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +openai.api_key = OPENAI_API_KEY  class SimpleChatStep(Step): @@ -13,9 +24,202 @@ class SimpleChatStep(Step):          self.description = f"```{self.user_input}```\n\n"          await sdk.update_ui() -        async for chunk in sdk.models.default.stream_chat(self.user_input, with_history=await sdk.get_chat_context()): +        async for chunk in sdk.models.default.stream_complete(self.user_input, with_history=await sdk.get_chat_context()):              self.description += chunk              await sdk.update_ui()          self.name = (await sdk.models.gpt35.complete(              f"Write a short title for the following chat message: {self.description}")).strip() + + +class AddFileStep(Step): +    name: str = "Add File" +    description = "Add a file to the workspace." +    filename: str +    file_contents: str + +    async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: +        return f"Added a file named `{self.filename}` to the workspace." + +    async def run(self, sdk: ContinueSDK): +        try: +            await sdk.add_file(self.filename, self.file_contents) +        except FileNotFoundError: +            self.description = f"File {self.filename} does not exist." +            return +        currently_open_file = (await sdk.ide.getOpenFiles())[0] +        await sdk.ide.setFileOpen(os.path.join(sdk.ide.workspace_directory, self.filename)) +        await sdk.ide.setFileOpen(currently_open_file) + + +class DeleteFileStep(Step): +    name: str = "Delete File" +    description = "Delete a file from the workspace." +    filename: str + +    async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: +        return f"Deleted a file named `{self.filename}` from the workspace." + +    async def run(self, sdk: ContinueSDK): +        await sdk.delete_file(self.filename) + + +class AddDirectoryStep(Step): +    name: str = "Add Directory" +    description = "Add a directory to the workspace." +    directory_name: str + +    async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: +        return f"Added a directory named `{self.directory_name}` to the workspace." + +    async def run(self, sdk: ContinueSDK): +        try: +            await sdk.add_directory(self.directory_name) +        except FileExistsError: +            self.description = f"Directory {self.directory_name} already exists." + + +class RunTerminalCommandStep(Step): +    name: str = "Run Terminal Command" +    description: str = "Run a terminal command." +    command: str + +    async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: +        return f"Ran the terminal command `{self.command}`." + +    async def run(self, sdk: ContinueSDK): +        await sdk.wait_for_user_confirmation(f"Run the following terminal command?\n\n```bash\n{self.command}\n```") +        await sdk.run(self.command) + + +class ViewDirectoryTreeStep(Step): +    name: str = "View Directory Tree" +    description: str = "View the directory tree to learn which folder and files exist." + +    async def describe(self, models: Models) -> Coroutine[Any, Any, Coroutine[str, None, None]]: +        return f"Viewed the directory tree." + +    async def run(self, sdk: ContinueSDK): +        self.description = f"```\n{display_tree(sdk.ide.workspace_directory, True)}\n```" + + +class EditFileStep(Step): +    name: str = "Edit File" +    description: str = "Edit a file in the workspace that is not currently open." +    filename: str +    instructions: str +    hide: bool = True + +    async def run(self, sdk: ContinueSDK): +        await sdk.edit_file(self.filename, self.instructions) + + +class ChatWithFunctions(Step): +    user_input: str +    functions: List[Step] = [AddFileStep(filename="", file_contents=""), +                             EditFileStep(filename="", instructions=""), +                             EditHighlightedCodeStep(user_input=""), +                             ViewDirectoryTreeStep(), AddDirectoryStep(directory_name=""), +                             DeleteFileStep(filename=""), RunTerminalCommandStep(command="")] +    name: str = "Chat" +    manage_own_chat_context: bool = True + +    async def run(self, sdk: ContinueSDK): +        self.description = f"```{self.user_input}```\n\nDeciding next steps...\n\n" +        await sdk.update_ui() + +        step_name_step_class_map = { +            step.name.replace(" ", ""): step.__class__ for step in self.functions} + +        functions = [step_to_json_schema( +            function) for function in self.functions] + +        self.chat_context.append(ChatMessage( +            role="user", +            content=self.user_input + "\n**DO NOT EVER call the 'python' function.**", +            summary=self.user_input +        )) + +        last_function_called_index_in_history = None +        while True: +            was_function_called = False +            func_args = "" +            func_name = "" +            msg_content = "" +            msg_step = None + +            async for msg_chunk in sdk.models.gpt350613.stream_chat(await sdk.get_chat_context(), functions=functions): +                if "content" in msg_chunk and msg_chunk["content"] is not None: +                    msg_content += msg_chunk["content"] +                    # if last_function_called_index_in_history is not None: +                    #     while sdk.history.timeline[last_function_called_index].step.hide: +                    #         last_function_called_index += 1 +                    #     sdk.history.timeline[last_function_called_index_in_history].step.description = msg_content +                    if msg_step is None: +                        msg_step = MessageStep( +                            name="Chat", +                            message=msg_chunk["content"] +                        ) +                        await sdk.run_step(msg_step) +                    else: +                        msg_step.description = msg_content +                    await sdk.update_ui() +                elif "function_call" in msg_chunk or func_name != "": +                    was_function_called = True +                    if "function_call" in msg_chunk: +                        if "arguments" in msg_chunk["function_call"]: +                            func_args += msg_chunk["function_call"]["arguments"] +                        if "name" in msg_chunk["function_call"]: +                            func_name += msg_chunk["function_call"]["name"] + +            if not was_function_called: +                self.chat_context.append(ChatMessage( +                    role="assistant", +                    content=msg_content, +                    summary=msg_content +                )) +                break +            else: +                if func_name == "python" and "python" not in step_name_step_class_map: +                    # GPT must be fine-tuned to believe this exists, but it doesn't always +                    self.chat_context.append(ChatMessage( +                        role="assistant", +                        content=None, +                        function_call=FunctionCall( +                            name=func_name, +                            arguments=func_args +                        ), +                        summary=f"Ran function {func_name}" +                    )) +                    self.chat_context.append(ChatMessage( +                        role="user", +                        content="The 'python' function does not exist. Don't call it.", +                        summary="'python' function does not exist." +                    )) +                    continue +                # Call the function, then continue to chat +                func_args = "{}" if func_args == "" else func_args +                fn_call_params = json.loads(func_args) +                self.chat_context.append(ChatMessage( +                    role="assistant", +                    content=None, +                    function_call=FunctionCall( +                        name=func_name, +                        arguments=func_args +                    ), +                    summary=f"Ran function {func_name}" +                )) +                last_function_called_index_in_history = sdk.history.current_index + 1 +                step_to_run = step_name_step_class_map[func_name]( +                    **fn_call_params) + +                if func_name == "AddFileStep": +                    step_to_run.hide = True +                    self.description += f"\nAdded file `{func_args['filename']}`" +                elif func_name == "AddDirectoryStep": +                    step_to_run.hide = True +                    self.description += f"\nAdded directory `{func_args['directory_name']}`" +                else: +                    self.description += f"\n`Running function {func_name}`\n\n" +                await sdk.run_step(step_to_run) +                await sdk.update_ui() diff --git a/continuedev/src/continuedev/steps/core/core.py b/continuedev/src/continuedev/steps/core/core.py index 71a5b5b2..f146c94a 100644 --- a/continuedev/src/continuedev/steps/core/core.py +++ b/continuedev/src/continuedev/steps/core/core.py @@ -10,7 +10,7 @@ from ...models.filesystem_edit import EditDiff, FileEdit, FileEditWithFullConten  from ...models.filesystem import FileSystem, RangeInFile, RangeInFileWithContents  from ...core.observation import Observation, TextObservation, TracebackObservation, UserInputObservation  from ...core.main import Step, SequentialStep -from ...libs.llm.openai import MAX_TOKENS_FOR_MODEL +from ...libs.util.count_tokens import MAX_TOKENS_FOR_MODEL  import difflib @@ -160,7 +160,7 @@ class DefaultModelEditCodeStep(Step):          return f"`{self.user_input}`\n\n" + description      async def run(self, sdk: ContinueSDK) -> Coroutine[Observation, None, None]: -        self.description = f"`{self.user_input}`" +        self.description = f"{self.user_input}"          await sdk.update_ui()          rif_with_contents = [] @@ -213,7 +213,8 @@ class DefaultModelEditCodeStep(Step):              if model_to_use.name == "gpt-4": -                total_tokens = model_to_use.count_tokens(full_file_contents + self._prompt) +                total_tokens = model_to_use.count_tokens( +                    full_file_contents + self._prompt)                  cur_start_line, cur_end_line = cut_context(                      model_to_use, total_tokens, cur_start_line, cur_end_line) @@ -316,7 +317,7 @@ class DefaultModelEditCodeStep(Step):              lines_of_prefix_copied = 0              line_below_highlighted_range = segs[1].lstrip().split("\n")[0]              should_stop = False -            async for chunk in model_to_use.stream_chat(prompt, with_history=await sdk.get_chat_context(), temperature=0): +            async for chunk in model_to_use.stream_complete(prompt, with_history=await sdk.get_chat_context(), temperature=0):                  if should_stop:                      break                  chunk_lines = chunk.split("\n") diff --git a/continuedev/src/continuedev/steps/main.py b/continuedev/src/continuedev/steps/main.py index 5ba86c53..5caac180 100644 --- a/continuedev/src/continuedev/steps/main.py +++ b/continuedev/src/continuedev/steps/main.py @@ -1,7 +1,7 @@  import os  from typing import Coroutine, List, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field  from ..libs.llm import LLM  from ..models.main import Traceback, Range @@ -246,8 +246,10 @@ class StarCoderEditHighlightedCodeStep(Step):  class EditHighlightedCodeStep(Step): -    user_input: str +    user_input: str = Field( +        ..., title="User Input", description="The natural language request describing how to edit the code")      hide = True +    description: str = "Change the contents of the currently highlighted code or open file"      async def describe(self, models: Models) -> Coroutine[str, None, None]:          return "Editing code" diff --git a/continuedev/src/continuedev/steps/on_traceback.py b/continuedev/src/continuedev/steps/on_traceback.py index 053b4ef4..3f8c5a76 100644 --- a/continuedev/src/continuedev/steps/on_traceback.py +++ b/continuedev/src/continuedev/steps/on_traceback.py @@ -1,5 +1,5 @@  import os -from ..core.main import Step +from ..core.main import ChatMessage, Step  from ..core.sdk import ContinueSDK  from .chat import SimpleChatStep @@ -16,7 +16,11 @@ class DefaultOnTracebackStep(Step):              for seg in segs:                  if seg.startswith(os.path.sep) and os.path.exists(seg) and os.path.commonprefix([seg, sdk.ide.workspace_directory]) == sdk.ide.workspace_directory:                      file_contents = await sdk.ide.readFile(seg) -                    await sdk.add_chat_context(f"The contents of {seg}:\n```\n{file_contents}\n```", "", "user") +                    self.chat_context.append(ChatMessage( +                        role="user", +                        content=f"The contents of {seg}:\n```\n{file_contents}\n```", +                        summary="" +                    ))          await sdk.run_step(SimpleChatStep(              name="Help With Traceback", diff --git a/extension/react-app/src/components/HeaderButtonWithText.tsx b/extension/react-app/src/components/HeaderButtonWithText.tsx index acaca9ce..5901c5d8 100644 --- a/extension/react-app/src/components/HeaderButtonWithText.tsx +++ b/extension/react-app/src/components/HeaderButtonWithText.tsx @@ -6,12 +6,14 @@ interface HeaderButtonWithTextProps {    text: string;    onClick?: (e: any) => void;    children: React.ReactNode; +  disabled?: boolean;  }  const HeaderButtonWithText = (props: HeaderButtonWithTextProps) => {    const [hover, setHover] = useState(false);    return (      <HeaderButton +      disabled={props.disabled}        style={{ padding: "1px", paddingLeft: hover ? "4px" : "1px" }}        onMouseEnter={() => setHover(true)}        onMouseLeave={() => { diff --git a/extension/react-app/src/components/StepContainer.tsx b/extension/react-app/src/components/StepContainer.tsx index 74a1c4e8..827d2d5f 100644 --- a/extension/react-app/src/components/StepContainer.tsx +++ b/extension/react-app/src/components/StepContainer.tsx @@ -200,6 +200,7 @@ function StepContainer(props: StepContainerProps) {              <>                <HeaderButtonWithText +                disabled={props.historyNode.active as boolean}                  onClick={(e) => {                    e.stopPropagation();                    props.onDelete(); diff --git a/extension/react-app/src/tabs/gui.tsx b/extension/react-app/src/tabs/gui.tsx index a457382d..5001fe4b 100644 --- a/extension/react-app/src/tabs/gui.tsx +++ b/extension/react-app/src/tabs/gui.tsx @@ -391,7 +391,7 @@ function GUI(props: GUIProps) {              />            );          })} -        {/* {waitingForSteps && <Loader></Loader>} */} +        {waitingForSteps && <Loader></Loader>}          <div>            {userInputQueue.map((input) => { @@ -491,6 +491,8 @@ function GUI(props: GUIProps) {          <HeaderButtonWithText            onClick={() => {              client?.sendClear(); +            // Reload the window to get completely fresh session +            window.location.reload();            }}            text="Clear All"          > | 
