diff options
| -rw-r--r-- | continuedev/src/continuedev/core/autopilot.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/hugging_face.py | 11 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/openai.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/libs/llm/proxy_server.py | 8 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/gui.py | 10 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide.py | 6 | ||||
| -rw-r--r-- | continuedev/src/continuedev/server/ide_protocol.py | 4 | ||||
| -rw-r--r-- | continuedev/src/continuedev/steps/search_directory.py | 4 | ||||
| -rw-r--r-- | extension/react-app/src/components/ComboBox.tsx | 57 | ||||
| -rw-r--r-- | extension/react-app/src/components/HeaderButtonWithText.tsx | 2 | ||||
| -rw-r--r-- | extension/react-app/src/components/PillButton.tsx | 9 | ||||
| -rw-r--r-- | extension/react-app/src/components/index.ts | 13 | ||||
| -rw-r--r-- | extension/react-app/src/hooks/ContinueGUIClientProtocol.ts | 2 | ||||
| -rw-r--r-- | extension/react-app/src/hooks/useContinueGUIProtocol.ts | 4 | ||||
| -rw-r--r-- | extension/react-app/src/tabs/gui.tsx | 23 | ||||
| -rw-r--r-- | extension/src/continueIdeClient.ts | 39 | 
16 files changed, 164 insertions, 46 deletions
| diff --git a/continuedev/src/continuedev/core/autopilot.py b/continuedev/src/continuedev/core/autopilot.py index 5193a02b..05e48f40 100644 --- a/continuedev/src/continuedev/core/autopilot.py +++ b/continuedev/src/continuedev/core/autopilot.py @@ -173,8 +173,12 @@ class Autopilot(ContinueBaseModel):          self.history.timeline[index].deleted = True          await self.update_subscribers() -    async def delete_context_item_at_index(self, index: int): -        self._highlighted_ranges.pop(index) +    async def delete_context_at_indices(self, indices: List[int]): +        kept_ranges = [] +        for i, rif in enumerate(self._highlighted_ranges): +            if i not in indices: +                kept_ranges.append(rif) +        self._highlighted_ranges = kept_ranges          await self.update_subscribers()      async def _run_singular_step(self, step: "Step", is_future_step: bool = False) -> Coroutine[Observation, None, None]: diff --git a/continuedev/src/continuedev/libs/llm/hugging_face.py b/continuedev/src/continuedev/libs/llm/hugging_face.py index 868cb560..b0db585b 100644 --- a/continuedev/src/continuedev/libs/llm/hugging_face.py +++ b/continuedev/src/continuedev/libs/llm/hugging_face.py @@ -1,14 +1,17 @@  from .llm import LLM  from transformers import AutoTokenizer, AutoModelForCausalLM +  class HuggingFace(LLM):      def __init__(self, model_path: str = "Salesforce/codegen-2B-mono"):          self.model_path = model_path          self.tokenizer = AutoTokenizer.from_pretrained(model_path)          self.model = AutoModelForCausalLM.from_pretrained(model_path) -     +      def complete(self, prompt: str, **kwargs): -        args = { "max_tokens": 100 } | kwargs +        args = {"max_tokens": 100} +        args.update(kwargs)          input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids -        generated_ids = self.model.generate(input_ids, max_length=args["max_tokens"]) -        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
\ No newline at end of file +        generated_ids = self.model.generate( +            input_ids, max_length=args["max_tokens"]) +        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) diff --git a/continuedev/src/continuedev/libs/llm/openai.py b/continuedev/src/continuedev/libs/llm/openai.py index a3ca5c80..c4e4139f 100644 --- a/continuedev/src/continuedev/libs/llm/openai.py +++ b/continuedev/src/continuedev/libs/llm/openai.py @@ -24,13 +24,14 @@ class OpenAI(LLM):      @property      def default_args(self): -        return DEFAULT_ARGS | {"model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.default_model}      def count_tokens(self, text: str):          return count_tokens(self.default_model, text)      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        args = self.default_args | kwargs +        args = self.default_args.copy() +        args.update(kwargs)          args["stream"] = True          if args["model"] in CHAT_MODELS: @@ -48,7 +49,8 @@ class OpenAI(LLM):                  yield chunk.choices[0].text      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        args = self.default_args | kwargs +        args = self.default_args.copy() +        args.update(kwargs)          args["stream"] = True          args["model"] = self.default_model if self.default_model in CHAT_MODELS else "gpt-3.5-turbo-0613"          if not args["model"].endswith("0613") and "functions" in args: @@ -62,7 +64,7 @@ class OpenAI(LLM):              yield chunk.choices[0].delta      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: -        args = self.default_args | kwargs +        args = {**self.default_args, **kwargs}          if args["model"] in CHAT_MODELS:              resp = (await openai.ChatCompletion.acreate( diff --git a/continuedev/src/continuedev/libs/llm/proxy_server.py b/continuedev/src/continuedev/libs/llm/proxy_server.py index 69c96ee8..05ece394 100644 --- a/continuedev/src/continuedev/libs/llm/proxy_server.py +++ b/continuedev/src/continuedev/libs/llm/proxy_server.py @@ -28,13 +28,13 @@ class ProxyServer(LLM):      @property      def default_args(self): -        return DEFAULT_ARGS | {"model": self.default_model} +        return {**DEFAULT_ARGS, "model": self.default_model}      def count_tokens(self, text: str):          return count_tokens(self.default_model, text)      async def complete(self, prompt: str, with_history: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, str]: -        args = self.default_args | kwargs +        args = {**self.default_args, **kwargs}          async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(ssl_context=ssl_context)) as session:              async with session.post(f"{SERVER_URL}/complete", json={ @@ -48,7 +48,7 @@ class ProxyServer(LLM):                      raise Exception(await resp.text())      async def stream_chat(self, messages: List[ChatMessage] = [], **kwargs) -> Coroutine[Any, Any, Generator[Union[Any, List, Dict], None, None]]: -        args = self.default_args | kwargs +        args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              self.default_model, messages, None, functions=args.get("functions", None)) @@ -72,7 +72,7 @@ class ProxyServer(LLM):                              raise Exception(str(line[0]))      async def stream_complete(self, prompt, with_history: List[ChatMessage] = [], **kwargs) -> Generator[Union[Any, List, Dict], None, None]: -        args = self.default_args | kwargs +        args = {**self.default_args, **kwargs}          messages = compile_chat_messages(              self.default_model, with_history, prompt, functions=args.get("functions", None)) diff --git a/continuedev/src/continuedev/server/gui.py b/continuedev/src/continuedev/server/gui.py index b2f23bac..4e960f7c 100644 --- a/continuedev/src/continuedev/server/gui.py +++ b/continuedev/src/continuedev/server/gui.py @@ -1,6 +1,6 @@  import json  from fastapi import Depends, Header, WebSocket, APIRouter -from typing import Any, Type, TypeVar, Union +from typing import Any, List, Type, TypeVar, Union  from pydantic import BaseModel  from uvicorn.main import Server @@ -83,8 +83,8 @@ class GUIProtocolServer(AbstractGUIProtocolServer):                  self.on_clear_history()              elif message_type == "delete_at_index":                  self.on_delete_at_index(data["index"]) -            elif message_type == "delete_context_item_at_index": -                self.on_delete_context_item_at_index(data["index"]) +            elif message_type == "delete_context_at_indices": +                self.on_delete_context_at_indices(data["indices"])          except Exception as e:              print(e) @@ -123,9 +123,9 @@ class GUIProtocolServer(AbstractGUIProtocolServer):      def on_delete_at_index(self, index: int):          asyncio.create_task(self.session.autopilot.delete_at_index(index)) -    def on_delete_context_item_at_index(self, index: int): +    def on_delete_context_at_indices(self, indices: List[int]):          asyncio.create_task( -            self.session.autopilot.delete_context_item_at_index(index) +            self.session.autopilot.delete_context_at_indices(indices)          ) diff --git a/continuedev/src/continuedev/server/ide.py b/continuedev/src/continuedev/server/ide.py index e2685493..ea355d3c 100644 --- a/continuedev/src/continuedev/server/ide.py +++ b/continuedev/src/continuedev/server/ide.py @@ -160,6 +160,12 @@ class IdeProtocolServer(AbstractIdeProtocolServer):              "edit": file_edit.dict()          }) +    async def showDiff(self, filepath: str, replacement: str): +        await self._send_json("showDiff", { +            "filepath": filepath, +            "replacement": replacement +        }) +      async def setFileOpen(self, filepath: str, open: bool = True):          # Autopilot needs access to this.          await self._send_json("setFileOpen", { diff --git a/continuedev/src/continuedev/server/ide_protocol.py b/continuedev/src/continuedev/server/ide_protocol.py index de2eea27..2e1f78d7 100644 --- a/continuedev/src/continuedev/server/ide_protocol.py +++ b/continuedev/src/continuedev/server/ide_protocol.py @@ -95,6 +95,10 @@ class AbstractIdeProtocolServer(ABC):      def onHighlightedCodeUpdate(self, range_in_files: List[RangeInFileWithContents]):          """Called when highlighted code is updated""" +    @abstractmethod +    async def showDiff(self, filepath: str, replacement: str): +        """Show a diff""" +      @abstractproperty      def workspace_directory(self) -> str:          """Get the workspace directory""" diff --git a/continuedev/src/continuedev/steps/search_directory.py b/continuedev/src/continuedev/steps/search_directory.py index d2966f46..2eecc99c 100644 --- a/continuedev/src/continuedev/steps/search_directory.py +++ b/continuedev/src/continuedev/steps/search_directory.py @@ -1,6 +1,6 @@  import asyncio  from textwrap import dedent -from typing import List +from typing import List, Union  from ..models.filesystem import RangeInFile  from ..models.main import Range @@ -54,7 +54,7 @@ class WriteRegexPatternStep(Step):  class EditAllMatchesStep(Step):      pattern: str      user_request: str -    directory: str | None = None +    directory: Union[str, None] = None      async def run(self, sdk: ContinueSDK):          # Search all files for a given string diff --git a/extension/react-app/src/components/ComboBox.tsx b/extension/react-app/src/components/ComboBox.tsx index 742c643b..bdb8850d 100644 --- a/extension/react-app/src/components/ComboBox.tsx +++ b/extension/react-app/src/components/ComboBox.tsx @@ -1,4 +1,4 @@ -import React, { useCallback, useEffect } from "react"; +import React, { useCallback, useEffect, useState } from "react";  import { useCombobox } from "downshift";  import styled from "styled-components";  import { @@ -10,7 +10,10 @@ import {  import CodeBlock from "./CodeBlock";  import { RangeInFile } from "../../../src/client";  import PillButton from "./PillButton"; +import HeaderButtonWithText from "./HeaderButtonWithText"; +import { Trash, LockClosed, LockOpen } from "@styled-icons/heroicons-outline"; +// #region styled components  const mainInputFontSize = 16;  const ContextDropdown = styled.div` @@ -87,13 +90,16 @@ const Li = styled.li<{    cursor: pointer;  `; +// #endregion +  interface ComboBoxProps {    items: { name: string; description: string }[];    onInputValueChange: (inputValue: string) => void;    disabled?: boolean; -  onEnter?: (e: React.KeyboardEvent<HTMLInputElement>) => void; -  highlightedCodeSections?: (RangeInFile & { contents: string })[]; -  deleteContextItem?: (idx: number) => void; +  onEnter: (e: React.KeyboardEvent<HTMLInputElement>) => void; +  highlightedCodeSections: (RangeInFile & { contents: string })[]; +  deleteContextItems: (indices: number[]) => void; +  onTogglePin: () => void;  }  const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => { @@ -104,6 +110,7 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {    const [hoveringButton, setHoveringButton] = React.useState(false);    const [hoveringContextDropdown, setHoveringContextDropdown] =      React.useState(false); +  const [pinned, setPinned] = useState(false);    const [highlightedCodeSections, setHighlightedCodeSections] = React.useState(      props.highlightedCodeSections || [        { @@ -242,12 +249,46 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {          </Ul>        </div>        <div className="px-2 flex gap-2 items-center flex-wrap"> +        {highlightedCodeSections.length > 0 && ( +          <> +            <HeaderButtonWithText +              text="Clear Context" +              onClick={() => { +                props.deleteContextItems( +                  highlightedCodeSections.map((_, idx) => idx) +                ); +              }} +            > +              <Trash size="1.6em" /> +            </HeaderButtonWithText> +            <HeaderButtonWithText +              text={pinned ? "Unpin Context" : "Pin Context"} +              inverted={pinned} +              onClick={() => { +                setPinned((prev) => !prev); +                props.onTogglePin(); +              }} +            > +              {pinned ? ( +                <LockClosed size="1.6em"></LockClosed> +              ) : ( +                <LockOpen size="1.6em"></LockOpen> +              )} +            </HeaderButtonWithText> +          </> +        )}          {highlightedCodeSections.map((section, idx) => (            <PillButton -            title={section.filepath} +            title={ +              hoveringButton +                ? `${section.filepath} (${section.range.start.line + 1}-${ +                    section.range.end.line + 1 +                  })` +                : section.filepath +            }              onDelete={() => { -              if (props.deleteContextItem) { -                props.deleteContextItem(idx); +              if (props.deleteContextItems) { +                props.deleteContextItems([idx]);                }                setHighlightedCodeSections((prev) => {                  const newSections = [...prev]; @@ -280,7 +321,7 @@ const ComboBox = React.forwardRef((props: ComboBoxProps, ref) => {          onMouseLeave={() => {            setHoveringContextDropdown(false);          }} -        hidden={!hoveringContextDropdown && !hoveringButton} +        hidden={true || (!hoveringContextDropdown && !hoveringButton)}        >          {highlightedCodeSections.map((section, idx) => (            <> diff --git a/extension/react-app/src/components/HeaderButtonWithText.tsx b/extension/react-app/src/components/HeaderButtonWithText.tsx index 30931f86..3ddac93c 100644 --- a/extension/react-app/src/components/HeaderButtonWithText.tsx +++ b/extension/react-app/src/components/HeaderButtonWithText.tsx @@ -7,12 +7,14 @@ interface HeaderButtonWithTextProps {    onClick?: (e: any) => void;    children: React.ReactNode;    disabled?: boolean; +  inverted?: boolean;  }  const HeaderButtonWithText = (props: HeaderButtonWithTextProps) => {    const [hover, setHover] = useState(false);    return (      <HeaderButton +      inverted={props.inverted}        disabled={props.disabled}        style={{ padding: "1px", paddingLeft: hover ? "4px" : "1px" }}        onMouseEnter={() => { diff --git a/extension/react-app/src/components/PillButton.tsx b/extension/react-app/src/components/PillButton.tsx index 33451db5..55fe1ac6 100644 --- a/extension/react-app/src/components/PillButton.tsx +++ b/extension/react-app/src/components/PillButton.tsx @@ -1,6 +1,7 @@  import { useState } from "react";  import styled from "styled-components";  import { defaultBorderRadius } from "."; +import { XMark } from "@styled-icons/heroicons-outline";  const Button = styled.button`    border: none; @@ -42,13 +43,12 @@ const PillButton = (props: PillButtonProps) => {        <div          style={{ display: "grid", gridTemplateColumns: "1fr auto", gap: "4px" }}        > -        <span>{props.title}</span>          <span            style={{              cursor: "pointer",              color: "red", -            borderLeft: "1px solid black", -            paddingLeft: "4px", +            borderRight: "1px solid black", +            paddingRight: "4px",            }}            hidden={!isHovered}            onClick={() => { @@ -56,8 +56,9 @@ const PillButton = (props: PillButtonProps) => {              props.onHover?.(false);            }}          > -          X +          <XMark style={{ padding: "0px" }} size="1.2em" strokeWidth="2px" />          </span> +        <span>{props.title}</span>        </div>      </Button>    ); diff --git a/extension/react-app/src/components/index.ts b/extension/react-app/src/components/index.ts index 429a7df5..db1925ed 100644 --- a/extension/react-app/src/components/index.ts +++ b/extension/react-app/src/components/index.ts @@ -124,16 +124,19 @@ export const appear = keyframes`      }  `; -export const HeaderButton = styled.button` -  background-color: transparent; +export const HeaderButton = styled.button<{ inverted: boolean | undefined }>` +  background-color: ${({ inverted }) => (inverted ? "white" : "transparent")}; +  color: ${({ inverted }) => (inverted ? "black" : "white")}; +    border: 1px solid white;    border-radius: ${defaultBorderRadius};    cursor: pointer; -  color: white;    &:hover { -    background-color: white; -    color: black; +    background-color: ${({ inverted }) => +      typeof inverted === "undefined" || inverted ? "white" : "transparent"}; +    color: ${({ inverted }) => +      typeof inverted === "undefined" || inverted ? "black" : "white"};    }    display: flex;    align-items: center; diff --git a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts index 228e9a53..96ea7ab3 100644 --- a/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts +++ b/extension/react-app/src/hooks/ContinueGUIClientProtocol.ts @@ -21,7 +21,7 @@ abstract class AbstractContinueGUIClientProtocol {    abstract deleteAtIndex(index: number): void; -  abstract deleteContextItemAtIndex(index: number): void; +  abstract deleteContextAtIndices(indices: number[]): void;  }  export default AbstractContinueGUIClientProtocol; diff --git a/extension/react-app/src/hooks/useContinueGUIProtocol.ts b/extension/react-app/src/hooks/useContinueGUIProtocol.ts index a0c38c0f..e950387c 100644 --- a/extension/react-app/src/hooks/useContinueGUIProtocol.ts +++ b/extension/react-app/src/hooks/useContinueGUIProtocol.ts @@ -71,8 +71,8 @@ class ContinueGUIClientProtocol extends AbstractContinueGUIClientProtocol {      this.messenger.send("delete_at_index", { index });    } -  deleteContextItemAtIndex(index: number) { -    this.messenger.send("delete_context_item_at_index", { index }); +  deleteContextAtIndices(indices: number[]) { +    this.messenger.send("delete_context_at_indices", { indices });    }  } diff --git a/extension/react-app/src/tabs/gui.tsx b/extension/react-app/src/tabs/gui.tsx index 40256f86..c2ff101a 100644 --- a/extension/react-app/src/tabs/gui.tsx +++ b/extension/react-app/src/tabs/gui.tsx @@ -70,10 +70,13 @@ function GUI(props: GUIProps) {    const [usingFastModel, setUsingFastModel] = useState(false);    const [waitingForSteps, setWaitingForSteps] = useState(false);    const [userInputQueue, setUserInputQueue] = useState<string[]>([]); -  const [highlightedRanges, setHighlightedRanges] = useState([]); +  const [highlightedRanges, setHighlightedRanges] = useState([ +    { filepath: "abc.txt", range: { start: { line: 2 }, end: { line: 4 } } }, +  ]);    const [availableSlashCommands, setAvailableSlashCommands] = useState<      { name: string; description: string }[]    >([]); +  const [pinned, setPinned] = useState(false);    const [showDataSharingInfo, setShowDataSharingInfo] = useState(false);    const [stepsOpen, setStepsOpen] = useState<boolean[]>([      true, @@ -185,9 +188,9 @@ function GUI(props: GUIProps) {    const mainTextInputRef = useRef<HTMLInputElement>(null); -  const deleteContextItem = useCallback( -    (idx: number) => { -      client?.deleteContextItemAtIndex(idx); +  const deleteContextItems = useCallback( +    (indices: number[]) => { +      client?.deleteContextAtIndices(indices);      },      [client]    ); @@ -241,6 +244,13 @@ function GUI(props: GUIProps) {        setUserInputQueue((queue) => {          return [...queue, input];        }); + +      // Delete all context items unless locked +      if (!pinned) { +        client?.deleteContextAtIndices( +          highlightedRanges.map((_, index) => index) +        ); +      }      }    }; @@ -345,7 +355,10 @@ function GUI(props: GUIProps) {            onInputValueChange={() => {}}            items={availableSlashCommands}            highlightedCodeSections={highlightedRanges} -          deleteContextItem={deleteContextItem} +          deleteContextItems={deleteContextItems} +          onTogglePin={() => { +            setPinned((prev: boolean) => !prev); +          }}          />          <ContinueButton onClick={onMainTextInput} />        </TopGUIDiv> diff --git a/extension/src/continueIdeClient.ts b/extension/src/continueIdeClient.ts index 999bca88..c517eb98 100644 --- a/extension/src/continueIdeClient.ts +++ b/extension/src/continueIdeClient.ts @@ -159,6 +159,9 @@ class IdeProtocolClient {        case "showSuggestion":          this.showSuggestion(data.edit);          break; +      case "showDiff": +        this.showDiff(data.filepath, data.replacement); +        break;        case "openGUI":        case "connected":          break; @@ -236,6 +239,42 @@ class IdeProtocolClient {      );    } +  contentProvider: vscode.Disposable | null = null; + +  showDiff(filepath: string, replacement: string) { +    const myProvider = new (class +      implements vscode.TextDocumentContentProvider +    { +      onDidChangeEmitter = new vscode.EventEmitter<vscode.Uri>(); +      onDidChange = this.onDidChangeEmitter.event; +      provideTextDocumentContent = (uri: vscode.Uri) => { +        return replacement; +      }; +    })(); +    this.contentProvider = vscode.workspace.registerTextDocumentContentProvider( +      "continueDiff", +      myProvider +    ); + +    // Call the event fire +    const diffFilename = `continueDiff://${filepath}`; +    myProvider.onDidChangeEmitter.fire(vscode.Uri.parse(diffFilename)); + +    const leftUri = vscode.Uri.file(filepath); +    const rightUri = vscode.Uri.parse(diffFilename); +    const title = "Continue Diff"; +    vscode.commands +      .executeCommand("vscode.diff", leftUri, rightUri, title) +      .then( +        () => { +          console.log("Diff view opened successfully"); +        }, +        (error) => { +          console.error("Error opening diff view:", error); +        } +      ); +  } +    openFile(filepath: string) {      // vscode has a builtin open/get open files      openEditorAndRevealRange(filepath, undefined, vscode.ViewColumn.One); | 
