File size: 3,728 Bytes
298e414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3423340
 
 
 
298e414
 
3423340
298e414
 
 
 
 
 
 
ca56c80
298e414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c265738
298e414
 
 
 
 
46fbd30
298e414
 
 
 
 
 
 
46fbd30
 
 
 
 
503cd4f
46fbd30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import os
from train import ModelTrainer

class NovelAIApp:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.trainer = None
        
        # 加载系统提示词
        with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
            self.system_prompts = json.load(f)
        
        # 初始化默认的情境
        self.current_mood = "暗示"

    def load_model(self, model_path):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            load_in_8bit=True,
            device_map="auto"
        )

    def train_model(self, files):
        if not self.trainer:
            self.trainer = ModelTrainer(
                "THUDM/chatglm2-6b",
                "configs/system_prompts.json"
            )
        
        dataset = self.trainer.prepare_dataset(files)
        self.trainer.train(dataset)
        return "训练完成!"

    def generate_text(self, message, history):
        """修改后的生成文本方法,适配 ChatInterface"""
        if not self.model:
            return "请先加载模型!"
            
        system_prompt = self.system_prompts.get("base_prompt")
        
        # 构建完整的对话历史
        full_history = ""
        for msg in history:
            full_history += f"<|user|>{msg[0]}</|user|>\n<|assistant|>{msg[1]}</|assistant|>\n"
        
        formatted_prompt = f"""<|system|>{system_prompt}</|system|>
{full_history}<|user|>{message}</|user|>
<|assistant|>"""

        inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
        outputs = self.model.generate(
            inputs["input_ids"],
            max_length=1024,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1
        )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # 提取助手的回复部分
        response = response.split("<|assistant|>")[-1].strip()
        return response

    def create_interface(self):
        """创建 Gradio 界面"""
        with gr.Blocks() as interface:
            gr.Markdown("# 猫娘对话助手")
            
            with gr.Tab("模型训练"):
                file_output = gr.File(
                    file_count="multiple",
                    label="上传小说文本文件"
                )
                train_button = gr.Button("开始训练")
                train_output = gr.Textbox(label="训练状态")
                
                train_button.click(
                    fn=self.train_model,
                    inputs=[file_output],
                    outputs=[train_output]
                )
            
            with gr.Tab("对话"):
                chatbot = gr.ChatInterface(
                    fn=self.generate_text,
                    title="与猫娘对话",
                    description="来和可爱的猫娘聊天吧~",
                    theme="soft",
                    examples=["今天天气真好呢", "你在做什么呢?", "要不要一起玩?"],
                    cache_examples=False,
                    type="messages"
                )

        return interface

# 创建应用实例
app = NovelAIApp()
interface = app.create_interface()

# 修改 launch 参数
interface.launch(
    server_name="0.0.0.0",  # 允许外部访问
    share=True,             # 创建公共链接
    ssl_verify=False        # 禁用 SSL 验证
)