jljiu commited on
Commit
298e414
·
verified ·
1 Parent(s): b95f55d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -92
app.py CHANGED
@@ -1,92 +1,108 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import json
4
- import os
5
- from train import ModelTrainer
6
-
7
- class NovelAIApp:
8
- def __init__(self):
9
- self.model = None
10
- self.tokenizer = None
11
- self.trainer = None
12
-
13
- # 加载系统提示词
14
- with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
15
- self.system_prompts = json.load(f)
16
-
17
- def load_model(self, model_path):
18
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
19
- self.model = AutoModelForCausalLM.from_pretrained(
20
- model_path,
21
- load_in_8bit=True,
22
- device_map="auto"
23
- )
24
-
25
- def train_model(self, files):
26
- if not self.trainer:
27
- self.trainer = ModelTrainer(
28
- "CohereForAI/c4ai-command-r-plus-08-2024",
29
- "configs/system_prompts.json"
30
- )
31
-
32
- dataset = self.trainer.prepare_dataset(files)
33
- self.trainer.train(dataset)
34
- return "训练完成!"
35
-
36
- def generate_text(self, prompt, system_prompt_type="creative"):
37
- if not self.model:
38
- return "请先加载模型!"
39
-
40
- system_prompt = self.system_prompts.get(system_prompt_type, self.system_prompts["base_prompt"])
41
-
42
- formatted_prompt = f"""<|system|>{system_prompt}</|system|>
43
- <|user|>{prompt}</|user|>
44
- <|assistant|>"""
45
-
46
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
47
- outputs = self.model.generate(
48
- inputs["input_ids"],
49
- max_length=512,
50
- temperature=0.7,
51
- top_p=0.9,
52
- repetition_penalty=1.1,
53
- num_return_sequences=1,
54
- pad_token_id=self.tokenizer.eos_token_id
55
- )
56
-
57
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
58
-
59
- def create_interface(self):
60
- with gr.Blocks() as interface:
61
- gr.Markdown("# 风格化对话助手")
62
-
63
- with gr.Tab("模型训练"):
64
- gr.Markdown("""### 上传参考文本
65
- 上传文本文件来训练模型学习特定的语言风格。
66
- 建议上传具有鲜明语言特色的文本。""")
67
-
68
- file_output = gr.File(
69
- file_count="multiple",
70
- label="上传参考文本文件"
71
- )
72
- train_button = gr.Button("开始训练")
73
- train_output = gr.Textbox(label="训练状态")
74
-
75
- with gr.Tab("对话"):
76
- gr.Markdown("与助手进行对话,体验风格化的语言表达")
77
- style_select = gr.Dropdown(
78
- choices=["formal", "casual"],
79
- label="选择对话风格",
80
- value="formal"
81
- )
82
- chat_interface = gr.ChatInterface(
83
- fn=self.generate_text,
84
- additional_inputs=[style_select]
85
- )
86
-
87
- return interface
88
-
89
- # 创建应用实例
90
- app = NovelAIApp()
91
- interface = app.create_interface()
92
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import json
4
+ import os
5
+ from train import ModelTrainer
6
+
7
+ class NovelAIApp:
8
+ def __init__(self):
9
+ self.model = None
10
+ self.tokenizer = None
11
+ self.trainer = None
12
+
13
+ # 加载系统提示词
14
+ with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
15
+ self.system_prompts = json.load(f)
16
+
17
+ # 初始化默认的情境
18
+ self.current_mood = "暗示"
19
+
20
+ def load_model(self, model_path):
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
22
+ self.model = AutoModelForCausalLM.from_pretrained(
23
+ model_path,
24
+ load_in_8bit=True,
25
+ device_map="auto"
26
+ )
27
+
28
+ def train_model(self, files):
29
+ if not self.trainer:
30
+ self.trainer = ModelTrainer(
31
+ "CohereForAI/c4ai-command-r-plus-08-2024",
32
+ "configs/system_prompts.json"
33
+ )
34
+
35
+ dataset = self.trainer.prepare_dataset(files)
36
+ self.trainer.train(dataset)
37
+ return "训练完成!"
38
+
39
+ def generate_text(self, message, history):
40
+ """修改后的生成文本方法,适配 ChatInterface"""
41
+ if not self.model:
42
+ return "请先加载模型!"
43
+
44
+ system_prompt = self.system_prompts.get("base_prompt")
45
+
46
+ # 构建完整的对话历史
47
+ full_history = ""
48
+ for msg in history:
49
+ full_history += f"<|user|>{msg[0]}</|user|>\n<|assistant|>{msg[1]}</|assistant|>\n"
50
+
51
+ formatted_prompt = f"""<|system|>{system_prompt}</|system|>
52
+ {full_history}<|user|>{message}</|user|>
53
+ <|assistant|>"""
54
+
55
+ inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
56
+ outputs = self.model.generate(
57
+ inputs["input_ids"],
58
+ max_length=1024,
59
+ temperature=0.7,
60
+ top_p=0.9,
61
+ repetition_penalty=1.1
62
+ )
63
+
64
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+ # 提取助手的回复部分
66
+ response = response.split("<|assistant|>")[-1].strip()
67
+ return response
68
+
69
+ def create_interface(self):
70
+ """创建 Gradio 界面"""
71
+ with gr.Blocks() as interface:
72
+ gr.Markdown("# 猫娘对话助手")
73
+
74
+ with gr.Tab("模型训练"):
75
+ file_output = gr.File(
76
+ file_count="multiple",
77
+ label="上传小说文本文件"
78
+ )
79
+ train_button = gr.Button("开始训练")
80
+ train_output = gr.Textbox(label="训练状态")
81
+
82
+ train_button.click(
83
+ fn=self.train_model,
84
+ inputs=[file_output],
85
+ outputs=[train_output]
86
+ )
87
+
88
+ with gr.Tab("对话"):
89
+ chatbot = gr.ChatInterface(
90
+ self.generate_text,
91
+ chatbot=gr.Chatbot(height=400),
92
+ textbox=gr.Textbox(placeholder="请输入您的消息...", container=False),
93
+ title="与猫娘对话",
94
+ description="来和可爱的猫娘聊天吧~",
95
+ theme="soft",
96
+ examples=["今天天气真好呢", "你在做什么呢?", "要不要一起玩?"],
97
+ cache_examples=False,
98
+ retry_btn=None,
99
+ undo_btn=None,
100
+ clear_btn="清除对话"
101
+ )
102
+
103
+ return interface
104
+
105
+ # 创建应用实例
106
+ app = NovelAIApp()
107
+ interface = app.create_interface()
108
+ interface.launch(server_name="0.0.0.0", share=False)