nekoa / app.py
jljiu's picture
Update app.py
503cd4f verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import os
from train import ModelTrainer
class NovelAIApp:
def __init__(self):
self.model = None
self.tokenizer = None
self.trainer = None
# 加载系统提示词
with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
self.system_prompts = json.load(f)
# 初始化默认的情境
self.current_mood = "暗示"
def load_model(self, model_path):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
load_in_8bit=True,
device_map="auto"
)
def train_model(self, files):
if not self.trainer:
self.trainer = ModelTrainer(
"THUDM/chatglm2-6b",
"configs/system_prompts.json"
)
dataset = self.trainer.prepare_dataset(files)
self.trainer.train(dataset)
return "训练完成!"
def generate_text(self, message, history):
"""修改后的生成文本方法,适配 ChatInterface"""
if not self.model:
return "请先加载模型!"
system_prompt = self.system_prompts.get("base_prompt")
# 构建完整的对话历史
full_history = ""
for msg in history:
full_history += f"<|user|>{msg[0]}</|user|>\n<|assistant|>{msg[1]}</|assistant|>\n"
formatted_prompt = f"""<|system|>{system_prompt}</|system|>
{full_history}<|user|>{message}</|user|>
<|assistant|>"""
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
outputs = self.model.generate(
inputs["input_ids"],
max_length=1024,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取助手的回复部分
response = response.split("<|assistant|>")[-1].strip()
return response
def create_interface(self):
"""创建 Gradio 界面"""
with gr.Blocks() as interface:
gr.Markdown("# 猫娘对话助手")
with gr.Tab("模型训练"):
file_output = gr.File(
file_count="multiple",
label="上传小说文本文件"
)
train_button = gr.Button("开始训练")
train_output = gr.Textbox(label="训练状态")
train_button.click(
fn=self.train_model,
inputs=[file_output],
outputs=[train_output]
)
with gr.Tab("对话"):
chatbot = gr.ChatInterface(
fn=self.generate_text,
title="与猫娘对话",
description="来和可爱的猫娘聊天吧~",
theme="soft",
examples=["今天天气真好呢", "你在做什么呢?", "要不要一起玩?"],
cache_examples=False,
type="messages"
)
return interface
# 创建应用实例
app = NovelAIApp()
interface = app.create_interface()
# 修改 launch 参数
interface.launch(
server_name="0.0.0.0", # 允许外部访问
share=True, # 创建公共链接
ssl_verify=False # 禁用 SSL 验证
)