Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import json | |
import wave | |
import numpy as np | |
from datetime import timedelta | |
import base64 | |
from io import BytesIO, StringIO | |
import tempfile | |
# Page configuration | |
st.set_page_config( | |
page_title="ASR Annotation Tool", | |
page_icon="π€", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Initialize session state | |
if 'annotation_type' not in st.session_state: | |
st.session_state.annotation_type = None | |
if 'audio_file' not in st.session_state: | |
st.session_state.audio_file = None | |
if 'transcript' not in st.session_state: | |
st.session_state.transcript = "" | |
if 'segments' not in st.session_state: | |
st.session_state.segments = [] | |
if 'current_page' not in st.session_state: | |
st.session_state.current_page = "home" | |
if 'audio_duration' not in st.session_state: | |
st.session_state.audio_duration = 0 | |
def get_audio_duration(audio_file): | |
"""Get audio duration in seconds""" | |
try: | |
with wave.open(audio_file, 'rb') as wav_file: | |
frames = wav_file.getnframes() | |
sample_rate = wav_file.getframerate() | |
duration = frames / float(sample_rate) | |
return duration | |
except: | |
return 0 | |
def format_time(seconds): | |
"""Format seconds to HH:MM:SS.mmm""" | |
td = timedelta(seconds=seconds) | |
total_seconds = int(td.total_seconds()) | |
hours, remainder = divmod(total_seconds, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
milliseconds = int((td.total_seconds() - total_seconds) * 1000) | |
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" | |
def create_audio_player_html(audio_data, audio_id="audio_player"): | |
"""Create HTML audio player with controls""" | |
audio_base64 = base64.b64encode(audio_data).decode() | |
html = f""" | |
<div style="margin: 20px 0;"> | |
<audio id="{audio_id}" controls style="width: 100%; height: 40px;"> | |
<source src="data:audio/wav;base64,{audio_base64}" type="audio/wav"> | |
Your browser does not support the audio element. | |
</audio> | |
<div style="margin-top: 10px;"> | |
<button onclick="document.getElementById('{audio_id}').currentTime -= 5" | |
style="margin-right: 5px; padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
βͺ -5s | |
</button> | |
<button onclick="document.getElementById('{audio_id}').currentTime -= 1" | |
style="margin-right: 5px; padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
βͺ -1s | |
</button> | |
<button onclick="var audio = document.getElementById('{audio_id}'); audio.paused ? audio.play() : audio.pause()" | |
style="margin-right: 5px; padding: 5px 15px; background: #00cc44; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
β―οΈ Play/Pause | |
</button> | |
<button onclick="document.getElementById('{audio_id}').currentTime += 1" | |
style="margin-right: 5px; padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
+1s β© | |
</button> | |
<button onclick="document.getElementById('{audio_id}').currentTime += 5" | |
style="padding: 5px 10px; background: #ff4b4b; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
+5s β© | |
</button> | |
</div> | |
</div> | |
""" | |
return html | |
def create_waveform_html(audio_data, segments=None): | |
"""Create interactive waveform with region selection""" | |
audio_base64 = base64.b64encode(audio_data).decode() | |
segments_json = json.dumps(segments or []) | |
html = f""" | |
<div id="waveform-container" style="margin: 20px 0;"> | |
<div id="waveform" style="height: 200px; border: 1px solid #ddd;"></div> | |
<div style="margin-top: 10px;"> | |
<button id="play-pause" style="margin-right: 5px; padding: 8px 15px; background: #00cc44; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
β―οΈ Play/Pause | |
</button> | |
<button id="add-region" style="margin-right: 5px; padding: 8px 15px; background: #0066cc; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
β Add Region | |
</button> | |
<button id="clear-regions" style="padding: 8px 15px; background: #cc0000; color: white; border: none; border-radius: 3px; cursor: pointer;"> | |
ποΈ Clear All | |
</button> | |
</div> | |
<div id="regions-list" style="margin-top: 15px; max-height: 200px; overflow-y: auto; color: white;"> | |
<h4>Segments:</h4> | |
<div id="segments-container"></div> | |
</div> | |
</div> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/wavesurfer.js/6.6.4/wavesurfer.min.js"></script> | |
<script src="https://cdnjs.cloudflare.com/ajax/libs/wavesurfer.js/6.6.4/plugin/wavesurfer.regions.min.js"></script> | |
<script> | |
let wavesurfer; | |
let regions = {segments_json}; | |
let speakerColors = {{}}; | |
// Initialize WaveSurfer | |
wavesurfer = WaveSurfer.create({{ | |
container: '#waveform', | |
waveColor: '#4FC3F7', | |
progressColor: '#1976D2', | |
height: 200, | |
responsive: true, | |
plugins: [ | |
WaveSurfer.regions.create({{ | |
dragSelection: true, | |
color: 'rgba(255, 75, 75, 0.3)' | |
}}) | |
] | |
}}); | |
// Load audio | |
wavesurfer.load('data:audio/wav;base64,{audio_base64}'); | |
// Play/Pause button | |
document.getElementById('play-pause').addEventListener('click', function() {{ | |
wavesurfer.playPause(); | |
}}); | |
// Add region button | |
document.getElementById('add-region').addEventListener('click', function() {{ | |
const start = wavesurfer.getCurrentTime(); | |
const end = Math.min(start + 2, wavesurfer.getDuration()); | |
// Ask for speaker ID first | |
const speakerId = prompt("Enter speaker ID (e.g., SPK001):", "SPK" + (Object.keys(speakerColors).length + 1).toString().padStart(3, '0')); | |
if (speakerId) {{ | |
addRegion(start, end, speakerId); | |
}} | |
}}); | |
// Clear regions button | |
document.getElementById('clear-regions').addEventListener('click', function() {{ | |
wavesurfer.clearRegions(); | |
regions = []; | |
updateRegionsList(); | |
}}); | |
// Add region function | |
function addRegion(start, end, speaker_id) {{ | |
// Get or assign color for this speaker | |
if (!speakerColors[speaker_id]) {{ | |
speakerColors[speaker_id] = getColorForSpeaker(speaker_id); | |
}} | |
const region = wavesurfer.addRegion({{ | |
start: start, | |
end: end, | |
color: speakerColors[speaker_id], | |
drag: true, | |
resize: true | |
}}); | |
regions.push({{ | |
id: region.id, | |
start: start, | |
end: end, | |
speaker_id: speaker_id | |
}}); | |
updateRegionsList(); | |
}} | |
// Update regions list | |
function updateRegionsList() {{ | |
const container = document.getElementById('segments-container'); | |
container.innerHTML = ''; | |
regions.forEach((region, index) => {{ | |
const div = document.createElement('div'); | |
div.style.cssText = 'border: 1px solid #ddd; padding: 10px; margin: 5px 0; border-radius: 5px; color: white;'; | |
div.innerHTML = ` | |
<div style="display: flex; justify-content: space-between; align-items: center;"> | |
<div> | |
<strong>Segment ${{index + 1}}</strong><br> | |
Start: ${{region.start.toFixed(2)}}s | End: ${{region.end.toFixed(2)}}s<br> | |
<input type="text" value="${{region.speaker_id}}" | |
onchange="updateSpeakerId('${{region.id}}', this.value, '${{region.speaker_id}}')" | |
style="margin-top: 5px; padding: 3px; border: 1px solid #ccc; border-radius: 3px; color: black;"> | |
</div> | |
<button onclick="removeRegion('${{region.id}}')" | |
style="background: #cc0000; color: white; border: none; border-radius: 3px; padding: 5px 8px; cursor: pointer;"> | |
β | |
</button> | |
</div> | |
`; | |
container.appendChild(div); | |
}}); | |
}} | |
// Remove region | |
function removeRegion(regionId) {{ | |
wavesurfer.regions.list[regionId].remove(); | |
regions = regions.filter(r => r.id !== regionId); | |
updateRegionsList(); | |
}} | |
// Update speaker ID | |
function updateSpeakerId(regionId, newId, oldId) {{ | |
const region = regions.find(r => r.id === regionId); | |
if (region) {{ | |
region.speaker_id = newId; | |
// Update color if speaker ID changed | |
if (newId !== oldId) {{ | |
if (!speakerColors[newId]) {{ | |
speakerColors[newId] = getColorForSpeaker(newId); | |
}} | |
wavesurfer.regions.list[regionId].color = speakerColors[newId]; | |
wavesurfer.regions.list[regionId].updateRender(); | |
}} | |
}} | |
}} | |
// Get consistent color for a specific speaker | |
function getColorForSpeaker(speakerId) {{ | |
const colors = [ | |
'rgba(255, 75, 75, 0.3)', // Red | |
'rgba(75, 192, 75, 0.3)', // Green | |
'rgba(75, 75, 255, 0.3)', // Blue | |
'rgba(255, 192, 75, 0.3)', // Yellow | |
'rgba(255, 75, 255, 0.3)', // Magenta | |
'rgba(75, 192, 192, 0.3)', // Cyan | |
'rgba(192, 75, 192, 0.3)', // Purple | |
'rgba(192, 192, 75, 0.3)' // Olive | |
]; | |
// Generate a deterministic index based on the speaker ID string | |
let hash = 0; | |
for (let i = 0; i < speakerId.length; i++) {{ | |
hash = ((hash << 5) - hash) + speakerId.charCodeAt(i); | |
hash |= 0; // Convert to 32bit integer | |
}} | |
// Use the absolute value of hash to select a color | |
const index = Math.abs(hash) % colors.length; | |
return colors[index]; | |
}} | |
// Update region on change | |
wavesurfer.on('region-update-end', function(region) {{ | |
const regionData = regions.find(r => r.id === region.id); | |
if (regionData) {{ | |
regionData.start = region.start; | |
regionData.end = region.end; | |
updateRegionsList(); | |
}} | |
}}); | |
// Load existing regions | |
wavesurfer.on('ready', function() {{ | |
// First, create color mappings for existing speakers | |
regions.forEach(regionData => {{ | |
if (!speakerColors[regionData.speaker_id]) {{ | |
speakerColors[regionData.speaker_id] = getColorForSpeaker(regionData.speaker_id); | |
}} | |
}}); | |
// Then create the regions with their colors | |
regions.forEach(regionData => {{ | |
const region = wavesurfer.addRegion({{ | |
start: regionData.start, | |
end: regionData.end, | |
color: speakerColors[regionData.speaker_id], | |
drag: true, | |
resize: true | |
}}); | |
regionData.id = region.id; | |
}}); | |
updateRegionsList(); | |
}}); | |
// Export regions function for Streamlit | |
window.getRegions = function() {{ | |
return regions.map(r => ({{ | |
start: r.start, | |
end: r.end, | |
speaker_id: r.speaker_id | |
}})); | |
}} | |
</script> | |
""" | |
return html | |
def generate_srt(segments, transcript): | |
"""Generate SRT format from segments and transcript""" | |
srt_content = "" | |
for i, segment in enumerate(segments): | |
start_time = format_srt_time(segment['start']) | |
end_time = format_srt_time(segment['end']) | |
# Extract corresponding text (simplified - in real app you'd need better text matching) | |
text = f"{segment['speaker_id']}: [Segment {i+1} text]" | |
srt_content += f"{i+1}\n" | |
srt_content += f"{start_time} --> {end_time}\n" | |
srt_content += f"{text}\n\n" | |
return srt_content | |
def format_srt_time(seconds): | |
"""Format time for SRT format (HH:MM:SS,mmm)""" | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = int(seconds % 60) | |
millisecs = int((seconds % 1) * 1000) | |
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}" | |
def get_download_link(content, filename, label="Download file"): | |
"""Generate download link for text content""" | |
b64 = base64.b64encode(content.encode()).decode() | |
href = f'<a href="data:file/txt;base64,{b64}" download="{filename}">{label}</a>' | |
return href | |
# Main App Layout | |
def main(): | |
st.title("π€ ASR Annotation Tool") | |
st.markdown("Simple tool for transcribing, segmenting, and annotating audio for ASR dataset creation.") | |
# Sidebar for navigation and settings | |
with st.sidebar: | |
st.header("Navigation") | |
if st.button("π Home", use_container_width=True): | |
st.session_state.current_page = "home" | |
if st.session_state.audio_file and st.session_state.annotation_type: | |
if st.button("π Transcription", use_container_width=True): | |
st.session_state.current_page = "transcription" | |
if st.session_state.annotation_type == "multi_speaker" and st.session_state.transcript: | |
if st.button("π― Segmentation", use_container_width=True): | |
st.session_state.current_page = "segmentation" | |
if st.session_state.segments: | |
if st.button("π Assignment", use_container_width=True): | |
st.session_state.current_page = "assignment" | |
if st.session_state.current_page == "home": | |
show_home_page() | |
elif st.session_state.current_page == "transcription": | |
show_transcription_page() | |
elif st.session_state.current_page == "segmentation": | |
show_segmentation_page() | |
elif st.session_state.current_page == "assignment": | |
show_assignment_page() | |
def show_home_page(): | |
"""Home page - annotation type selection and file upload""" | |
# Annotation type selection | |
st.subheader("1. Select Annotation Type") | |
annotation_type = st.radio( | |
"How many speakers are in your audio?", | |
["single_speaker", "multi_speaker"], | |
format_func=lambda x: "Single Speaker (Simple ASR)" if x == "single_speaker" else "Multi Speaker (Diarization)", | |
key="annotation_type_radio" | |
) | |
st.session_state.annotation_type = annotation_type | |
# File upload with better error handling | |
st.subheader("2. Upload Audio File") | |
# Add file size warning for Hugging Face Spaces | |
st.info("π‘ **Tip for Hugging Face Spaces:** Large files (>10MB) may fail to upload. Try smaller audio files or compress your audio if you encounter issues.") | |
uploaded_file = st.file_uploader( | |
"Choose an audio file", | |
type=['wav', 'mp3', 'flac', 'm4a'], | |
help="Supported formats: WAV, MP3, FLAC, M4A" | |
) | |
if uploaded_file is not None: | |
st.session_state.audio_file = uploaded_file.read() | |
# Save temporary file to get duration | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: | |
tmp_file.write(st.session_state.audio_file) | |
st.session_state.audio_duration = get_audio_duration(tmp_file.name) | |
os.unlink(tmp_file.name) | |
st.success(f"β Audio file uploaded successfully!") | |
st.info(f"Duration: {format_time(st.session_state.audio_duration)}") | |
# Show audio player | |
st.subheader("Audio Preview") | |
audio_html = create_audio_player_html(st.session_state.audio_file) | |
st.components.v1.html(audio_html, height=120) | |
# Continue button | |
if st.button("Continue to Transcription β", type="primary"): | |
st.session_state.current_page = "transcription" | |
st.rerun() | |
def show_transcription_page(): | |
"""Transcription page - text annotation""" | |
st.header("π Text Transcription") | |
if not st.session_state.audio_file: | |
st.error("Please upload an audio file first!") | |
return | |
# Audio player | |
st.subheader("Audio Player") | |
audio_html = create_audio_player_html(st.session_state.audio_file) | |
st.components.v1.html(audio_html, height=120) | |
# Transcription area | |
st.subheader("Transcript") | |
transcript = st.text_area( | |
"Write your transcription here:", | |
value=st.session_state.transcript, | |
height=300, | |
help="Check the guidelines below to help you transcribe accurately." | |
) | |
st.session_state.transcript = transcript | |
# Guidelines reminder | |
with st.expander("π Transcription Guidelines"): | |
st.markdown(""" | |
**Key Guidelines:** | |
- Transcribe exactly what is said | |
- Use standard punctuation and capitalization (tip: Get punctuation from natural pauses in dialogue) | |
- Write numbers 1-10 as words, 11+ as digits | |
- Ignore unclear speech or marked as [unclear] or [inaudible] | |
- For multi-speaker: transcribe all audible speech without identifying speakers | |
""") | |
# Action buttons | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
if transcript.strip(): | |
download_link = get_download_link(transcript, "transcript.txt", "πΎ Download Transcript") | |
st.markdown(download_link, unsafe_allow_html=True) | |
else: | |
st.button("πΎ Download Transcript", disabled=True) | |
with col2: | |
if st.session_state.annotation_type == "multi_speaker" and transcript.strip(): | |
if st.button("π― Continue to Segmentation β"): | |
st.session_state.current_page = "segmentation" | |
st.rerun() | |
with col3: | |
if st.session_state.annotation_type == "single_speaker" and transcript.strip(): | |
if st.button("β Finish Annotation"): | |
st.balloons() | |
st.success("π Single speaker annotation completed!") | |
download_link = get_download_link(transcript, "transcript.txt", "π₯ Download Final Transcript") | |
st.markdown(download_link, unsafe_allow_html=True) | |
def show_segmentation_page(): | |
"""Segmentation page - audio region selection""" | |
st.header("π― Audio Segmentation") | |
if not st.session_state.audio_file: | |
st.error("Please upload an audio file first!") | |
return | |
st.info("Click and drag on the waveform to create segments. Resize by dragging edges, remove with β button.") | |
# Interactive waveform | |
waveform_html = create_waveform_html(st.session_state.audio_file, st.session_state.segments) | |
st.components.v1.html(waveform_html, height=500) | |
# Manual segment addition | |
st.subheader("Manual Segment Addition") | |
st.info("After having segmented the wav using our wav surfer, you can manually add segments here. Don't hesitate to replay and pause for the best results.") | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
start_time = st.number_input("Start (seconds)", min_value=0.0, max_value=st.session_state.audio_duration, step=0.1) | |
with col2: | |
end_time = st.number_input("End (seconds)", min_value=0.0, max_value=st.session_state.audio_duration, step=0.1) | |
with col3: | |
speaker_id = st.text_input("Speaker ID", value="SPK001") | |
with col4: | |
if st.button("β Add Segment"): | |
if start_time < end_time: | |
new_segment = { | |
"start": start_time, | |
"end": end_time, | |
"speaker_id": speaker_id | |
} | |
st.session_state.segments.append(new_segment) | |
st.success("Segment added!") | |
st.rerun() | |
else: | |
st.error("End time must be greater than start time!") | |
# Current segments display | |
if st.session_state.segments: | |
st.subheader("Current Segments") | |
for i, segment in enumerate(st.session_state.segments): | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
st.write(f"**Segment {i+1}:** {segment['speaker_id']} | {segment['start']:.2f}s - {segment['end']:.2f}s") | |
with col2: | |
if st.button("ποΈ", key=f"remove_{i}"): | |
st.session_state.segments.pop(i) | |
st.rerun() | |
# Continue button | |
if st.session_state.segments: | |
if st.button("π Continue to Assignment β", type="primary"): | |
st.session_state.current_page = "assignment" | |
st.rerun() | |
def show_assignment_page(): | |
"""Assignment page - text-to-segment mapping and final export""" | |
st.header("π Text-Segment Assignment") | |
if not st.session_state.segments: | |
st.error("Please create segments first!") | |
return | |
st.info("Assign portions of your text transcript to each audio segment to create the final annotation.") | |
# Display transcript | |
st.subheader("Original Transcript") | |
st.text_area("Reference transcript:", value=st.session_state.transcript, height=150, disabled=True) | |
# Segment assignment | |
st.subheader("Segment Text Assignment") | |
assigned_segments = [] | |
for i, segment in enumerate(st.session_state.segments): | |
st.write(f"**Segment {i+1}:** {segment['speaker_id']} ({segment['start']:.2f}s - {segment['end']:.2f}s)") | |
segment_text = st.text_area( | |
f"Text for segment {i+1}:", | |
key=f"segment_text_{i}", | |
height=100, | |
help="Copy and paste the relevant portion of the text transcript for this segment" | |
) | |
assigned_segments.append({ | |
**segment, | |
"text": segment_text | |
}) | |
st.divider() | |
# Preview SRT | |
if st.button("π Preview SRT"): | |
srt_preview = generate_srt_with_text(assigned_segments) | |
st.subheader("SRT Preview") | |
st.code(srt_preview, language="text") | |
# Final save | |
st.subheader("Download Final Annotation") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Create enhanced transcript with speaker labels | |
enhanced_transcript = create_speaker_transcript(assigned_segments) | |
download_transcript = get_download_link(enhanced_transcript, "final_transcript.txt", "πΎ Download Transcript") | |
st.markdown(download_transcript, unsafe_allow_html=True) | |
with col2: | |
srt_content = generate_srt_with_text(assigned_segments) | |
download_srt = get_download_link(srt_content, "final_transcript.srt", "πΎ Download SRT") | |
st.markdown(download_srt, unsafe_allow_html=True) | |
if st.button("π Finish Annotation", type="primary"): | |
st.balloons() | |
st.success("π Yihawww or Youhouuuu Multi-speaker annotation completed!") | |
# Final downloads | |
st.subheader("Download your files:") | |
download_transcript = get_download_link(enhanced_transcript, "final_transcript.txt", "π₯ Download Transcript") | |
download_srt = get_download_link(srt_content, "final_transcript.srt", "π₯ Download SRT") | |
st.markdown(download_transcript, unsafe_allow_html=True) | |
st.markdown(download_srt, unsafe_allow_html=True) | |
if st.button("π Back to Segmentation"): | |
st.session_state.current_page = "segmentation" | |
st.rerun() | |
def generate_srt_with_text(segments): | |
"""Generate SRT with actual text content""" | |
srt_content = "" | |
for i, segment in enumerate(segments): | |
start_time = format_srt_time(segment['start']) | |
end_time = format_srt_time(segment['end']) | |
text = segment.get('text', '').strip() or f"[Segment {i+1} - No text assigned]" | |
srt_content += f"{i+1}\n" | |
srt_content += f"{start_time} --> {end_time}\n" | |
srt_content += f"{segment['speaker_id']}: {text}\n\n" | |
return srt_content | |
def create_speaker_transcript(segments): | |
"""Create speaker-labeled transcript""" | |
transcript_lines = [] | |
for segment in sorted(segments, key=lambda x: x['start']): | |
text = segment.get('text', '').strip() | |
if text: | |
transcript_lines.append(f"{segment['speaker_id']}: {text}") | |
return "\n\n".join(transcript_lines) | |
if __name__ == "__main__": | |
main() | |