File size: 3,299 Bytes
f0dff07 9ab7a40 f0dff07 97befb1 59be267 f0dff07 c99eaf2 f0dff07 245e479 f0dff07 97befb1 59be267 e38ab6b 5a02dd0 97befb1 f0dff07 e38ab6b f0dff07 97befb1 e38ab6b 5a02dd0 0e35e11 e38ab6b 9d7e24a e38ab6b f0dff07 920b6db f0dff07 920b6db f0dff07 0e35e11 f0dff07 0e35e11 f0dff07 97befb1 f0dff07 920b6db f0dff07 920b6db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
#!/usr/bin/env python
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
from peft import PeftModel
DESCRIPTION = "# 真空ジェネレータ\n<p>Imitate 真空 (@vericava)'s posts interactively</p>"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
if torch.cuda.is_available():
my_pipeline=pipeline(
task="text-generation",
model="vericava/gpt2-medium-vericava-posts-v3",
do_sample=True,
num_beams=1,
)
@spaces.GPU
@torch.inference_mode()
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
user_input = " ".join(message.strip().split("\n"))
user_input = user_input if (
user_input.endswith("。")
or user_input.endswith("?")
or user_input.endswith("!")
or user_input.endswith("?")
or user_input.endswith("!")
) else user_input + "。"
output = my_pipeline(
user_input,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
)
print(output)
gen_text = output[len(user_input):]
gen_text = gen_text[:gen_text.find("\n")] if "\n" in gen_text else gen_text
gen_text = gen_text[:(gen_text.rfind("。") + 1)] if "。" in gen_text else gen_text
yield gen_text
demo = gr.ChatInterface(
fn=generate,
type="tuples",
additional_inputs_accordion=gr.Accordion(label="詳細設定", open=False),
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=1.0,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.5,
),
],
stop_btn=None,
examples=[
["サマリーを作る男の人,サマリーマン。"],
["やばい場所にクリティカルな配線ができてしまったので掲示した。"],
["にゃん"],
["Wikipedia の情報は入っているのかもしれない"],
],
description=DESCRIPTION,
css_paths="style.css",
fill_height=True,
)
if __name__ == "__main__":
demo.launch()
|