|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import json |
|
import os |
|
from train import ModelTrainer |
|
|
|
class NovelAIApp: |
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
self.trainer = None |
|
|
|
|
|
with open('configs/system_prompts.json', 'r', encoding='utf-8') as f: |
|
self.system_prompts = json.load(f) |
|
|
|
|
|
self.current_mood = "暗示" |
|
|
|
def load_model(self, model_path): |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, |
|
trust_remote_code=True |
|
) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
trust_remote_code=True, |
|
load_in_8bit=True, |
|
device_map="auto" |
|
) |
|
|
|
def train_model(self, files): |
|
if not self.trainer: |
|
self.trainer = ModelTrainer( |
|
"THUDM/chatglm2-6b", |
|
"configs/system_prompts.json" |
|
) |
|
|
|
dataset = self.trainer.prepare_dataset(files) |
|
self.trainer.train(dataset) |
|
return "训练完成!" |
|
|
|
def generate_text(self, message, history): |
|
"""修改后的生成文本方法,适配 ChatInterface""" |
|
if not self.model: |
|
return "请先加载模型!" |
|
|
|
system_prompt = self.system_prompts.get("base_prompt") |
|
|
|
|
|
full_history = "" |
|
for msg in history: |
|
full_history += f"<|user|>{msg[0]}</|user|>\n<|assistant|>{msg[1]}</|assistant|>\n" |
|
|
|
formatted_prompt = f"""<|system|>{system_prompt}</|system|> |
|
{full_history}<|user|>{message}</|user|> |
|
<|assistant|>""" |
|
|
|
inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
|
outputs = self.model.generate( |
|
inputs["input_ids"], |
|
max_length=1024, |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
response = response.split("<|assistant|>")[-1].strip() |
|
return response |
|
|
|
def create_interface(self): |
|
"""创建 Gradio 界面""" |
|
with gr.Blocks() as interface: |
|
gr.Markdown("# 猫娘对话助手") |
|
|
|
with gr.Tab("模型训练"): |
|
file_output = gr.File( |
|
file_count="multiple", |
|
label="上传小说文本文件" |
|
) |
|
train_button = gr.Button("开始训练") |
|
train_output = gr.Textbox(label="训练状态") |
|
|
|
train_button.click( |
|
fn=self.train_model, |
|
inputs=[file_output], |
|
outputs=[train_output] |
|
) |
|
|
|
with gr.Tab("对话"): |
|
chatbot = gr.ChatInterface( |
|
fn=self.generate_text, |
|
title="与猫娘对话", |
|
description="来和可爱的猫娘聊天吧~", |
|
theme="soft", |
|
examples=["今天天气真好呢", "你在做什么呢?", "要不要一起玩?"], |
|
cache_examples=False, |
|
type="messages" |
|
) |
|
|
|
return interface |
|
|
|
|
|
app = NovelAIApp() |
|
interface = app.create_interface() |
|
|
|
|
|
interface.launch( |
|
server_name="0.0.0.0", |
|
share=True, |
|
ssl_verify=False |
|
) |