File size: 1,856 Bytes
acdbcde 8483b93 a3a82f7 8483b93 64bdad5 8483b93 68c99bd 8483b93 64bdad5 8483b93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
print("start to run")
import streamlit as st
import os
os.system("pip install torch transformers sentencepiece accelerate torch.utils torchvision torch")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
print("[code] All module has imported.")
# モデルとトークナイザの初期化
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
print("[code] model loaded")
# 推論用の関数
def generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty):
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output
# Streamlitアプリの設定
st.title("Causal Language Modeling")
st.write("AIによる文章生成")
# パラメータの入力
input_text = st.text_area("入力テキスト")
max_new_tokens = st.slider("生成する最大トークン数", min_value=1, max_value=512, value=64)
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7)
top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9)
repetition_penalty = st.slider("Repetition Penalty", min_value=0.1, max_value=2.0, value=1.05)
# 推論結果の表示
if st.button("生成"):
output = generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty)
st.write("生成されたテキスト:")
st.write(output)
|