dipankar53 commited on
Commit
cfa5958
·
1 Parent(s): 978917b

Add requirements.txt, Modifiy app.py

Browse files
Files changed (2) hide show
  1. app.py +110 -4
  2. requirements.txt +5 -0
app.py CHANGED
@@ -1,7 +1,113 @@
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import whisper
3
+ import torchaudio
4
  import gradio as gr
5
+ import torch.nn as nn
6
+ from huggingface_hub import hf_hub_download
7
 
 
 
8
 
9
+ # Define the same model class used during training
10
+ class DialectClassifier(nn.Module):
11
+ def __init__(self, input_dim, num_classes):
12
+ super(DialectClassifier, self).__init__()
13
+ self.fc1 = nn.Linear(input_dim, 128)
14
+ self.fc2 = nn.Linear(128, 64)
15
+ self.fc3 = nn.Linear(64, num_classes)
16
+ self.relu = nn.ReLU()
17
+
18
+ def forward(self, x):
19
+ x = x.view(x.size(0), -1) # Flatten the input tensor
20
+ x = self.relu(self.fc1(x))
21
+ x = self.relu(self.fc2(x))
22
+ x = self.fc3(x)
23
+ return x
24
+
25
+ # Function to preprocess audio and extract features
26
+ def preprocess_audio(file_path, whisper_model, device):
27
+ def load_audio(file_path):
28
+ waveform, sample_rate = torchaudio.load(file_path)
29
+ if sample_rate != 16000:
30
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
31
+
32
+ # Convert to single channel (mono) if necessary
33
+ if waveform.shape[0] > 1:
34
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
35
+
36
+ # Pad or trim audio to 30 seconds
37
+ desired_length = 16000 * 30 # 30 seconds at 16 kHz
38
+ current_length = waveform.shape[1]
39
+
40
+ if current_length < desired_length:
41
+ # Pad with zeros
42
+ padding = desired_length - current_length
43
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
44
+ elif current_length > desired_length:
45
+ # Trim to desired length
46
+ waveform = waveform[:, :desired_length]
47
+
48
+ return waveform
49
+
50
+ audio = load_audio(file_path)
51
+ audio = whisper.pad_or_trim(audio.flatten())
52
+ mel = whisper.log_mel_spectrogram(audio).to_dense()
53
+
54
+ with torch.no_grad():
55
+ mel = mel.unsqueeze(0).to(device) # Add batch dimension and move to device
56
+ features = whisper_model.encoder(mel)
57
+ return features
58
+
59
+ repo_id = "dipankar53/assamese_dialect_classifier_model"
60
+ model_filename = "dialect_classifier_model.pth"
61
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
62
+
63
+ label_to_idx = {"Darrangiya Accent": 0, "Kamrupiya Accent": 1, "Upper Assam": 2, "Nalbaria Accent": 3}
64
+
65
+ # Load Whisper model
66
+ whisper_model = whisper.load_model("medium")
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+
70
+ # Load the trained model
71
+ num_classes = len(label_to_idx)
72
+ sample_input = torch.randn(1, 80, 3000).to(device)
73
+ with torch.no_grad():
74
+ sample_output = whisper_model.encoder(sample_input)
75
+ input_dim = sample_output.view(1, -1).shape[1] # Flatten and get dimension
76
+
77
+ model = DialectClassifier(input_dim, num_classes)
78
+ model.load_state_dict(torch.load(model_path, map_location=device))
79
+ model.to(device)
80
+ model.eval()
81
+
82
+ # Function to predict the dialect of a single audio file
83
+ def predict_dialect(audio_path):
84
+ try:
85
+ # Preprocess audio and extract features
86
+ features = preprocess_audio(audio_path, whisper_model, device)
87
+ features = features.view(1, -1) # Flatten features
88
+
89
+ # Perform prediction
90
+ with torch.no_grad():
91
+ outputs = model(features)
92
+ _, predicted = torch.max(outputs, 1)
93
+
94
+ # Map predicted index back to dialect label
95
+ idx_to_label = {idx: label for label, idx in label_to_idx.items()}
96
+ predicted_label = idx_to_label[predicted.item()]
97
+
98
+ return f"Predicted Dialect: {predicted_label}"
99
+ except Exception as e:
100
+ return f"Error: {str(e)}"
101
+
102
+ # Define Gradio interface
103
+ interface = gr.Interface(
104
+ fn=predict_dialect,
105
+ inputs=gr.Audio(sources=["upload", "microphone"], type="filepath"),
106
+ outputs="text",
107
+ title="Assamese Dialect Prediction",
108
+ description="Upload an Assamese audio file to predict its dialect.",
109
+ )
110
+
111
+ # Launch the interface
112
+ if __name__ == "__main__":
113
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ whisper
3
+ gradio
4
+ huggingface_hub
5
+ torchaudio