Spaces:
Running
on
Zero
Running
on
Zero
liuyang
commited on
Commit
Β·
1656d98
1
Parent(s):
233e4b4
transcribe by word
Browse files
app.py
CHANGED
|
@@ -36,125 +36,9 @@ pipe = pipeline(
|
|
| 36 |
torch_dtype=torch.float16,
|
| 37 |
device="cuda",
|
| 38 |
model_kwargs={"attn_implementation": "flash_attention_2"},
|
| 39 |
-
return_timestamps=
|
| 40 |
)
|
| 41 |
|
| 42 |
-
def comprehensive_flash_attention_verification():
|
| 43 |
-
"""Comprehensive verification of flash attention setup"""
|
| 44 |
-
print("π Running Flash Attention Verification...")
|
| 45 |
-
print("=" * 50)
|
| 46 |
-
|
| 47 |
-
verification_results = {}
|
| 48 |
-
|
| 49 |
-
# Check 1: Package Installation
|
| 50 |
-
print("π Checking Python packages...")
|
| 51 |
-
try:
|
| 52 |
-
import flash_attn
|
| 53 |
-
print(f"β
flash-attn: {flash_attn.__version__}")
|
| 54 |
-
verification_results["flash_attn_installed"] = True
|
| 55 |
-
except ImportError:
|
| 56 |
-
print("β flash-attn: Not installed")
|
| 57 |
-
verification_results["flash_attn_installed"] = False
|
| 58 |
-
|
| 59 |
-
try:
|
| 60 |
-
import transformers
|
| 61 |
-
print(f"β
transformers: {transformers.__version__}")
|
| 62 |
-
verification_results["transformers_available"] = True
|
| 63 |
-
except ImportError:
|
| 64 |
-
print("β transformers: Not installed")
|
| 65 |
-
verification_results["transformers_available"] = False
|
| 66 |
-
|
| 67 |
-
# Check 2: CUDA Availability
|
| 68 |
-
print("\nπ Checking CUDA availability...")
|
| 69 |
-
cuda_available = torch.cuda.is_available()
|
| 70 |
-
print(f"β
CUDA available: {cuda_available}")
|
| 71 |
-
if cuda_available:
|
| 72 |
-
print(f"β
CUDA version: {torch.version.cuda}")
|
| 73 |
-
print(f"β
GPU count: {torch.cuda.device_count()}")
|
| 74 |
-
for i in range(torch.cuda.device_count()):
|
| 75 |
-
print(f"β
GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 76 |
-
verification_results["cuda_available"] = cuda_available
|
| 77 |
-
|
| 78 |
-
# Check 3: Flash Attention Import
|
| 79 |
-
print("\nπ Testing flash attention imports...")
|
| 80 |
-
try:
|
| 81 |
-
from flash_attn import flash_attn_func
|
| 82 |
-
print("β
flash_attn_func imported successfully")
|
| 83 |
-
|
| 84 |
-
if flash_attn_func is None:
|
| 85 |
-
print("β flash_attn_func is None")
|
| 86 |
-
verification_results["flash_attn_import"] = False
|
| 87 |
-
else:
|
| 88 |
-
print("β
flash_attn_func is callable")
|
| 89 |
-
verification_results["flash_attn_import"] = True
|
| 90 |
-
except ImportError as e:
|
| 91 |
-
print(f"β Import error: {e}")
|
| 92 |
-
verification_results["flash_attn_import"] = False
|
| 93 |
-
except Exception as e:
|
| 94 |
-
print(f"β Unexpected error: {e}")
|
| 95 |
-
verification_results["flash_attn_import"] = False
|
| 96 |
-
|
| 97 |
-
# Check 4: Flash Attention Functionality Test
|
| 98 |
-
print("\nπ Testing flash attention functionality...")
|
| 99 |
-
if not cuda_available:
|
| 100 |
-
print("β οΈ Skipping functionality test - CUDA not available")
|
| 101 |
-
verification_results["flash_attn_functional"] = False
|
| 102 |
-
elif not verification_results.get("flash_attn_import", False):
|
| 103 |
-
print("β οΈ Skipping functionality test - Import failed")
|
| 104 |
-
verification_results["flash_attn_functional"] = False
|
| 105 |
-
else:
|
| 106 |
-
try:
|
| 107 |
-
from flash_attn import flash_attn_func
|
| 108 |
-
|
| 109 |
-
# Create small dummy tensors
|
| 110 |
-
batch_size, seq_len, num_heads, head_dim = 1, 16, 4, 32
|
| 111 |
-
device = "cuda:0"
|
| 112 |
-
dtype = torch.float16
|
| 113 |
-
|
| 114 |
-
print(f"Creating tensors: batch={batch_size}, seq_len={seq_len}, heads={num_heads}, dim={head_dim}")
|
| 115 |
-
|
| 116 |
-
q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
| 117 |
-
k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
| 118 |
-
v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
| 119 |
-
|
| 120 |
-
print("β
Tensors created successfully")
|
| 121 |
-
|
| 122 |
-
# Test flash attention
|
| 123 |
-
output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)
|
| 124 |
-
|
| 125 |
-
print(f"β
Flash attention output shape: {output.shape}")
|
| 126 |
-
print("β
Flash attention test passed!")
|
| 127 |
-
verification_results["flash_attn_functional"] = True
|
| 128 |
-
|
| 129 |
-
except Exception as e:
|
| 130 |
-
print(f"β Flash attention test failed: {e}")
|
| 131 |
-
import traceback
|
| 132 |
-
traceback.print_exc()
|
| 133 |
-
verification_results["flash_attn_functional"] = False
|
| 134 |
-
|
| 135 |
-
# Summary
|
| 136 |
-
print("\n" + "=" * 50)
|
| 137 |
-
print("π VERIFICATION SUMMARY")
|
| 138 |
-
print("=" * 50)
|
| 139 |
-
|
| 140 |
-
all_passed = True
|
| 141 |
-
for check_name, result in verification_results.items():
|
| 142 |
-
status = "β
PASS" if result else "β FAIL"
|
| 143 |
-
print(f"{check_name}: {status}")
|
| 144 |
-
if not result:
|
| 145 |
-
all_passed = False
|
| 146 |
-
|
| 147 |
-
if all_passed:
|
| 148 |
-
print("\nπ All checks passed! Flash attention should work.")
|
| 149 |
-
return True
|
| 150 |
-
else:
|
| 151 |
-
print("\nβ οΈ Some checks failed. Flash attention may not work properly.")
|
| 152 |
-
print("\nRecommendations:")
|
| 153 |
-
print("1. Try reinstalling flash-attn: pip uninstall flash-attn && pip install flash-attn --no-build-isolation")
|
| 154 |
-
print("2. Check CUDA compatibility with your PyTorch version")
|
| 155 |
-
print("3. Consider using default attention as fallback")
|
| 156 |
-
return False
|
| 157 |
-
|
| 158 |
class WhisperTranscriber:
|
| 159 |
def __init__(self):
|
| 160 |
self.pipe = pipe # Use global pipeline
|
|
@@ -207,12 +91,6 @@ class WhisperTranscriber:
|
|
| 207 |
def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None):
|
| 208 |
"""Transcribe audio using Whisper with flash attention"""
|
| 209 |
|
| 210 |
-
# Run comprehensive flash attention verification
|
| 211 |
-
#flash_attention_working = comprehensive_flash_attention_verification()
|
| 212 |
-
#if not flash_attention_working:
|
| 213 |
-
# print("β οΈ Flash attention verification failed, but proceeding with transcription...")
|
| 214 |
-
# print("You may encounter the TypeError: 'NoneType' object is not callable error")
|
| 215 |
-
|
| 216 |
'''
|
| 217 |
#if self.pipe is None:
|
| 218 |
# self.setup_models()
|
|
|
|
| 36 |
torch_dtype=torch.float16,
|
| 37 |
device="cuda",
|
| 38 |
model_kwargs={"attn_implementation": "flash_attention_2"},
|
| 39 |
+
return_timestamps="word",
|
| 40 |
)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
class WhisperTranscriber:
|
| 43 |
def __init__(self):
|
| 44 |
self.pipe = pipe # Use global pipeline
|
|
|
|
| 91 |
def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None):
|
| 92 |
"""Transcribe audio using Whisper with flash attention"""
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
'''
|
| 95 |
#if self.pipe is None:
|
| 96 |
# self.setup_models()
|