金融聊天机器人
使用方式
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 设置环境变量,解决OpenMP错误
# 这个环境变量设置允许程序在检测到多个OpenMP库时继续运行,避免出现冲突错误
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 设置模型和数据路径
MODEL_PATH = "Fintech-Dreamer/FinSynth_model_chatbot"
def generate_response(model, tokenizer, instruction, input_text, max_length=512):
"""
使用模型生成客服回答
参数:
model: 加载的语言模型实例
tokenizer: 模型对应的分词器
instruction: 指令部分文本,一般是客户的问题
input_text: 输入文本,作为参考上下文或背景信息
max_length: 生成文本的最大长度,默认为512个token
返回:
prompt: 完整的输入提示词
response: 模型生成的回答内容(仅包含模型生成部分,不包含输入提示词)
"""
# 构造提示词格式 - 使用特殊标记组织对话形式
# <|begin of sentence|>标记句子开始,和/分别标记用户和助手角色
# 这种特殊标记格式是某些模型预训练时使用的对话格式,需要严格遵循
prompt = f"<|begin of sentence|> {instruction}\n{input_text} <|Assistant|>"
# 编码输入,将文本转换为模型可以理解的token序列
# add_special_tokens=True确保添加特殊标记如开始和结束标记
# truncation=True确保输入不超过模型的最大处理长度
# padding=True确保所有输入长度一致
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, add_special_tokens=True)
# 将输入移动到模型所在的设备(CPU/GPU)
# 这确保了模型和输入在同一设备上,避免跨设备操作导致的错误
inputs = inputs.to(model.device)
# 使用torch.no_grad()避免计算梯度,节省内存并加速推理过程
# 在推理阶段不需要计算梯度,这可以显著减少内存使用并提高速度
with torch.no_grad():
# 调用模型的generate方法生成回答
# 这里设置了多个生成参数来控制输出的质量和特性
outputs = model.generate(
**inputs,
max_length=max_length, # 设置生成文本的最大长度
num_return_sequences=1, # 只返回一个生成序列
do_sample=True, # 使用采样策略,增加多样性
temperature=0.6, # 温度参数,控制生成文本的随机性(与模型配置一致)
top_p=0.95, # 使用nucleus sampling,只考虑概率和超过0.95的token(与模型配置一致)
top_k=20, # 只考虑概率最高的20个token,增加生成文本的可控性
repetition_penalty=1.1, # 重复惩罚系数,降低模型重复同一内容的可能性
pad_token_id=151643, # 填充标记ID(与模型配置中的eos_token_id一致)
bos_token_id=151646, # 句子开始标记ID(与模型配置一致)
eos_token_id=151643, # 句子结束标记ID(与模型配置一致)
use_cache=True, # 使用缓存加速生成过程
)
# 将生成的token序列解码为文本
# skip_special_tokens=True会跳过特殊标记,只保留实际文本内容
# clean_up_tokenization_spaces=True会清理分词过程中产生的额外空格
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
# 分离模型输入和输出
# 如果在完整响应中找到了助手标记后的内容,则提取出来
# 否则尝试找出与输入不同的部分作为输出
if "<|Assistant|>" in full_response:
response = full_response.split("<|Assistant|>")[1].strip()
else:
input_without_assistant = prompt.split("<|Assistant|>")[0]
if full_response.startswith(input_without_assistant):
response = full_response[len(input_without_assistant) :].strip()
else:
response = "[无法分离模型生成内容] " + full_response
# 返回两个独立的结果:输入提示词和模型生成的回答
return prompt, response
def process_test_data():
"""
处理测试数据集并生成客服回答
功能:
- 加载客服问答测试数据集
- 初始化模型和分词器
- 对每个测试样本生成客服回答
- 清晰区分并打印模型的输入提示词和输出结果
返回:
None,结果直接打印
"""
# 加载测试数据
# 加载模型和分词器
print(f"加载模型: {MODEL_PATH}")
print("正在加载分词器...")
# 加载预训练的分词器,使用local_files_only=True确保只从本地加载
# 分词器负责将文本转换为数字token序列,这是模型处理文本的第一步
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True, # 允许使用模型自定义的代码
padding_side="left", # 在左侧进行填充,适合生成任务
truncation_side="left", # 在左侧进行截断,保留最新的内容
)
print("正在加载模型...")
# 加载预训练的语言模型,同样使用local_files_only=True
# 模型是实际执行推理的部分,加载到合适的设备(CPU/GPU)上
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True, # 允许使用模型自定义的代码
device_map="auto", # 自动选择可用的设备(CPU/GPU)
torch_dtype=torch.bfloat16, # 使用bfloat16精度,在保持准确性的同时减少内存占用
use_cache=True, # 启用缓存以提高生成速度
)
# 设置模型为评估模式,关闭dropout等训练特性,提高推理性能
# 评估模式下模型行为更加确定,适合推理任务
model.eval()
# 处理每个测试样本
print("开始生成客服回答...")
try:
# 提取指令和输入文本
instruction = "What types of grants are included in the full grant date fair value calculation?" # 指令部分,通常是客户问题
input_text = "(1) Amounts shown in this column do not reflect dollar amounts actually received by the NEO. Instead, these amounts reflect the aggregate full grant date fair value calculated in accordance with ASC 718 for the respective fiscal year for grants of RSUs, SY PSUs, and MY PSUs, as applicable. The assumptions used in the calculation of values of the awards are set forth under Note 4 to our consolidated financial statements titled Stock-Based Compensation in our Form 10-K. With regard to the stock awards with performance-based vesting conditions, the reported grant date fair value assumes the probable outcome of the conditions at Base Compensation Plan for SY PSUs and MY PSUs, determined in accordance with applicable accounting standards."
# 生成回答
print("\n===== 模型预测 =====")
print("\n正在生成客服回答...")
# 构造完整提示并生成回答
prompt, response = generate_response(model, tokenizer, instruction, input_text)
# 清晰区分模型输入和输出(只打印一次)
print("\n\n==================== 模型输入 ====================")
print(prompt)
print("\n\n==================== 模型输出(仅包含生成部分)====================")
print(response)
except Exception as e:
# 异常处理,确保一个样本的错误不会导致整个程序崩溃
# 这对于批量处理多个样本时非常重要
print(f"\n处理样本时出错: {str(e)}")
import traceback
traceback.print_exc() # 打印详细错误信息,便于调试
print("\n客服回答生成完成!")
return None
def main():
"""
主函数,程序入口点
功能:
- 启动客服聊天机器人测试流程
- 处理测试数据并生成回答
- 控制整个程序的执行流程
"""
print("===== 客服聊天机器人模型调用 =====")
process_test_data()
if __name__ == "__main__":
main()
数据集参考
Fintech-Dreamer/FinSynth_data · Datasets at Hugging Face
前端框架参考
数据处理方式参考
- Downloads last month
- 46
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
Model tree for Fintech-Dreamer/FinSynth_model_chatbot
Base model
deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B