liuyang commited on
Commit
1656d98
Β·
1 Parent(s): 233e4b4

transcribe by word

Browse files
Files changed (1) hide show
  1. app.py +1 -123
app.py CHANGED
@@ -36,125 +36,9 @@ pipe = pipeline(
36
  torch_dtype=torch.float16,
37
  device="cuda",
38
  model_kwargs={"attn_implementation": "flash_attention_2"},
39
- return_timestamps=True,
40
  )
41
 
42
- def comprehensive_flash_attention_verification():
43
- """Comprehensive verification of flash attention setup"""
44
- print("πŸ” Running Flash Attention Verification...")
45
- print("=" * 50)
46
-
47
- verification_results = {}
48
-
49
- # Check 1: Package Installation
50
- print("πŸ” Checking Python packages...")
51
- try:
52
- import flash_attn
53
- print(f"βœ… flash-attn: {flash_attn.__version__}")
54
- verification_results["flash_attn_installed"] = True
55
- except ImportError:
56
- print("❌ flash-attn: Not installed")
57
- verification_results["flash_attn_installed"] = False
58
-
59
- try:
60
- import transformers
61
- print(f"βœ… transformers: {transformers.__version__}")
62
- verification_results["transformers_available"] = True
63
- except ImportError:
64
- print("❌ transformers: Not installed")
65
- verification_results["transformers_available"] = False
66
-
67
- # Check 2: CUDA Availability
68
- print("\nπŸ” Checking CUDA availability...")
69
- cuda_available = torch.cuda.is_available()
70
- print(f"βœ… CUDA available: {cuda_available}")
71
- if cuda_available:
72
- print(f"βœ… CUDA version: {torch.version.cuda}")
73
- print(f"βœ… GPU count: {torch.cuda.device_count()}")
74
- for i in range(torch.cuda.device_count()):
75
- print(f"βœ… GPU {i}: {torch.cuda.get_device_name(i)}")
76
- verification_results["cuda_available"] = cuda_available
77
-
78
- # Check 3: Flash Attention Import
79
- print("\nπŸ” Testing flash attention imports...")
80
- try:
81
- from flash_attn import flash_attn_func
82
- print("βœ… flash_attn_func imported successfully")
83
-
84
- if flash_attn_func is None:
85
- print("❌ flash_attn_func is None")
86
- verification_results["flash_attn_import"] = False
87
- else:
88
- print("βœ… flash_attn_func is callable")
89
- verification_results["flash_attn_import"] = True
90
- except ImportError as e:
91
- print(f"❌ Import error: {e}")
92
- verification_results["flash_attn_import"] = False
93
- except Exception as e:
94
- print(f"❌ Unexpected error: {e}")
95
- verification_results["flash_attn_import"] = False
96
-
97
- # Check 4: Flash Attention Functionality Test
98
- print("\nπŸ” Testing flash attention functionality...")
99
- if not cuda_available:
100
- print("⚠️ Skipping functionality test - CUDA not available")
101
- verification_results["flash_attn_functional"] = False
102
- elif not verification_results.get("flash_attn_import", False):
103
- print("⚠️ Skipping functionality test - Import failed")
104
- verification_results["flash_attn_functional"] = False
105
- else:
106
- try:
107
- from flash_attn import flash_attn_func
108
-
109
- # Create small dummy tensors
110
- batch_size, seq_len, num_heads, head_dim = 1, 16, 4, 32
111
- device = "cuda:0"
112
- dtype = torch.float16
113
-
114
- print(f"Creating tensors: batch={batch_size}, seq_len={seq_len}, heads={num_heads}, dim={head_dim}")
115
-
116
- q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
117
- k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
118
- v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
119
-
120
- print("βœ… Tensors created successfully")
121
-
122
- # Test flash attention
123
- output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)
124
-
125
- print(f"βœ… Flash attention output shape: {output.shape}")
126
- print("βœ… Flash attention test passed!")
127
- verification_results["flash_attn_functional"] = True
128
-
129
- except Exception as e:
130
- print(f"❌ Flash attention test failed: {e}")
131
- import traceback
132
- traceback.print_exc()
133
- verification_results["flash_attn_functional"] = False
134
-
135
- # Summary
136
- print("\n" + "=" * 50)
137
- print("πŸ“Š VERIFICATION SUMMARY")
138
- print("=" * 50)
139
-
140
- all_passed = True
141
- for check_name, result in verification_results.items():
142
- status = "βœ… PASS" if result else "❌ FAIL"
143
- print(f"{check_name}: {status}")
144
- if not result:
145
- all_passed = False
146
-
147
- if all_passed:
148
- print("\nπŸŽ‰ All checks passed! Flash attention should work.")
149
- return True
150
- else:
151
- print("\n⚠️ Some checks failed. Flash attention may not work properly.")
152
- print("\nRecommendations:")
153
- print("1. Try reinstalling flash-attn: pip uninstall flash-attn && pip install flash-attn --no-build-isolation")
154
- print("2. Check CUDA compatibility with your PyTorch version")
155
- print("3. Consider using default attention as fallback")
156
- return False
157
-
158
  class WhisperTranscriber:
159
  def __init__(self):
160
  self.pipe = pipe # Use global pipeline
@@ -207,12 +91,6 @@ class WhisperTranscriber:
207
  def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None):
208
  """Transcribe audio using Whisper with flash attention"""
209
 
210
- # Run comprehensive flash attention verification
211
- #flash_attention_working = comprehensive_flash_attention_verification()
212
- #if not flash_attention_working:
213
- # print("⚠️ Flash attention verification failed, but proceeding with transcription...")
214
- # print("You may encounter the TypeError: 'NoneType' object is not callable error")
215
-
216
  '''
217
  #if self.pipe is None:
218
  # self.setup_models()
 
36
  torch_dtype=torch.float16,
37
  device="cuda",
38
  model_kwargs={"attn_implementation": "flash_attention_2"},
39
+ return_timestamps="word",
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class WhisperTranscriber:
43
  def __init__(self):
44
  self.pipe = pipe # Use global pipeline
 
91
  def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None):
92
  """Transcribe audio using Whisper with flash attention"""
93
 
 
 
 
 
 
 
94
  '''
95
  #if self.pipe is None:
96
  # self.setup_models()