funkykong888 commited on
Commit
7a648cd
·
verified ·
1 Parent(s): 7777037

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -17
app.py CHANGED
@@ -1,22 +1,66 @@
1
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
2
  import gradio as gr
3
 
4
- # モデルとトークナイザーのロード
 
 
 
5
  model_name = "inu-ai/dolly-japanese-gpt-1b"
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
-
9
- # チャットインターフェースの作成
10
- def generate_text(user_input):
11
- inputs = tokenizer(user_input, return_tensors="pt")
12
- outputs = model.generate(**inputs)
13
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
14
-
15
- iface = gr.Interface(
16
- fn=generate_text,
17
- inputs="text",
18
- outputs="text",
19
- title="Dolly Japanese GPT-1b Chatbot"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  )
21
 
22
- iface.launch()
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
 
5
+ # デバイス設定
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ # モデルとトークナイザーの読み込み
9
  model_name = "inu-ai/dolly-japanese-gpt-1b"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
12
+
13
+ # チャットボット関数
14
+ def chatbot(input_text, chat_history):
15
+ if chat_history is None:
16
+ chat_history = []
17
+
18
+ # 入力の前処理
19
+ new_input = "ユーザー: " + input_text + " ボット:"
20
+ print(f"Input to the model: {new_input}") # デバッグ用
21
+
22
+ # トークナイズ
23
+ inputs = tokenizer(new_input, return_tensors="pt", padding=True).to(device)
24
+ print(f"Tokenized input: {inputs}") # トークン化された入力の確認
25
+
26
+ # 応答の生成
27
+ outputs = model.generate(
28
+ inputs.input_ids,
29
+ attention_mask=inputs.attention_mask,
30
+ max_length=512,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ do_sample=True,
33
+ top_p=0.95, # 生成におけるランダム性を調整
34
+ temperature=0.7 # ランダム性の調整
35
+ )
36
+
37
+ # 応答のデコード(skip_special_tokens=Falseにして特殊トークンをデバッグ)
38
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
39
+ print(f"Generated response (with special tokens): {response}") # 生成された応答の確認
40
+
41
+ # 応答の整形
42
+ response = response.split("ボット:")[-1].strip()
43
+
44
+ # チャット履歴に追加(辞書形式に変換)
45
+ chat_history.append({"role": "user", "content": input_text})
46
+ chat_history.append({"role": "assistant", "content": response})
47
+
48
+ return chat_history, chat_history
49
+
50
+ # Gradioインターフェース設定
51
+ interface = gr.Interface(
52
+ fn=chatbot,
53
+ inputs=[
54
+ gr.Textbox(label="ユーザー入力", placeholder="ここに入力してください"),
55
+ gr.State() # チャット履歴
56
+ ],
57
+ outputs=[
58
+ gr.Chatbot(label="ボット応答", type="messages"), # 出力形式をmessagesに指定
59
+ gr.State() # チャット履歴の状態
60
+ ],
61
+ title="日本語チャットボット",
62
+ description="inu-ai/dolly-japanese-gpt-1b を使用した日本語チャットボットです。",
63
  )
64
 
65
+ # アプリの起動
66
+ interface.launch()