File size: 1,856 Bytes
acdbcde
8483b93
 
a3a82f7
8483b93
 
64bdad5
8483b93
68c99bd
 
8483b93
64bdad5
8483b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
print("start to run")
import streamlit as st
import os
os.system("pip install torch transformers sentencepiece accelerate torch.utils torchvision torch")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
print("[code] All module has imported.")
# モデルとトークナイザの初期化

model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
print("[code] model loaded")
# 推論用の関数
def generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty):
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        tokens = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.pad_token_id,
        )
    output = tokenizer.decode(tokens[0], skip_special_tokens=True)
    return output

# Streamlitアプリの設定
st.title("Causal Language Modeling")
st.write("AIによる文章生成")

# パラメータの入力
input_text = st.text_area("入力テキスト")
max_new_tokens = st.slider("生成する最大トークン数", min_value=1, max_value=512, value=64)
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.7)
top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9)
repetition_penalty = st.slider("Repetition Penalty", min_value=0.1, max_value=2.0, value=1.05)

# 推論結果の表示
if st.button("生成"):
    output = generate_text(input_text, max_new_tokens, temperature, top_p, repetition_penalty)
    st.write("生成されたテキスト:")
    st.write(output)