ALLARD Marc-Antoine commited on
Commit
c30da2c
Β·
1 Parent(s): 4afd720

updating app

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +666 -0
src/streamlit_app.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import json
4
+ import wave
5
+ import numpy as np
6
+ from datetime import timedelta
7
+ import base64
8
+ from io import BytesIO
9
+ import tempfile
10
+
11
+ # Page configuration
12
+ st.set_page_config(
13
+ page_title="ASR Annotation Tool",
14
+ page_icon="🎀",
15
+ layout="wide",
16
+ initial_sidebar_state="expanded"
17
+ )
18
+
19
+ # Initialize session state
20
+ if 'annotation_type' not in st.session_state:
21
+ st.session_state.annotation_type = None
22
+ if 'audio_file' not in st.session_state:
23
+ st.session_state.audio_file = None
24
+ if 'transcript' not in st.session_state:
25
+ st.session_state.transcript = ""
26
+ if 'segments' not in st.session_state:
27
+ st.session_state.segments = []
28
+ if 'current_page' not in st.session_state:
29
+ st.session_state.current_page = "home"
30
+ if 'audio_duration' not in st.session_state:
31
+ st.session_state.audio_duration = 0
32
+ if 'save_path' not in st.session_state:
33
+ st.session_state.save_path = ""
34
+
35
+ def get_audio_duration(audio_file):
36
+ """Get audio duration in seconds"""
37
+ try:
38
+ with wave.open(audio_file, 'rb') as wav_file:
39
+ frames = wav_file.getnframes()
40
+ sample_rate = wav_file.getframerate()
41
+ duration = frames / float(sample_rate)
42
+ return duration
43
+ except:
44
+ return 0
45
+
46
+ def format_time(seconds):
47
+ """Format seconds to HH:MM:SS.mmm"""
48
+ td = timedelta(seconds=seconds)
49
+ total_seconds = int(td.total_seconds())
50
+ hours, remainder = divmod(total_seconds, 3600)
51
+ minutes, seconds = divmod(remainder, 60)
52
+ milliseconds = int((td.total_seconds() - total_seconds) * 1000)
53
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
54
+
55
+ def create_audio_player_html(audio_data, audio_id="audio_player"):
56
+ """Create HTML audio player with controls"""
57
+ audio_base64 = base64.b64encode(audio_data).decode()
58
+
59
+ html = f"""
60
+ <div style="margin: 20px 0;">
61
+ <audio id="{audio_id}" controls style="width: 100%; height: 40px;">
62
+ <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
63
+ Your browser does not support the audio element.
64
+ </audio>
65
+ <div style="margin-top: 10px;">
66
+ <button onclick="document.getElementById('{audio_id}').currentTime -= 5"
67
+ style="margin-right: 5px; padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;">
68
+ βͺ -5s
69
+ </button>
70
+ <button onclick="document.getElementById('{audio_id}').currentTime -= 1"
71
+ style="margin-right: 5px; padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;">
72
+ βͺ -1s
73
+ </button>
74
+ <button onclick="var audio = document.getElementById('{audio_id}'); audio.paused ? audio.play() : audio.pause()"
75
+ style="margin-right: 5px; padding: 5px 15px; background: #00cc44; color: white; border: none; border-radius: 3px; cursor: pointer;">
76
+ ⏯️ Play/Pause
77
+ </button>
78
+ <button onclick="document.getElementById('{audio_id}').currentTime += 1"
79
+ style="margin-right: 5px; padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;">
80
+ +1s ⏩
81
+ </button>
82
+ <button onclick="document.getElementById('{audio_id}').currentTime += 5"
83
+ style="padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;">
84
+ +5s ⏩
85
+ </button>
86
+ </div>
87
+ </div>
88
+ """
89
+ return html
90
+
91
+ def create_waveform_html(audio_data, segments=None):
92
+ """Create interactive waveform with region selection"""
93
+ audio_base64 = base64.b64encode(audio_data).decode()
94
+ segments_json = json.dumps(segments or [])
95
+
96
+ html = f"""
97
+ <div id="waveform-container" style="margin: 20px 0;">
98
+ <div id="waveform" style="height: 200px; border: 1px solid #ddd;"></div>
99
+ <div style="margin-top: 10px;">
100
+ <button id="play-pause" style="margin-right: 5px; padding: 8px 15px; background: #00cc44; color: white; border: none; border-radius: 3px; cursor: pointer;">
101
+ ⏯️ Play/Pause
102
+ </button>
103
+ <button id="add-region" style="margin-right: 5px; padding: 8px 15px; background: #0066cc; color: white; border: none; border-radius: 3px; cursor: pointer;">
104
+ βž• Add Region
105
+ </button>
106
+ <button id="clear-regions" style="padding: 8px 15px; background: #cc0000; color: white; border: none; border-radius: 3px; cursor: pointer;">
107
+ πŸ—‘οΈ Clear All
108
+ </button>
109
+ </div>
110
+ <div id="regions-list" style="margin-top: 15px; max-height: 200px; overflow-y: auto; color: white;">
111
+ <h4>Segments:</h4>
112
+ <div id="segments-container"></div>
113
+ </div>
114
+ </div>
115
+
116
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/wavesurfer.js/6.6.4/wavesurfer.min.js"></script>
117
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/wavesurfer.js/6.6.4/plugin/wavesurfer.regions.min.js"></script>
118
+
119
+ <script>
120
+ let wavesurfer;
121
+ let regions = {segments_json};
122
+ let speakerColors = {{}};
123
+
124
+ // Initialize WaveSurfer
125
+ wavesurfer = WaveSurfer.create({{
126
+ container: '#waveform',
127
+ waveColor: '#4FC3F7',
128
+ progressColor: '#1976D2',
129
+ height: 200,
130
+ responsive: true,
131
+ plugins: [
132
+ WaveSurfer.regions.create({{
133
+ dragSelection: true,
134
+ color: 'rgba(255, 75, 75, 0.3)'
135
+ }})
136
+ ]
137
+ }});
138
+
139
+ // Load audio
140
+ wavesurfer.load('data:audio/wav;base64,{audio_base64}');
141
+
142
+ // Play/Pause button
143
+ document.getElementById('play-pause').addEventListener('click', function() {{
144
+ wavesurfer.playPause();
145
+ }});
146
+
147
+ // Add region button
148
+ document.getElementById('add-region').addEventListener('click', function() {{
149
+ const start = wavesurfer.getCurrentTime();
150
+ const end = Math.min(start + 2, wavesurfer.getDuration());
151
+
152
+ // Ask for speaker ID first
153
+ const speakerId = prompt("Enter speaker ID (e.g., SPK001):", "SPK" + (Object.keys(speakerColors).length + 1).toString().padStart(3, '0'));
154
+ if (speakerId) {{
155
+ addRegion(start, end, speakerId);
156
+ }}
157
+ }});
158
+
159
+ // Clear regions button
160
+ document.getElementById('clear-regions').addEventListener('click', function() {{
161
+ wavesurfer.clearRegions();
162
+ regions = [];
163
+ updateRegionsList();
164
+ }});
165
+
166
+ // Add region function
167
+ function addRegion(start, end, speaker_id) {{
168
+ // Get or assign color for this speaker
169
+ if (!speakerColors[speaker_id]) {{
170
+ speakerColors[speaker_id] = getColorForSpeaker(speaker_id);
171
+ }}
172
+
173
+ const region = wavesurfer.addRegion({{
174
+ start: start,
175
+ end: end,
176
+ color: speakerColors[speaker_id],
177
+ drag: true,
178
+ resize: true
179
+ }});
180
+
181
+ regions.push({{
182
+ id: region.id,
183
+ start: start,
184
+ end: end,
185
+ speaker_id: speaker_id
186
+ }});
187
+
188
+ updateRegionsList();
189
+ }}
190
+
191
+ // Update regions list
192
+ function updateRegionsList() {{
193
+ const container = document.getElementById('segments-container');
194
+ container.innerHTML = '';
195
+
196
+ regions.forEach((region, index) => {{
197
+ const div = document.createElement('div');
198
+ div.style.cssText = 'border: 1px solid #ddd; padding: 10px; margin: 5px 0; border-radius: 5px; color: white;';
199
+ div.innerHTML = `
200
+ <div style="display: flex; justify-content: space-between; align-items: center;">
201
+ <div>
202
+ <strong>Segment ${{index + 1}}</strong><br>
203
+ Start: ${{region.start.toFixed(2)}}s | End: ${{region.end.toFixed(2)}}s<br>
204
+ <input type="text" value="${{region.speaker_id}}"
205
+ onchange="updateSpeakerId('${{region.id}}', this.value, '${{region.speaker_id}}')"
206
+ style="margin-top: 5px; padding: 3px; border: 1px solid #ccc; border-radius: 3px; color: black;">
207
+ </div>
208
+ <button onclick="removeRegion('${{region.id}}')"
209
+ style="background: #cc0000; color: white; border: none; border-radius: 3px; padding: 5px 8px; cursor: pointer;">
210
+ βœ•
211
+ </button>
212
+ </div>
213
+ `;
214
+ container.appendChild(div);
215
+ }});
216
+ }}
217
+
218
+ // Remove region
219
+ function removeRegion(regionId) {{
220
+ wavesurfer.regions.list[regionId].remove();
221
+ regions = regions.filter(r => r.id !== regionId);
222
+ updateRegionsList();
223
+ }}
224
+
225
+ // Update speaker ID
226
+ function updateSpeakerId(regionId, newId, oldId) {{
227
+ const region = regions.find(r => r.id === regionId);
228
+ if (region) {{
229
+ region.speaker_id = newId;
230
+
231
+ // Update color if speaker ID changed
232
+ if (newId !== oldId) {{
233
+ if (!speakerColors[newId]) {{
234
+ speakerColors[newId] = getColorForSpeaker(newId);
235
+ }}
236
+ wavesurfer.regions.list[regionId].color = speakerColors[newId];
237
+ wavesurfer.regions.list[regionId].updateRender();
238
+ }}
239
+ }}
240
+ }}
241
+
242
+ // Get consistent color for a specific speaker
243
+ function getColorForSpeaker(speakerId) {{
244
+ const colors = [
245
+ 'rgba(255, 75, 75, 0.3)', // Red
246
+ 'rgba(75, 192, 75, 0.3)', // Green
247
+ 'rgba(75, 75, 255, 0.3)', // Blue
248
+ 'rgba(255, 192, 75, 0.3)', // Yellow
249
+ 'rgba(255, 75, 255, 0.3)', // Magenta
250
+ 'rgba(75, 192, 192, 0.3)', // Cyan
251
+ 'rgba(192, 75, 192, 0.3)', // Purple
252
+ 'rgba(192, 192, 75, 0.3)' // Olive
253
+ ];
254
+
255
+ // Generate a deterministic index based on the speaker ID string
256
+ let hash = 0;
257
+ for (let i = 0; i < speakerId.length; i++) {{
258
+ hash = ((hash << 5) - hash) + speakerId.charCodeAt(i);
259
+ hash |= 0; // Convert to 32bit integer
260
+ }}
261
+
262
+ // Use the absolute value of hash to select a color
263
+ const index = Math.abs(hash) % colors.length;
264
+ return colors[index];
265
+ }}
266
+
267
+ // Update region on change
268
+ wavesurfer.on('region-update-end', function(region) {{
269
+ const regionData = regions.find(r => r.id === region.id);
270
+ if (regionData) {{
271
+ regionData.start = region.start;
272
+ regionData.end = region.end;
273
+ updateRegionsList();
274
+ }}
275
+ }});
276
+
277
+ // Load existing regions
278
+ wavesurfer.on('ready', function() {{
279
+ // First, create color mappings for existing speakers
280
+ regions.forEach(regionData => {{
281
+ if (!speakerColors[regionData.speaker_id]) {{
282
+ speakerColors[regionData.speaker_id] = getColorForSpeaker(regionData.speaker_id);
283
+ }}
284
+ }});
285
+
286
+ // Then create the regions with their colors
287
+ regions.forEach(regionData => {{
288
+ const region = wavesurfer.addRegion({{
289
+ start: regionData.start,
290
+ end: regionData.end,
291
+ color: speakerColors[regionData.speaker_id],
292
+ drag: true,
293
+ resize: true
294
+ }});
295
+ regionData.id = region.id;
296
+ }});
297
+ updateRegionsList();
298
+ }});
299
+
300
+ // Export regions function for Streamlit
301
+ window.getRegions = function() {{
302
+ return regions.map(r => ({{
303
+ start: r.start,
304
+ end: r.end,
305
+ speaker_id: r.speaker_id
306
+ }}));
307
+ }}
308
+ </script>
309
+ """
310
+ return html
311
+
312
+ def generate_srt(segments, transcript):
313
+ """Generate SRT format from segments and transcript"""
314
+ srt_content = ""
315
+
316
+ for i, segment in enumerate(segments):
317
+ start_time = format_srt_time(segment['start'])
318
+ end_time = format_srt_time(segment['end'])
319
+
320
+ # Extract corresponding text (simplified - in real app you'd need better text matching)
321
+ text = f"{segment['speaker_id']}: [Segment {i+1} text]"
322
+
323
+ srt_content += f"{i+1}\n"
324
+ srt_content += f"{start_time} --> {end_time}\n"
325
+ srt_content += f"{text}\n\n"
326
+
327
+ return srt_content
328
+
329
+ def format_srt_time(seconds):
330
+ """Format time for SRT format (HH:MM:SS,mmm)"""
331
+ hours = int(seconds // 3600)
332
+ minutes = int((seconds % 3600) // 60)
333
+ secs = int(seconds % 60)
334
+ millisecs = int((seconds % 1) * 1000)
335
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"
336
+
337
+ def save_files(transcript, segments=None, save_path=""):
338
+ """Save transcript and SRT files"""
339
+ if not save_path:
340
+ save_path = "."
341
+
342
+ # Save transcript
343
+ transcript_path = os.path.join(save_path, "transcript.txt")
344
+ with open(transcript_path, "w", encoding="utf-8") as f:
345
+ f.write(transcript)
346
+
347
+ if segments:
348
+ # Save SRT
349
+ srt_content = generate_srt(segments, transcript)
350
+ srt_path = os.path.join(save_path, "transcript.srt")
351
+ with open(srt_path, "w", encoding="utf-8") as f:
352
+ f.write(srt_content)
353
+
354
+ return transcript_path, srt_path
355
+
356
+ return transcript_path, None
357
+
358
+ # Main App Layout
359
+ def main():
360
+ st.title("🎀 ASR Annotation Tool")
361
+ st.markdown("Professional tool for creating ASR evaluation datasets")
362
+
363
+ # Sidebar for navigation and settings
364
+ with st.sidebar:
365
+ st.header("Settings")
366
+
367
+ # Save path configuration
368
+ st.session_state.save_path = st.text_input(
369
+ "Save Path",
370
+ value=st.session_state.save_path,
371
+ help="Directory where files will be saved"
372
+ )
373
+
374
+ # Navigation
375
+ st.header("Navigation")
376
+ if st.button("🏠 Home", use_container_width=True):
377
+ st.session_state.current_page = "home"
378
+
379
+ if st.session_state.audio_file and st.session_state.annotation_type:
380
+ if st.button("πŸ“ Transcription", use_container_width=True):
381
+ st.session_state.current_page = "transcription"
382
+
383
+ if st.session_state.annotation_type == "multi_speaker" and st.session_state.transcript:
384
+ if st.button("🎯 Segmentation", use_container_width=True):
385
+ st.session_state.current_page = "segmentation"
386
+
387
+ if st.session_state.segments:
388
+ if st.button("πŸ“Š Assignment", use_container_width=True):
389
+ st.session_state.current_page = "assignment"
390
+
391
+ # Main content area
392
+ if st.session_state.current_page == "home":
393
+ show_home_page()
394
+ elif st.session_state.current_page == "transcription":
395
+ show_transcription_page()
396
+ elif st.session_state.current_page == "segmentation":
397
+ show_segmentation_page()
398
+ elif st.session_state.current_page == "assignment":
399
+ show_assignment_page()
400
+
401
+ def show_home_page():
402
+ """Home page - annotation type selection and file upload"""
403
+ st.header("Welcome to ASR Annotation Tool")
404
+
405
+ # Annotation type selection
406
+ st.subheader("1. Select Annotation Type")
407
+ annotation_type = st.radio(
408
+ "Choose the type of annotation:",
409
+ ["single_speaker", "multi_speaker"],
410
+ format_func=lambda x: "Single Speaker (Simple ASR)" if x == "single_speaker" else "Multi Speaker (Diarization)",
411
+ key="annotation_type_radio"
412
+ )
413
+ st.session_state.annotation_type = annotation_type
414
+
415
+ # File upload
416
+ st.subheader("2. Upload Audio File")
417
+ uploaded_file = st.file_uploader(
418
+ "Choose an audio file",
419
+ type=['wav', 'mp3', 'flac', 'm4a'],
420
+ help="Supported formats: WAV, MP3, FLAC, M4A"
421
+ )
422
+
423
+ if uploaded_file is not None:
424
+ st.session_state.audio_file = uploaded_file.read()
425
+
426
+ # Save temporary file to get duration
427
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
428
+ tmp_file.write(st.session_state.audio_file)
429
+ st.session_state.audio_duration = get_audio_duration(tmp_file.name)
430
+ os.unlink(tmp_file.name)
431
+
432
+ st.success(f"βœ… Audio file uploaded successfully!")
433
+ st.info(f"Duration: {format_time(st.session_state.audio_duration)}")
434
+
435
+ # Show audio player
436
+ st.subheader("Audio Preview")
437
+ audio_html = create_audio_player_html(st.session_state.audio_file)
438
+ st.components.v1.html(audio_html, height=120)
439
+
440
+ # Continue button
441
+ if st.button("Continue to Transcription β†’", type="primary"):
442
+ st.session_state.current_page = "transcription"
443
+ st.rerun()
444
+
445
+ def show_transcription_page():
446
+ """Transcription page - text annotation"""
447
+ st.header("πŸ“ Text Transcription")
448
+
449
+ if not st.session_state.audio_file:
450
+ st.error("Please upload an audio file first!")
451
+ return
452
+
453
+ # Audio player
454
+ st.subheader("Audio Player")
455
+ audio_html = create_audio_player_html(st.session_state.audio_file)
456
+ st.components.v1.html(audio_html, height=120)
457
+
458
+ # Transcription area
459
+ st.subheader("Transcript")
460
+ transcript = st.text_area(
461
+ "Write your transcription here:",
462
+ value=st.session_state.transcript,
463
+ height=300,
464
+ help="Follow the annotation guidelines for accurate transcription"
465
+ )
466
+ st.session_state.transcript = transcript
467
+
468
+ # Guidelines reminder
469
+ with st.expander("πŸ“‹ Transcription Guidelines"):
470
+ st.markdown("""
471
+ **Key Guidelines:**
472
+ - Transcribe exactly what is said (verbatim)
473
+ - Include false starts, filled pauses (um, uh)
474
+ - Use standard punctuation
475
+ - Write numbers 1-10 as words, 11+ as digits
476
+ - Mark unclear speech as [unclear] or [inaudible]
477
+ - For multi-speaker: transcribe all audible speech
478
+ """)
479
+
480
+ # Action buttons
481
+ col1, col2, col3 = st.columns(3)
482
+
483
+ with col1:
484
+ if st.button("πŸ’Ύ Save Transcript", type="primary"):
485
+ if transcript.strip():
486
+ try:
487
+ transcript_path, _ = save_files(transcript, save_path=st.session_state.save_path)
488
+ st.success(f"βœ… Transcript saved to: {transcript_path}")
489
+ except Exception as e:
490
+ st.error(f"Error saving file: {e}")
491
+ else:
492
+ st.warning("Please write a transcript first!")
493
+
494
+ with col2:
495
+ if st.session_state.annotation_type == "multi_speaker" and transcript.strip():
496
+ if st.button("🎯 Continue to Segmentation β†’"):
497
+ st.session_state.current_page = "segmentation"
498
+ st.rerun()
499
+
500
+ with col3:
501
+ if st.session_state.annotation_type == "single_speaker" and transcript.strip():
502
+ if st.button("βœ… Finish Annotation"):
503
+ try:
504
+ transcript_path, _ = save_files(transcript, save_path=st.session_state.save_path)
505
+ st.balloons()
506
+ st.success(f"πŸŽ‰ Single speaker annotation completed!\nSaved to: {transcript_path}")
507
+ except Exception as e:
508
+ st.error(f"Error saving file: {e}")
509
+
510
+ def show_segmentation_page():
511
+ """Segmentation page - audio region selection"""
512
+ st.header("🎯 Audio Segmentation")
513
+
514
+ if not st.session_state.audio_file:
515
+ st.error("Please upload an audio file first!")
516
+ return
517
+
518
+ st.info("Click and drag on the waveform to create segments. Resize by dragging edges, remove with βœ• button.")
519
+
520
+ # Interactive waveform
521
+ waveform_html = create_waveform_html(st.session_state.audio_file, st.session_state.segments)
522
+ st.components.v1.html(waveform_html, height=500)
523
+
524
+ # Manual segment addition
525
+ st.subheader("Manual Segment Addition")
526
+ col1, col2, col3, col4 = st.columns(4)
527
+
528
+ with col1:
529
+ start_time = st.number_input("Start (seconds)", min_value=0.0, max_value=st.session_state.audio_duration, step=0.1)
530
+ with col2:
531
+ end_time = st.number_input("End (seconds)", min_value=0.0, max_value=st.session_state.audio_duration, step=0.1)
532
+ with col3:
533
+ speaker_id = st.text_input("Speaker ID", value="SPK001")
534
+ with col4:
535
+ if st.button("βž• Add Segment"):
536
+ if start_time < end_time:
537
+ new_segment = {
538
+ "start": start_time,
539
+ "end": end_time,
540
+ "speaker_id": speaker_id
541
+ }
542
+ st.session_state.segments.append(new_segment)
543
+ st.success("Segment added!")
544
+ st.rerun()
545
+ else:
546
+ st.error("End time must be greater than start time!")
547
+
548
+ # Current segments display
549
+ if st.session_state.segments:
550
+ st.subheader("Current Segments")
551
+ for i, segment in enumerate(st.session_state.segments):
552
+ col1, col2 = st.columns([4, 1])
553
+ with col1:
554
+ st.write(f"**Segment {i+1}:** {segment['speaker_id']} | {segment['start']:.2f}s - {segment['end']:.2f}s")
555
+ with col2:
556
+ if st.button("πŸ—‘οΈ", key=f"remove_{i}"):
557
+ st.session_state.segments.pop(i)
558
+ st.rerun()
559
+
560
+ # Continue button
561
+ if st.session_state.segments:
562
+ if st.button("πŸ“Š Continue to Assignment β†’", type="primary"):
563
+ st.session_state.current_page = "assignment"
564
+ st.rerun()
565
+
566
+ def show_assignment_page():
567
+ """Assignment page - text-to-segment mapping and final export"""
568
+ st.header("πŸ“Š Text-Segment Assignment")
569
+
570
+ if not st.session_state.segments:
571
+ st.error("Please create segments first!")
572
+ return
573
+
574
+ st.info("Assign portions of your transcript to each audio segment to create the final annotation.")
575
+
576
+ # Display transcript
577
+ st.subheader("Original Transcript")
578
+ st.text_area("Reference transcript:", value=st.session_state.transcript, height=150, disabled=True)
579
+
580
+ # Segment assignment
581
+ st.subheader("Segment Text Assignment")
582
+
583
+ assigned_segments = []
584
+ for i, segment in enumerate(st.session_state.segments):
585
+ st.write(f"**Segment {i+1}:** {segment['speaker_id']} ({segment['start']:.2f}s - {segment['end']:.2f}s)")
586
+
587
+ segment_text = st.text_area(
588
+ f"Text for segment {i+1}:",
589
+ key=f"segment_text_{i}",
590
+ height=100,
591
+ help="Copy and paste the relevant portion of the transcript for this segment"
592
+ )
593
+
594
+ assigned_segments.append({
595
+ **segment,
596
+ "text": segment_text
597
+ })
598
+
599
+ st.divider()
600
+
601
+ # Preview SRT
602
+ if st.button("πŸ” Preview SRT"):
603
+ srt_preview = generate_srt_with_text(assigned_segments)
604
+ st.subheader("SRT Preview")
605
+ st.code(srt_preview, language="text")
606
+
607
+ # Final save
608
+ st.subheader("Save Final Annotation")
609
+ col1, col2 = st.columns(2)
610
+
611
+ with col1:
612
+ if st.button("πŸ’Ύ Save Transcript + SRT", type="primary"):
613
+ try:
614
+ # Create enhanced transcript with speaker labels
615
+ enhanced_transcript = create_speaker_transcript(assigned_segments)
616
+
617
+ # Save files
618
+ transcript_path = os.path.join(st.session_state.save_path or ".", "final_transcript.txt")
619
+ srt_path = os.path.join(st.session_state.save_path or ".", "final_transcript.srt")
620
+
621
+ with open(transcript_path, "w", encoding="utf-8") as f:
622
+ f.write(enhanced_transcript)
623
+
624
+ srt_content = generate_srt_with_text(assigned_segments)
625
+ with open(srt_path, "w", encoding="utf-8") as f:
626
+ f.write(srt_content)
627
+
628
+ st.balloons()
629
+ st.success(f"πŸŽ‰ Multi-speaker annotation completed!\n\nFiles saved:\n- {transcript_path}\n- {srt_path}")
630
+
631
+ except Exception as e:
632
+ st.error(f"Error saving files: {e}")
633
+
634
+ with col2:
635
+ if st.button("πŸ”„ Back to Segmentation"):
636
+ st.session_state.current_page = "segmentation"
637
+ st.rerun()
638
+
639
+ def generate_srt_with_text(segments):
640
+ """Generate SRT with actual text content"""
641
+ srt_content = ""
642
+
643
+ for i, segment in enumerate(segments):
644
+ start_time = format_srt_time(segment['start'])
645
+ end_time = format_srt_time(segment['end'])
646
+ text = segment.get('text', '').strip() or f"[Segment {i+1} - No text assigned]"
647
+
648
+ srt_content += f"{i+1}\n"
649
+ srt_content += f"{start_time} --> {end_time}\n"
650
+ srt_content += f"{segment['speaker_id']}: {text}\n\n"
651
+
652
+ return srt_content
653
+
654
+ def create_speaker_transcript(segments):
655
+ """Create speaker-labeled transcript"""
656
+ transcript_lines = []
657
+
658
+ for segment in sorted(segments, key=lambda x: x['start']):
659
+ text = segment.get('text', '').strip()
660
+ if text:
661
+ transcript_lines.append(f"{segment['speaker_id']}: {text}")
662
+
663
+ return "\n\n".join(transcript_lines)
664
+
665
+ if __name__ == "__main__":
666
+ main()