Commit
·
a5a2048
1
Parent(s):
893d32f
🔧 修复FlashAttention2错误和场景加载问题
Browse files🐛 主要修复:
1. FlashAttention2 兼容性问题
- 添加自动回退机制:优先尝试 FlashAttention2,失败则使用标准SDPA
- 修改默认注意力实现为 'sdpa',适配HF Spaces环境
- 添加友好的日志提示,明确使用的注意力机制
2. 场景下拉菜单问题
- 使用预定义场景列表,确保界面稳定加载
- 修复动态获取场景时的竞态条件问题
- 特殊处理'默认示例'选项,直接调用默认音频加载
✨ 改进:
- 增强错误处理和用户反馈
- 优化场景加载逻辑,支持多种场景类型
- 确保即使在场景文件缺失的情况下也能正常工作
🎯 效果:
- 解决 ImportError: FlashAttention2 安装问题
- 修复场景选择器显示异常
- 提升 Space 在 CPU/GPU 环境下的兼容性
- 确保所有预设场景都能正确加载
- __pycache__/app.cpython-313.pyc +0 -0
- app.py +23 -5
- generation_utils.py +8 -2
__pycache__/app.cpython-313.pyc
ADDED
|
Binary file (25.7 kB). View file
|
|
|
app.py
CHANGED
|
@@ -433,13 +433,21 @@ def create_space_ui() -> gr.Blocks:
|
|
| 433 |
|
| 434 |
with gr.Group():
|
| 435 |
gr.Markdown("### 🚀 快速操作")
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
scenario_dropdown = gr.Dropdown(
|
| 441 |
-
choices=
|
| 442 |
-
value=
|
| 443 |
label="🎭 选择场景",
|
| 444 |
info="选择一个预设场景,自动填充对话文本和参考音频"
|
| 445 |
)
|
|
@@ -514,6 +522,16 @@ def create_space_ui() -> gr.Blocks:
|
|
| 514 |
gr.Warning("⚠️ 请先选择一个场景")
|
| 515 |
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
scenarios = get_scenario_examples()
|
| 518 |
if name not in scenarios:
|
| 519 |
gr.Error(f"❌ 场景不存在: {name}")
|
|
|
|
| 433 |
|
| 434 |
with gr.Group():
|
| 435 |
gr.Markdown("### 🚀 快速操作")
|
| 436 |
+
|
| 437 |
+
# 预定义场景选项,确保界面稳定
|
| 438 |
+
predefined_scenarios = [
|
| 439 |
+
"🎧 默认示例",
|
| 440 |
+
"🤖 科技播客 - AI发展趋势",
|
| 441 |
+
"📚 教育播客 - 高效学习方法",
|
| 442 |
+
"🍜 生活播客 - 美食文化探索",
|
| 443 |
+
"💼 商业播客 - 创业经验分享",
|
| 444 |
+
"🏃 健康播客 - 运动健身指南",
|
| 445 |
+
"🧠 心理播客 - 情绪管理技巧"
|
| 446 |
+
]
|
| 447 |
|
| 448 |
scenario_dropdown = gr.Dropdown(
|
| 449 |
+
choices=predefined_scenarios,
|
| 450 |
+
value=predefined_scenarios[0],
|
| 451 |
label="🎭 选择场景",
|
| 452 |
info="选择一个预设场景,自动填充对话文本和参考音频"
|
| 453 |
)
|
|
|
|
| 522 |
gr.Warning("⚠️ 请先选择一个场景")
|
| 523 |
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 524 |
|
| 525 |
+
# 处理默认示例的特殊情况
|
| 526 |
+
if name == "🎧 默认示例":
|
| 527 |
+
try:
|
| 528 |
+
result = load_default_audio()
|
| 529 |
+
gr.Info("✅ 成功加载默认示例")
|
| 530 |
+
return result
|
| 531 |
+
except Exception as e:
|
| 532 |
+
gr.Error(f"❌ 加载默认示例时出错: {str(e)}")
|
| 533 |
+
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 534 |
+
|
| 535 |
scenarios = get_scenario_examples()
|
| 536 |
if name not in scenarios:
|
| 537 |
gr.Error(f"❌ 场景不存在: {name}")
|
generation_utils.py
CHANGED
|
@@ -12,10 +12,16 @@ from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
|
|
| 12 |
MAX_CHANNELS = 8
|
| 13 |
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
|
| 14 |
|
| 15 |
-
def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
|
| 21 |
|
|
|
|
| 12 |
MAX_CHANNELS = 8
|
| 13 |
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
|
| 14 |
|
| 15 |
+
def load_model(model_path, spt_config_path, spt_checkpoint_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa"):
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 17 |
|
| 18 |
+
# 尝试使用 FlashAttention2,失败则回退到标准实现
|
| 19 |
+
try:
|
| 20 |
+
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
|
| 21 |
+
print("✅ 使用 FlashAttention2")
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("⚠️ FlashAttention2 不可用,使用标准注意力机制")
|
| 24 |
+
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch_dtype, attn_implementation=attn_implementation)
|
| 25 |
|
| 26 |
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
|
| 27 |
|