|
import os |
|
import torch |
|
import torchaudio |
|
import librosa |
|
import streamlit as st |
|
from huggingface_hub import login |
|
from transformers import AutoProcessor, AutoModelForCTC |
|
import numpy as np |
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("hf_token") |
|
|
|
if HF_TOKEN is None: |
|
raise ValueError("β Hugging Face API token not found. Please set it in Secrets.") |
|
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
MODEL_NAME = "deepl-project/conformer-finetunning" |
|
processor = AutoProcessor.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCTC.from_pretrained(MODEL_NAME) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
print(f"β
Conformer Model loaded on {device}") |
|
|
|
|
|
|
|
|
|
st.sidebar.title("π§ Fine-Tuning Hyperparameters") |
|
num_epochs = st.sidebar.slider("Epochs", min_value=1, max_value=10, value=3) |
|
learning_rate = st.sidebar.select_slider("Learning Rate", options=[5e-4, 1e-4, 5e-5, 1e-5], value=5e-5) |
|
batch_size = st.sidebar.select_slider("Batch Size", options=[2, 4, 8, 16], value=8) |
|
attack_strength = st.sidebar.slider("Attack Strength", 0.0, 0.9, 0.1) |
|
|
|
|
|
|
|
|
|
st.title("ποΈ Speech-to-Text ASR Conformer Model Finetunned on Libri Speech with Security Features πΆ") |
|
|
|
audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"]) |
|
|
|
if audio_file: |
|
audio_path = "temp_audio.wav" |
|
with open(audio_path, "wb") as f: |
|
f.write(audio_file.read()) |
|
|
|
speech, sr = librosa.load(audio_path, sr=16000) |
|
|
|
|
|
adversarial_speech = speech + (attack_strength * np.random.randn(*speech.shape)) |
|
adversarial_speech = np.clip(adversarial_speech, -1.0, 1.0) |
|
|
|
inputs = processor(adversarial_speech, sampling_rate=sr, return_tensors="pt", padding=True) |
|
input_values = inputs.input_values.to(device) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_values).logits |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
|
|
if attack_strength > 0.2: |
|
st.warning("β οΈ Adversarial attack detected! Transcription may be affected.") |
|
|
|
st.success("π Secure Transcription:") |
|
st.write(transcription[0]) |
|
|