/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the Chameleon License found in the * LICENSE file in the root directory of this source tree. */ import { useEffect, useState, useRef } from "react"; import { LexicalComposer } from "@lexical/react/LexicalComposer"; import { ContentEditable } from "@lexical/react/LexicalContentEditable"; import { HistoryPlugin } from "@lexical/react/LexicalHistoryPlugin"; import { RichTextPlugin } from "@lexical/react/LexicalRichTextPlugin"; import { OnChangePlugin } from "@lexical/react/LexicalOnChangePlugin"; import DragDropPaste from "../lexical/DragDropPastePlugin"; import { ImagesPlugin } from "../lexical/ImagesPlugin"; import { ImageNode } from "../lexical/ImageNode"; import { ReplaceContentPlugin } from "../lexical/ReplaceContentPlugin"; import LexicalErrorBoundary from "@lexical/react/LexicalErrorBoundary"; import useWebSocket, { ReadyState } from "react-use-websocket"; import { z } from "zod"; import JsonView from "react18-json-view"; import { InputRange } from "../inputs/InputRange"; import { Config } from "../../Config"; import axios from "axios"; import { useHotkeys } from "react-hotkeys-hook"; import { COMPLETE, FULL_OUTPUT, FrontendMultimodalSequencePair, GENERATE_MULTIMODAL, IMAGE, PARTIAL_OUTPUT, QUEUE_STATUS, TEXT, WSContent, WSMultimodalMessage, WSOptions, ZWSMultimodalMessage, mergeTextContent, readableWsState, } from "../../DataTypes"; import { StatusBadge, StatusCategory } from "../output/StatusBadge"; import { SettingsAdjust, Close, Idea, } from "@carbon/icons-react"; import { useAdvancedMode } from "../hooks/useAdvancedMode"; import { InputShowHide } from "../inputs/InputShowHide"; import { InputToggle } from "../inputs/InputToggle"; import Markdown from "react-markdown"; import remarkGfm from "remark-gfm"; import { EOT_TOKEN } from "../../DataTypes"; import { ImageResult } from "../output/ImageResult"; enum GenerationSocketState { Generating = "GENERATING", UserWriting = "USER_WRITING", NotReady = "NOT_READY", } function makeid(length) { let result = ""; const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; const charactersLength = characters.length; let counter = 0; while (counter < length) { result += characters.charAt(Math.floor(Math.random() * charactersLength)); counter += 1; } return result; } // Prepend an arbitrary texdt prompt to an existing list of contents export function prependTextPrompt( toPrepend: string, contents: WSContent[], ): WSContent[] { if (toPrepend.length == 0) { return contents; } const promptContent: WSContent = { content: toPrepend, content_type: TEXT, }; return [promptContent].concat(contents); } // Extract a flat list of text and image contents from the editor state export function flattenContents(obj): WSContent[] { let result: WSContent[] = []; if (!obj || !obj.children || obj.children.length === 0) return result; for (const child of obj.children) { // Only take text and image contents if (child.type === "text") { result.push({ content: child.text, content_type: TEXT }); } else if (child.type === "image") { result.push({ // TODO: Convert the src from URL to base64 image content: child.src, content_type: IMAGE, }); } const grandChildren = flattenContents(child); result = result.concat(grandChildren); } return result; } export function contentToHtml(content: WSContent, index?: number) { if (content.content_type == TEXT) { return ( <Markdown remarkPlugins={[remarkGfm]} key={`t${index}`}> {content.content} </Markdown> // <code style={{ whiteSpace: "pre-wrap" }} key={`code${index}`}> // {content.content} // </code> ); } else if (content.content_type == IMAGE) { return <ImageResult src={content.content} key={`img${index}`} />; } else { return <p key={`p${index}`}>Unknown content type</p>; } } export function GenerateMixedModal() { function Editor() { const [clientId, setClientId] = useState<string>(makeid(8)); const [generationState, setGenerationState] = useState<GenerationSocketState>(GenerationSocketState.NotReady); const [contents, setContents] = useState<WSContent[]>([]); const [partialImage, setPartialImage] = useState<string>(""); // Model hyperparams const [temp, setTemp] = useState<number>(0.7); const [topP, setTopP] = useState<number>(0.9); const [cfgImageWeight, setCfgImageWeight] = useState<number>(1.2); const [cfgTextWeight, setCfgTextWeight] = useState<number>(3.0); const [yieldEveryN, setYieldEveryN] = useState<number>(32); const [seed, setSeed] = useState<number | null>(Config.default_seed); const [maxGenTokens, setMaxGenTokens] = useState<number>(4096); const [repetitionPenalty, setRepetitionPenalty] = useState<number>(1.2); const [showSeed, setShowSeed] = useState<boolean>(true); const [numberInQueue, setNumberInQueue] = useState<number>(); const socketUrl = `${Config.ws_address}/ws/chameleon/v2/${clientId}`; // Array of text string or html string (i.e., an image) const [modelOutput, setModelOutput] = useState<Array<WSContent>>([]); const { readyState, sendJsonMessage, lastJsonMessage, getWebSocket } = useWebSocket(socketUrl, { onOpen: () => { console.log("WS Opened"); setGenerationState(GenerationSocketState.UserWriting); }, onClose: (e) => { console.log("WS Closed", e); setGenerationState(GenerationSocketState.NotReady); }, onError: (e) => { console.log("WS Error", e); setGenerationState(GenerationSocketState.NotReady); }, // TODO: Inspect error a bit shouldReconnect: (closeEvent) => true, heartbeat: false, }); function abortGeneration() { getWebSocket()?.close(); setModelOutput([]); setGenerationState(GenerationSocketState.UserWriting); setClientId(makeid(8)); } useEffect(() => { if (lastJsonMessage != null) { const maybeMessage = ZWSMultimodalMessage.safeParse(lastJsonMessage); console.log("Message", lastJsonMessage, "Parsed", maybeMessage.success); if (maybeMessage.success) { if ( maybeMessage.data.content.length != 1 && maybeMessage.data.message_type != COMPLETE ) { console.error("Too few or too many content"); } console.log("parsed message", maybeMessage); if (maybeMessage.data.message_type == PARTIAL_OUTPUT) { // Currently, the backend only sends one content piece at a time const content = maybeMessage.data.content[0]; if (content.content_type == IMAGE) { setPartialImage(content.content); } else if (content.content_type == TEXT) { setModelOutput((prev) => { return prev.concat(maybeMessage.data.content); }); } setNumberInQueue(undefined); } else if (maybeMessage.data.message_type == FULL_OUTPUT) { // Only image gives full output, text is rendered as it // comes. const content = maybeMessage.data.content[0]; if (content.content_type == IMAGE) { setPartialImage(""); setModelOutput((prev) => { console.log("Set model image output"); return prev.concat(maybeMessage.data.content); }); } } else if (maybeMessage.data.message_type == COMPLETE) { setGenerationState(GenerationSocketState.UserWriting); } else if (maybeMessage.data.message_type == QUEUE_STATUS) { console.log("Queue Status Message", maybeMessage); // expects payload to be n_requests=<number> setNumberInQueue( Number(maybeMessage.data.content[0].content.match(/\d+/g)), ); } } } else { console.log("Null message"); } }, [lastJsonMessage, setModelOutput]); const initialConfig = { namespace: "MyEditor", theme: { heading: { h1: "text-24 text-red-500", }, }, onError, nodes: [ImageNode], }; function onError(error) { console.error(error); } function Placeholder() { return ( <> <div className="absolute top-4 left-4 z-0 select-none pointer-events-none opacity-50 prose"> You can edit text and drag/paste images in the input above.<br /> It's just like writing a mini document. </div> </> ); } function onChange(editorState) { // Call toJSON on the EditorState object, which produces a serialization safe string const editorStateJSON = editorState.toJSON(); setContents(flattenContents(editorStateJSON?.root)); setExamplePrompt(null); } function onRunModelClick() { if (runButtonDisabled) return; async function prepareContent(content: WSContent): Promise<WSContent> { if (content.content_type == TEXT) { return content; } else if (content.content_type == IMAGE) { if (content.content.startsWith("http")) { const response = await fetch(content.content); const blob = await response.blob(); const reader = new FileReader(); return new Promise((resolve) => { reader.onload = (event) => { const result = event.target?.result; if (typeof result === "string") { resolve({ ...content, content: result }); } else { resolve(content); } }; reader.readAsDataURL(blob); }); } else { return content; } } else { console.error("Unknown content type"); return content; } } async function prepareAndRun() { if (contents.length != 0) { setModelOutput([]); setGenerationState(GenerationSocketState.Generating); const currentContent = await Promise.all( contents.map(prepareContent), ); let processedContents = currentContent; const suffix_tokens: Array<string> = [EOT_TOKEN]; const options: WSOptions = { message_type: GENERATE_MULTIMODAL, temp: temp, top_p: topP, cfg_image_weight: cfgImageWeight, cfg_text_weight: cfgTextWeight, repetition_penalty: repetitionPenalty, yield_every_n: yieldEveryN, max_gen_tokens: maxGenTokens, suffix_tokens: suffix_tokens, seed: seed, }; const message: WSMultimodalMessage = { message_type: GENERATE_MULTIMODAL, content: processedContents, options: options, debug_info: {}, }; setContents(processedContents); sendJsonMessage(message); } } prepareAndRun().catch(console.error); } useHotkeys("ctrl+enter, cmd+enter", () => { console.log("Run Model by hotkey"); onRunModelClick(); }); const readableSocketState = readableWsState(readyState); let socketStatus: StatusCategory = "neutral"; if (readableSocketState == "Open") { socketStatus = "success"; } else if (readableSocketState == "Closed") { socketStatus = "error"; } else if (readableSocketState == "Connecting") { socketStatus = "warning"; } else { socketStatus = "error"; } const runButtonDisabled = readyState !== ReadyState.OPEN || generationState != GenerationSocketState.UserWriting; const runButtonText = runButtonDisabled ? ( <div className="loading loading-infinity loading-lg text-neutral"></div> ) : ( <div className="flex flex-row items-center"> Run Model {/* Use the following label when hot-key is implemented <span className="flex flex-row items-center ml-2 text-[10px] text-gray-600"> <MacCommand size={12} className="inline" /> +ENTER </span> */} </div> ); const runButtonColor = runButtonDisabled ? "btn-neutral opacity-60" : "btn-success"; let uiStatus: StatusCategory = "neutral"; if (generationState == "USER_WRITING") { uiStatus = "success"; } else if (generationState == "GENERATING") { uiStatus = "info"; } else if (generationState == "NOT_READY") { uiStatus = "error"; } const [advancedMode, setAdvancedMode] = useAdvancedMode(); const [tutorialBanner, setTutorialBanner] = useState(true); const [examplePrompt, setExamplePrompt] = useState<string | null>(null); const chatRef = useRef<HTMLDivElement>(null); useEffect(() => { chatRef?.current?.scrollIntoView({ behavior: "smooth", block: "end", inline: "end", }); }, [modelOutput]); return ( <> <div className="flex-1 flex flex-col min-h-[calc(100vh-150px)] max-h-[calc(100vh-150px)]"> <div className={`flex-1 flex flex-col relative overflow-x-hidden mb-10`} > <div className={`flex-1 flex flex-row items-stretch gap-4 max-h-[calc(100vh-200px)] ${ advancedMode ? "ml-[500px]" : "ml-0" } transition-all`} > <div className="flex-1 flex flex-col relative rounded-md px-6 py-4 bg-purple-50 gap-8"> <div className="flex flex-row items-center justify-between"> <div className="prose"> <h4>Input</h4> </div> <SettingsAdjust onClick={() => setAdvancedMode(!advancedMode)} size={24} className="hover:fill-primary cursor-pointer" /> </div> <div className="flex flex-col flex-1 items-stretch overflow-y-scroll h-full"> <LexicalComposer initialConfig={initialConfig}> {/* Toolbar on top, if needed */} {/* <ToolbarPlugin /> */} <div className="relative flex-1"> <RichTextPlugin contentEditable={ <ContentEditable className={`relative bg-white ${ tutorialBanner ? "rounded-t-md" : "rounded-md" } block p-4 leading-5 text-md h-full`} /> } placeholder={<Placeholder />} ErrorBoundary={LexicalErrorBoundary} /> </div> <DragDropPaste /> <HistoryPlugin /> <ImagesPlugin /> <OnChangePlugin onChange={onChange} /> <ReplaceContentPlugin payload={examplePrompt} /> </LexicalComposer> </div> <div className="flex flex-row items-center justify-between my-4 gap-2"> <div className="flex flex-row items-center gap-2"> <button onClick={onRunModelClick} disabled={runButtonDisabled} className={"btn" + " " + runButtonColor} > {runButtonText} </button> <button onClick={abortGeneration} className="btn btn-ghost"> Abort </button> </div> {!tutorialBanner && ( <button className="btn btn-circle bg-white border-none" onClick={() => setTutorialBanner(true)} > <Idea size={24} /> </button> )} </div> </div> {/* Results */} <div className="flex-1 flex flex-col bg-gray-50 rounded-md overflow-x-hidden px-6 py-4 max-h-[calc(100vh-200px)] "> <div className="prose"> <h4>Output</h4> </div> <div className="mt-6 overflow-scroll flex-1 leading-relaxed markdown"> {numberInQueue && numberInQueue > 0 && ( <div role="alert" className="p-4 mb-4 text-med rounded-lg bg-purple-50" > There are {numberInQueue} other users in the queue for generation. </div> )} <div className="prose leading-snug"> {mergeTextContent(modelOutput).map(contentToHtml)} </div> <ImageResult src={partialImage} completed={false} /> </div> </div> </div> {/* Side panel */} <div className={`absolute top-0 bottom-11 w-[490px] max-h-[calc(100vh-200px)] rounded-md px-6 py-4 overflow-y-scroll ${ advancedMode ? "left-0" : "left-[-500px]" } bg-gray-100 transition-all`} > <div className="prose flex flex-row items-center justify-between"> <h3>Advanced settings</h3> <Close size={32} className="cursor-pointer hover:fill-primary" onClick={() => setAdvancedMode(false)} /> </div> <InputRange value={temp} onValueChange={setTemp} label="Temperature" min={0.01} step={0.01} max={1} /> <InputRange value={topP} onValueChange={setTopP} label="Top P" min={0.01} step={0.01} max={1} /> <InputRange value={maxGenTokens} onValueChange={setMaxGenTokens} label="Max Gen Tokens" integerOnly step={1} min={1} max={4096} /> <InputRange value={repetitionPenalty} onValueChange={setRepetitionPenalty} label="Text Repetition Penalty" min={0} max={10} /> <InputRange value={cfgImageWeight} onValueChange={setCfgImageWeight} label="CFG Image Weight" min={0.01} max={10} /> <InputRange value={cfgTextWeight} onValueChange={setCfgTextWeight} label="CFG Text Weight" min={0.01} max={10} /> <InputToggle label="Set seed" value={showSeed} onValueChange={(checked) => { setShowSeed(checked); }} /> {showSeed && seed != null && ( <InputRange value={seed} step={1} integerOnly={true} onValueChange={setSeed} label="Seed" min={1} max={1000} /> )} {/* Input preview */} <InputShowHide labelShow="Show input data" labelHide="Hide input data" > <div className="overflow-auto bg-white text-xs font-mono p-4 rounded-md mt-4"> <JsonView src={contents} collapsed={({ node, indexOrName, depth, size }) => indexOrName !== "data" && depth > 3 } /> </div> </InputShowHide> </div> </div> <div className="absolute bottom-0 left-20 right-20 bg-white flex flex-row items-center gap-4 text-xs h-10"> <StatusBadge label="Connection" category={socketStatus} status={readableSocketState} /> <StatusBadge label="UI" category={uiStatus} status={generationState} /> </div> </div> </> ); } return <Editor />; }