|
import gradio as gr
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
import json
|
|
import os
|
|
from train import ModelTrainer
|
|
|
|
class NovelAIApp:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.trainer = None
|
|
|
|
|
|
with open('configs/system_prompts.json', 'r', encoding='utf-8') as f:
|
|
self.system_prompts = json.load(f)
|
|
|
|
|
|
self.current_mood = "暗示"
|
|
self.mood_history = []
|
|
|
|
def load_model(self, model_path):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
model_path,
|
|
trust_remote_code=True
|
|
)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
trust_remote_code=True,
|
|
load_in_8bit=True,
|
|
device_map="auto"
|
|
)
|
|
|
|
def train_model(self, files):
|
|
if not self.trainer:
|
|
self.trainer = ModelTrainer(
|
|
"THUDM/chatglm2-6b",
|
|
"configs/system_prompts.json"
|
|
)
|
|
|
|
dataset = self.trainer.prepare_dataset(files)
|
|
self.trainer.train(dataset)
|
|
return "训练完成!"
|
|
|
|
def generate_text(self, message, history, mood=None):
|
|
"""增强的生成文本方法"""
|
|
if not self.model:
|
|
return "请先加载模型!"
|
|
|
|
|
|
if mood:
|
|
self.current_mood = mood
|
|
else:
|
|
|
|
self.current_mood = self._detect_mood(message)
|
|
|
|
system_prompt = f"""<|system|>{self.system_prompts['base_prompt']}
|
|
当前情境:{self.current_mood}</|system|>"""
|
|
|
|
|
|
full_history = ""
|
|
for msg in history:
|
|
full_history += f"<|user|>{msg[0]}</|user|>\n<|assistant|>{msg[1]}</|assistant|>\n"
|
|
|
|
formatted_prompt = f"{system_prompt}\n{full_history}<|user|>{message}</|user|>\n<|assistant|>"
|
|
|
|
|
|
temperature = self._get_dynamic_temperature()
|
|
|
|
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
|
|
outputs = self.model.generate(
|
|
inputs["input_ids"],
|
|
max_length=1024,
|
|
temperature=temperature,
|
|
top_p=0.9,
|
|
repetition_penalty=1.1,
|
|
do_sample=True,
|
|
num_return_sequences=1,
|
|
pad_token_id=self.tokenizer.eos_token_id
|
|
)
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
response = response.split("<|assistant|>")[-1].strip()
|
|
|
|
|
|
self.mood_history.append(self.current_mood)
|
|
return response
|
|
|
|
def _detect_mood(self, message):
|
|
"""根据用户输入智能判断情境"""
|
|
mood_keywords = {
|
|
"撒娇": ["想你", "抱抱", "亲亲", "摸摸"],
|
|
"害羞": ["害羞", "不好意思", "脸红"],
|
|
"粘人": ["陪我", "一起", "不要走"],
|
|
"暗示": ["热", "难受", "想要", "摸"],
|
|
"调皮": ["玩", "闹", "捣乱", "逗"],
|
|
"吃醋": ["别人", "不理我", "生气"]
|
|
}
|
|
|
|
|
|
mood_scores = {mood: 0 for mood in mood_keywords}
|
|
for mood, keywords in mood_keywords.items():
|
|
for keyword in keywords:
|
|
if keyword in message:
|
|
mood_scores[mood] += 1
|
|
|
|
|
|
max_score = max(mood_scores.values())
|
|
if max_score == 0:
|
|
return self.current_mood
|
|
|
|
return max(mood_scores.items(), key=lambda x: x[1])[0]
|
|
|
|
def _get_dynamic_temperature(self):
|
|
"""根据情境动态调整生成温度"""
|
|
temperature_map = {
|
|
"撒娇": 0.8,
|
|
"害羞": 0.6,
|
|
"粘人": 0.7,
|
|
"暗示": 0.9,
|
|
"调皮": 0.85,
|
|
"吃醋": 0.75
|
|
}
|
|
return temperature_map.get(self.current_mood, 0.7)
|
|
|
|
def create_interface(self):
|
|
"""优化的界面创建方法"""
|
|
with gr.Blocks() as interface:
|
|
gr.Markdown("# 猫娘对话助手")
|
|
|
|
with gr.Tab("模型训练"):
|
|
with gr.Row():
|
|
file_output = gr.File(
|
|
file_count="multiple",
|
|
label="上传小说文本文件"
|
|
)
|
|
train_button = gr.Button("开始训练")
|
|
|
|
train_output = gr.Textbox(label="训练状态")
|
|
|
|
with gr.Tab("对话"):
|
|
with gr.Row():
|
|
mood_selector = gr.Dropdown(
|
|
choices=["撒娇", "害羞", "粘人", "暗示", "调皮", "吃醋"],
|
|
value=self.current_mood,
|
|
label="选择当前情境"
|
|
)
|
|
|
|
chatbot = gr.ChatInterface(
|
|
fn=lambda msg, history: self.generate_text(msg, history, mood_selector.value),
|
|
title="与猫娘对话",
|
|
description="来和可爱的猫娘聊天吧~",
|
|
theme="soft",
|
|
examples=[
|
|
"今天好想你呀~",
|
|
"主人在做什么呢?",
|
|
"要不要一起玩?",
|
|
"人家身体有点奇怪...",
|
|
"主人不要理别人啦!"
|
|
],
|
|
cache_examples=False
|
|
)
|
|
|
|
return interface
|
|
|
|
|
|
app = NovelAIApp()
|
|
interface = app.create_interface()
|
|
|
|
|
|
interface.launch(
|
|
server_name="0.0.0.0",
|
|
share=True,
|
|
ssl_verify=False
|
|
) |