#!/usr/bin/env python import os from collections.abc import Iterator from threading import Thread import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline from peft import PeftModel DESCRIPTION = "# 真空ジェネレータ\n
Imitate 真空 (@vericava)'s posts interactively
" if not torch.cuda.is_available(): DESCRIPTION += "\nRunning on CPU 🥶 This demo does not work on CPU.
" MAX_MAX_NEW_TOKENS = 768 DEFAULT_MAX_NEW_TOKENS = 512 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768")) if torch.cuda.is_available(): my_pipeline=pipeline( task="text-generation", model="vericava/gpt2-medium-vericava-posts-v3", do_sample=True, num_beams=1, ) @spaces.GPU @torch.inference_mode() def generate( message: str, chat_history, max_new_tokens: int = 1024, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.0, ) -> Iterator[str]: user_input = " ".join(message.strip().split("\n")) user_input = user_input if ( user_input.endswith("。") or user_input.endswith("?") or user_input.endswith("!") or user_input.endswith("?") or user_input.endswith("!") ) else user_input + "。" output = my_pipeline( user_input, temperature=temperature, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, )[-1]["generated_text"] print(output) gen_text = output[len(user_input):] gen_text = gen_text[:gen_text.find("\n")] if "\n" in gen_text else gen_text gen_text = gen_text[:(gen_text.rfind("。") + 1)] if "。" in gen_text else gen_text yield gen_text demo = gr.ChatInterface( fn=generate, type="messages", additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False), additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=1.0, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.95, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.5, ), ], stop_btn=None, examples=[ ["サマリーを作る男の人,サマリーマン。"], ["やばい場所にクリティカルな配線ができてしまったので掲示した。"], ["にゃん"], ["Wikipedia の情報は入っているのかもしれない"], ], description=DESCRIPTION, css_paths="style.css", fill_height=True, ) if __name__ == "__main__": demo.launch()