Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import random | |
| import string | |
| import requests | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import StableDiffusionPipeline | |
| from pingpong import PingPong | |
| from pingpong.pingpong import PPManager | |
| from pingpong.pingpong import PromptFmt | |
| from pingpong.pingpong import UIFmt | |
| from pingpong.gradio import GradioChatUIFmt | |
| from fpdf import FPDF | |
| class PDF(FPDF): | |
| def header(self): | |
| # Arial bold 15 | |
| self.set_font('Arial', 'B', 15) | |
| # Calculate width of title and position | |
| w = self.get_string_width(self.title) + 6 | |
| self.set_x((210 - w) / 2) | |
| # Colors of frame, background and text | |
| self.set_draw_color(255, 255, 255) | |
| self.set_fill_color(255, 255, 255) | |
| # self.set_text_color(220, 50, 50) | |
| # Thickness of frame (1 mm) | |
| self.set_line_width(1) | |
| # Title | |
| self.cell(w, 9, self.title, 1, 1, 'C', 1) | |
| # Line break | |
| self.ln(10) | |
| if self.art is not None: | |
| self.image(self.art, x=self.w/2.0-25, w=50) | |
| self.ln(10) | |
| def footer(self): | |
| # Position at 1.5 cm from bottom | |
| self.set_y(-15) | |
| # Arial italic 8 | |
| self.set_font('Arial', 'I', 8) | |
| # Text color in gray | |
| self.set_text_color(128) | |
| # Page number | |
| self.cell(0, 10, 'Page ' + str(self.page_no()), 0, 0, 'C') | |
| def chapter_title(self, num, label): | |
| # Arial 12 | |
| self.set_font('Arial', '', 12) | |
| # Background color | |
| self.set_fill_color(200, 220, 255) | |
| # Title | |
| self.cell(0, 6, 'Chapter %d : %s' % (num, label), 0, 1, 'L', 1) | |
| # Line break | |
| self.ln(4) | |
| def chapter_body(self, content): | |
| # Times 12 | |
| self.set_font('Times', '', 12) | |
| # Output justified text | |
| self.multi_cell(0, 5, content) | |
| # Line break | |
| self.ln() | |
| # Mention in italics | |
| self.set_font('', 'I') | |
| def print_chapter(self, content): | |
| self.add_page() | |
| self.chapter_body(content) | |
| class LLaMA2ChatPromptFmt(PromptFmt): | |
| def ctx(cls, context): | |
| if context is None or context == "": | |
| return "" | |
| else: | |
| return f"""<<SYS>> | |
| {context} | |
| <</SYS>> | |
| """ | |
| def prompt(cls, pingpong, truncate_size): | |
| ping = pingpong.ping[:truncate_size] | |
| pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size] | |
| return f"""[INST] {ping} [/INST] {pong}""" | |
| class LLaMA2ChatPPManager(PPManager): | |
| def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=LLaMA2ChatPromptFmt, truncate_size: int=None): | |
| if to_idx == -1 or to_idx >= len(self.pingpongs): | |
| to_idx = len(self.pingpongs) | |
| results = fmt.ctx(self.ctx) | |
| for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]): | |
| results += fmt.prompt(pingpong, truncate_size=truncate_size) | |
| return results | |
| class GradioLLaMA2ChatPPManager(LLaMA2ChatPPManager): | |
| def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): | |
| if to_idx == -1 or to_idx >= len(self.pingpongs): | |
| to_idx = len(self.pingpongs) | |
| results = [] | |
| for pingpong in self.pingpongs[from_idx:to_idx]: | |
| results.append(fmt.ui(pingpong)) | |
| return results | |
| TOKEN = os.getenv('HF_TOKEN') | |
| MODEL_ID = 'meta-llama/Llama-2-70b-chat-hf' | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "nota-ai/bk-sdm-small", torch_dtype=torch.float16 | |
| ) | |
| STYLES = """ | |
| .left-panel { | |
| min-width: min(290px, 100%) !important; | |
| } | |
| .small-big { | |
| font-size: 12pt !important; | |
| } | |
| .small-big-textarea > label > textarea { | |
| font-size: 12pt !important; | |
| } | |
| .highlighted-text { | |
| background: yellow; | |
| overflow-wrap: break-word; | |
| } | |
| .no-gap { | |
| gap: 0px !important; | |
| } | |
| .group-border { | |
| padding: 10px; | |
| border-width: 1px; | |
| border-radius: 10px; | |
| border-color: gray; | |
| border-style: dashed; | |
| } | |
| .control-label-font { | |
| font-size: 13pt !important; | |
| } | |
| .control-button { | |
| background: none !important; | |
| border-color: #69ade2 !important; | |
| border-width: 2px !important; | |
| color: #69ade2 !important; | |
| } | |
| .center { | |
| text-align: center; | |
| } | |
| .right { | |
| text-align: right; | |
| } | |
| .no-label { | |
| padding: 0px !important; | |
| } | |
| .no-label > label > span { | |
| display: none; | |
| } | |
| .no-label-chatbot { | |
| border: none !important; | |
| box-shadow: none !important; | |
| height: 520px !important; | |
| } | |
| .no-label-chatbot > div > div:nth-child(1) { | |
| display: none; | |
| } | |
| .no-label-image > div:nth-child(2) { | |
| display: none; | |
| } | |
| .left-margin-30 { | |
| padding-left: 30px !important; | |
| } | |
| .left { | |
| text-align: left !important; | |
| } | |
| .alt-button { | |
| color: gray !important; | |
| border-width: 1px !important; | |
| background: none !important; | |
| border-color: gray !important; | |
| text-align: justify !important; | |
| } | |
| .white-text { | |
| color: #000 !important; | |
| } | |
| """ | |
| def id_generator(size=6, chars=string.ascii_uppercase + string.digits): | |
| return ''.join(random.choice(chars) for _ in range(size)) | |
| def get_new_ppm(ping): | |
| ppm = LLaMA2ChatPPManager() | |
| ppm.ctx = """\ | |
| You are a helpful, respectful and honest writing helper. Always write stories that suites to query. | |
| You DO NOT give explanation but just stories. For instance, do not say such as "Sure! Here's a short paragraph to start a short story:""" | |
| ppm.add_pingpong(PingPong(ping, '')) | |
| return ppm | |
| def get_new_ppm_for_chat(): | |
| ppm = GradioLLaMA2ChatPPManager() | |
| return ppm | |
| def gen_text(prompt, hf_model='meta-llama/Llama-2-70b-chat-hf', hf_token=None, parameters=None): | |
| if hf_token is None: | |
| raise ValueError("Hugging Face Token is not set") | |
| if parameters is None: | |
| parameters = { | |
| 'max_new_tokens': 512, | |
| 'do_sample': True, | |
| 'return_full_text': False, | |
| 'temperature': 1.0, | |
| 'top_k': 50, | |
| # 'top_p': 1.0, | |
| 'repetition_penalty': 1.2 | |
| } | |
| url = f'https://api-inference.huggingface.co/models/{hf_model}' | |
| headers={ | |
| 'Authorization': f'Bearer {hf_token}', | |
| 'Content-type': 'application/json' | |
| } | |
| data = { | |
| 'inputs': prompt, | |
| 'stream': False, | |
| 'options': { | |
| 'use_cache': False, | |
| }, | |
| 'parameters': parameters | |
| } | |
| r = requests.post( | |
| url, | |
| headers=headers, | |
| data=json.dumps(data) | |
| ) | |
| if r.reason != 'OK': | |
| raise ValueError("Response other than 200") | |
| return json.loads(r.content.decode("utf-8"))[0]['generated_text'] | |
| def gen_art(editor, cover_art_image, gen_cover_art_prompt): | |
| if gen_cover_art_prompt.strip() == "": | |
| ppm = get_new_ppm(f"""describe the story below as a movie poster. give me the caption ONLY. | |
| -------------------------------- | |
| {editor}""") | |
| cover_art_prompt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
| return [ | |
| cover_art_image, | |
| cover_art_prompt | |
| ] | |
| else: | |
| global pipe | |
| pipe = pipe.to("cuda") | |
| return [ | |
| pipe(gen_cover_art_prompt).images[0], | |
| gen_cover_art_prompt | |
| ] | |
| def generate_pdf(title, editor, concept_art): | |
| tmp_filename = id_generator() | |
| if concept_art is not None: | |
| im = Image.fromarray(concept_art) | |
| im.save(f"{tmp_filename}.png") | |
| pdf = PDF() | |
| pdf.title = "Untitled" if title.strip() == "" else title | |
| pdf.art = None if concept_art is None else f"{tmp_filename}.png" | |
| pdf.print_chapter(editor) | |
| pdf.output(f'{tmp_filename}.pdf', 'F') | |
| return ( | |
| gr.update(value=f'{tmp_filename}.pdf', visible=True), | |
| " " | |
| ) | |
| def select(editor, evt: gr.SelectData): | |
| return [ | |
| evt.value, | |
| evt.index[0], | |
| evt.index[1] | |
| ] | |
| def get_gen_txt(title, editor, prompt, only_gen_text=False): | |
| if editor.strip() == '': | |
| ppm = get_new_ppm(f'Write a short paragraph to start a short story titled "{title}" for me') | |
| else: | |
| ppm = get_new_ppm(f"""{prompt} | |
| -------------------------------- | |
| {editor}""") | |
| try: | |
| txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
| if only_gen_text: | |
| return txt + "\n\n" | |
| else: | |
| return editor + txt + "\n\n" | |
| except ValueError as e: | |
| print(f"something went wrong - {e}") | |
| return editor | |
| def gen_txt(title, editor, prompt): | |
| return [ | |
| get_gen_txt(title, editor, "Write the next paragraph based on the following stories so far." if prompt.strip() == "" else prompt), | |
| 0, | |
| gr.update(interactive=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ] | |
| def chat_gen(editor, chat_txt, chatbot, ppm, regen=False): | |
| ppm.ctx = f"""\ | |
| You are a helpful, respectful and honest assistant. | |
| you must consider multi-turn conversations. | |
| Answer to questions based on the written stories so far as below | |
| ---------------- | |
| {editor} | |
| """ | |
| if regen: | |
| last_pingpong = ppm.pop_pingpong() | |
| chat_txt = last_pingpong.ping | |
| ppm.add_pingpong(PingPong(chat_txt, '')) | |
| try: | |
| txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
| ppm.add_pong(txt) | |
| except ValueError as e: | |
| print(f"something went wrong - {e}") | |
| return [ | |
| "", | |
| ppm.build_uis(), | |
| ppm | |
| ] | |
| def chat(editor, chat_txt, chatbot, ppm): | |
| return chat_gen(editor, chat_txt, chatbot, ppm, regen=False) | |
| def regen_chat(editor, chat_txt, chatbot, ppm): | |
| return chat_gen(editor, chat_txt, chatbot, ppm, regen=True) | |
| def get_new_ppm_for_range(): | |
| ppm = LLaMA2ChatPPManager() | |
| ppm.ctx = """\ | |
| You are a helpful, respectful and honest writing helper. Always write text that suites to query. | |
| You DO NOT give explanation but just stories. DO NOT say such as 'Sure! Here's a short paragraph to start a short story:' or 'Sure, here is a revised version of ....:' | |
| """ | |
| return ppm | |
| def replace_sel(editor, replace_type, selected_text, sel_index_from, sel_index_to): | |
| ppm = get_new_ppm_for_range() | |
| ping = f"""replace {selected_text} in a single {replace_type} based on the story below | |
| ---------------- | |
| {editor} | |
| """ | |
| ppm.add_pingpong(PingPong(ping, '')) | |
| try: | |
| txt = gen_text(ppm.build_prompts(), hf_model=MODEL_ID, hf_token=TOKEN) | |
| ppm.add_pong(txt) | |
| except ValueError as e: | |
| print(f"something went wrong - {e}") | |
| return [ | |
| f"{editor[:sel_index_from]} {txt} {editor[sel_index_to:]}", | |
| "", | |
| 0, | |
| 0 | |
| ] | |
| def gen_alt(title, editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3): | |
| if num_enabled_alts < 3: | |
| gen_txt = get_gen_txt(title, editor, "Write the next paragraph based on the following stories so far.", only_gen_text=True) | |
| return [ | |
| min(num_enabled_alts+1, 3), | |
| gr.update(interactive=False if num_enabled_alts >=2 else True), | |
| gr.update(visible=True if num_enabled_alts >=0 else False), | |
| gr.update(value=gen_txt if num_enabled_alts == 0 else alt_btn1), | |
| gr.update(visible=True if num_enabled_alts >=1 else False), | |
| gr.update(value=gen_txt if num_enabled_alts == 1 else alt_btn2), | |
| gr.update(visible=True if num_enabled_alts >=2 else False), | |
| gr.update(value=gen_txt if num_enabled_alts == 2 else alt_btn3), | |
| " ", | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ] | |
| def fill_with_gen(alt_txt, editor): | |
| return [ | |
| editor + alt_txt, | |
| 0, | |
| gr.update(interactive=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False) | |
| ] | |
| with gr.Blocks(css=STYLES) as demo: | |
| num_enabled_alts = gr.State(0) | |
| sel_index_from = gr.State(0) | |
| sel_index_to = gr.State(0) | |
| chat_history = gr.State(get_new_ppm_for_chat()) | |
| gr.Markdown("# Co-writing with AI", elem_classes=['center']) | |
| gr.Markdown( | |
| "This application is designed for you to collaborate with LLM to co-write stories. It is inspired by [Wordcraft project](https://wordcraft-writers-workshop.appspot.com/) from Google's PAIR and Magenta teams. " | |
| "This application built on [Gradio](https://www.gradio.app), and the underlying text generation is powered by [Hugging Face Inference API](https://huggingface.co/inference-api). The text generation model might" | |
| "be changed over time, but [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) is selected for now. On the `exporting` tab, you can generate cover art for your story. You can " | |
| "design your own prompt to generate the cover art, or you can let the LLaMA2 to generate one for you. Currently, the quantized Stable Diffusion model `nota-ai/bk-sdm-small` by [Nota AI](https://www.nota.ai/) is " | |
| "used.", | |
| elem_classes=['center', 'small-big']) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| editor = gr.Textbox(lines=32, max_lines=32, elem_classes=['no-label', 'small-big-textarea']) | |
| word_counter = gr.Markdown("0 words", elem_classes=['right']) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Control"): | |
| with gr.Column(elem_classes=['group-border']): | |
| gr.Markdown('### title') | |
| title = gr.Textbox("pokemon training story", elem_classes=['no-label']) | |
| with gr.Column(elem_classes=['group-border']): | |
| with gr.Column(): | |
| gr.Markdown("For instant generation and concatenation, use `generate text` button. " | |
| "Want to explore alternative choices? use `generate alternatives` button.") | |
| with gr.Accordion("longer guideline", open=False): | |
| gr.Markdown("`generate text` button generate continued text and attach it to the end. " | |
| "on the other hand, `generate alternatives` button generate alternate texts " | |
| "up to 3 and let you choose one of them. In both cases, **Write the next paragraph based on " | |
| "the following stories so far.** is the default prompt. If you want to try your own designed " | |
| "prompt, enter it in the textbox below.") | |
| prompt = gr.Textbox(placeholder="design your own prompt", elem_classes=['no-label']) | |
| with gr.Row(): | |
| gen_btn = gr.Button("generate text", elem_classes=['control-label-font', 'control-button']) | |
| gen_alt_btn = gr.Button("generate alternatives", elem_classes=['control-label-font', 'control-button']) | |
| with gr.Column(): | |
| with gr.Row(visible=False) as first_alt: | |
| gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
| alt_btn1 = gr.Button("Alternative 1", elem_classes=['alt-button'], scale=8) | |
| with gr.Row(visible=False) as second_alt: | |
| gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
| alt_btn2 = gr.Button("Alternative 2", elem_classes=['alt-button'], scale=8) | |
| with gr.Row(visible=False) as third_alt: | |
| gr.Markdown("↳", scale=1, elem_classes=['wrap']) | |
| alt_btn3 = gr.Button("Alternative 3", elem_classes=['alt-button'], scale=8) | |
| with gr.Column(elem_classes=['group-border']): | |
| with gr.Row(): | |
| selected_text = gr.Markdown("Selected text will be displayed in this area", elem_classes=['highlighted-text']) | |
| with gr.Row(): | |
| with gr.Column(elem_classes=['no-gap']): | |
| replace_sel_btn = gr.Button("replace selection", elem_classes=['control-label-font', 'control-button']) | |
| replace_type = gr.Dropdown(choices=['word', 'sentense', 'phrase', 'paragraph'], value='sentense', interactive=True, elem_classes=['no-label']) | |
| with gr.Tab("Chatting"): | |
| chatbot = gr.Chatbot([], elem_classes=['no-label-chatbot']) | |
| chat_txt = gr.Textbox(placeholder="enter question", elem_classes=['no-label']) | |
| with gr.Row(): | |
| clear_btn = gr.Button("clear", elem_classes=['control-label-font', 'control-button']) | |
| regen_btn = gr.Button("regenerate", elem_classes=['control-label-font', 'control-button']) | |
| with gr.Tab("Exporting"): | |
| gr.Markdown("generate cover art with [`nota-ai/bk-sdm-small`](https://huggingface.co/nota-ai/bk-sdm-small) model. " | |
| "design your own prompt in the textbox below, or just hit 'generate prompt for cover art` button. LLaMA2 " | |
| "model will suggest a prompt for you based on your story.") | |
| cover_art = gr.Image(interactive=False, elem_classes=['no-label-image']) | |
| gen_cover_art_prompt = gr.Textbox(lines=5, max_lines=5, elem_classes=['no-label']) | |
| # toggle between "generate prompt for cover art" and "generate cover art" | |
| gen_cover_art_btn = gr.Button("generate prompt for cover art", elem_classes=['control-label-font', 'control-button']) | |
| gen_pdf_btn = gr.Button("export as PDF", elem_classes=['control-label-font', 'control-button']) | |
| pdf_file = gr.File(visible=False) | |
| progress_bar = gr.Textbox(elem_classes=['no-label']) | |
| gen_pdf_btn.click( | |
| lambda t, e, c: generate_pdf(t, e, c), | |
| inputs=[title, editor, cover_art], | |
| outputs=[pdf_file, progress_bar] | |
| ) | |
| gen_cover_art_btn.click( | |
| gen_art, | |
| inputs=[editor, cover_art, gen_cover_art_prompt], | |
| outputs=[cover_art, gen_cover_art_prompt] | |
| ) | |
| gen_cover_art_prompt.change( | |
| fn=None, | |
| inputs=[gen_cover_art_prompt], | |
| outputs=[gen_cover_art_btn], | |
| _js="(t) => t.trim() == '' ? 'generate prompt for cover art' : 'generate cover art'" | |
| ) | |
| editor.change( | |
| fn=None, | |
| inputs=[editor], | |
| outputs=[word_counter, selected_text], | |
| _js="(e) => [e.split(/\s+/).length, '']" | |
| ) | |
| editor.select( | |
| fn=select, | |
| inputs=[editor], | |
| outputs=[selected_text, sel_index_from, sel_index_to], | |
| show_progress='minimal' | |
| ) | |
| gen_btn.click( | |
| lambda: ( | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ), | |
| inputs=None, | |
| outputs=[gen_btn, gen_alt_btn, replace_sel_btn] | |
| ).then( | |
| fn=gen_txt, | |
| inputs=[title, editor, prompt], | |
| outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt, gen_btn, replace_sel_btn] | |
| ) | |
| gen_alt_btn.click( | |
| lambda: ( | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| ), | |
| inputs=None, | |
| outputs=[gen_btn, gen_alt_btn, replace_sel_btn] | |
| ).then( | |
| fn=gen_alt, | |
| inputs=[title, editor, num_enabled_alts, alt_btn1, alt_btn2, alt_btn3], | |
| outputs=[num_enabled_alts, gen_alt_btn, first_alt, alt_btn1, second_alt, alt_btn2, third_alt, alt_btn3, progress_bar, gen_btn, replace_sel_btn], | |
| ) | |
| alt_btn1.click( | |
| fn=fill_with_gen, | |
| inputs=[alt_btn1, editor], | |
| outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
| ) | |
| alt_btn2.click( | |
| fn=fill_with_gen, | |
| inputs=[alt_btn2, editor], | |
| outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
| ) | |
| alt_btn3.click( | |
| fn=fill_with_gen, | |
| inputs=[alt_btn3, editor], | |
| outputs=[editor, num_enabled_alts, gen_alt_btn, first_alt, second_alt, third_alt] | |
| ) | |
| replace_sel_btn.click( | |
| fn=replace_sel, | |
| inputs=[editor, replace_type, selected_text, sel_index_from, sel_index_to], | |
| outputs=[editor, selected_text, sel_index_from, sel_index_to], | |
| show_progress='minimal' | |
| ) | |
| chat_txt.submit( | |
| fn=chat, | |
| inputs=[editor, chat_txt, chatbot, chat_history], | |
| outputs=[chat_txt, chatbot, chat_history] | |
| ) | |
| regen_btn.click( | |
| fn=regen_chat, | |
| inputs=[editor, chat_txt, chatbot, chat_history], | |
| outputs=[chat_txt, chatbot, chat_history] | |
| ) | |
| demo.launch() |