Commit
·
9aa735a
0
Parent(s):
Duplicate from team7/talk_with_wind
Browse filesCo-authored-by: Omar Sanseviero <[email protected]>
- .gitattributes +37 -0
- .gitignore +1 -0
- README.md +13 -0
- app.py +121 -0
- efficientat/helpers/flop_count.py +162 -0
- efficientat/helpers/init.py +33 -0
- efficientat/helpers/utils.py +104 -0
- efficientat/metadata/class_labels_indices.csv +528 -0
- efficientat/models/MobileNetV3.py +349 -0
- efficientat/models/attention_pooling.py +56 -0
- efficientat/models/block_types.py +182 -0
- efficientat/models/preprocess.py +67 -0
- efficientat/models/utils.py +59 -0
- efficientat/resources/README.md +1 -0
- efficientat/resources/metro_station-paris.wav +3 -0
- logo.png +3 -0
- metro_station-paris.wav +3 -0
- packages.txt +1 -0
- requirements.txt +11 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
efficientat/resources/metro_station-paris.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
logo.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Talk With Wind
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.16.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: team7/talk_with_wind
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import librosa
|
7 |
+
|
8 |
+
from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model
|
9 |
+
from efficientat.models.preprocess import AugmentMelSTFT
|
10 |
+
from efficientat.helpers.utils import NAME_TO_WIDTH, labels
|
11 |
+
|
12 |
+
from torch import autocast
|
13 |
+
from contextlib import nullcontext
|
14 |
+
|
15 |
+
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
|
16 |
+
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory
|
17 |
+
|
18 |
+
MODEL_NAME = "mn40_as"
|
19 |
+
|
20 |
+
session_token = os.environ["SESSION_TOKEN"]
|
21 |
+
|
22 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
23 |
+
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME)
|
24 |
+
model.to(device)
|
25 |
+
model.eval()
|
26 |
+
|
27 |
+
cached_audio_class = "c"
|
28 |
+
template = None
|
29 |
+
prompt = None
|
30 |
+
chain = None
|
31 |
+
|
32 |
+
|
33 |
+
def format_classname(classname):
|
34 |
+
return classname.capitalize()
|
35 |
+
|
36 |
+
def audio_tag(
|
37 |
+
audio_path,
|
38 |
+
human_input,
|
39 |
+
sample_rate=32000,
|
40 |
+
window_size=800,
|
41 |
+
hop_size=320,
|
42 |
+
n_mels=128,
|
43 |
+
cuda=True,
|
44 |
+
):
|
45 |
+
|
46 |
+
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
|
47 |
+
mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
|
48 |
+
mel.to(device)
|
49 |
+
mel.eval()
|
50 |
+
waveform = torch.from_numpy(waveform[None, :]).to(device)
|
51 |
+
|
52 |
+
# our models are trained in half precision mode (torch.float16)
|
53 |
+
# run on cuda with torch.float16 to get the best performance
|
54 |
+
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
|
55 |
+
with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext():
|
56 |
+
spec = mel(waveform)
|
57 |
+
preds, features = model(spec.unsqueeze(0))
|
58 |
+
preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()
|
59 |
+
|
60 |
+
sorted_indexes = np.argsort(preds)[::-1]
|
61 |
+
output = {}
|
62 |
+
# Print audio tagging top probabilities
|
63 |
+
|
64 |
+
label = labels[sorted_indexes[0]]
|
65 |
+
return formatted_message(format_classname(label), human_input)
|
66 |
+
|
67 |
+
def format_classname(classname):
|
68 |
+
return classname.capitalize()
|
69 |
+
|
70 |
+
def formatted_message(audio_class, human_input):
|
71 |
+
global cached_audio_class
|
72 |
+
global session_token
|
73 |
+
formatted_classname = format_classname(audio_class)
|
74 |
+
if cached_audio_class != formatted_classname:
|
75 |
+
|
76 |
+
cached_audio_class = formatted_classname
|
77 |
+
|
78 |
+
prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. The goal is for you to embody that non-human entity and converse with the human.
|
79 |
+
Examples:
|
80 |
+
Non-human Entity: Tree
|
81 |
+
Human Input: Hello tree
|
82 |
+
Tree: Hello human, I am a tree
|
83 |
+
Let's begin:
|
84 |
+
Non-human Entity: {formatted_classname}"""
|
85 |
+
suffix = f'''
|
86 |
+
{{history}}
|
87 |
+
Human Input: {{human_input}}
|
88 |
+
{formatted_classname}:'''
|
89 |
+
template = prefix + suffix
|
90 |
+
|
91 |
+
prompt = PromptTemplate(
|
92 |
+
input_variables=["history", "human_input"],
|
93 |
+
template=template
|
94 |
+
)
|
95 |
+
|
96 |
+
chatgpt_chain = LLMChain(
|
97 |
+
llm=OpenAI(temperature=.5, openai_api_key=session_token),
|
98 |
+
prompt=prompt,
|
99 |
+
verbose=True,
|
100 |
+
memory=ConversationalBufferWindowMemory(k=2, ai=formatted_classname),
|
101 |
+
)
|
102 |
+
output = chatgpt_chain.predict(human_input=human_input)
|
103 |
+
|
104 |
+
return output
|
105 |
+
|
106 |
+
with gr.Blocks() as demo:
|
107 |
+
gr.HTML("""
|
108 |
+
<div style='text-align: center; width:100%; margin: auto;'>
|
109 |
+
<img src='./logo.png' alt='anychat' width='250px' />
|
110 |
+
<h3>Non-Human entities have many things to say, listen to them!</h3>
|
111 |
+
</div>
|
112 |
+
""")
|
113 |
+
with gr.Row():
|
114 |
+
with gr.Column():
|
115 |
+
aud = gr.Audio(source="upload", type="filepath", label="Your audio")
|
116 |
+
inp = gr.Textbox()
|
117 |
+
out = gr.Textbox()
|
118 |
+
btn = gr.Button("Run")
|
119 |
+
btn.click(fn=audio_tag, inputs=[aud, inp], outputs=out)
|
120 |
+
|
121 |
+
demo.launch()
|
efficientat/helpers/flop_count.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
# adapted from PANNs (https://github.com/qiuqiangkong/audioset_tagging_cnn)
|
6 |
+
|
7 |
+
def count_macs(model, spec_size):
|
8 |
+
list_conv2d = []
|
9 |
+
|
10 |
+
def conv2d_hook(self, input, output):
|
11 |
+
batch_size, input_channels, input_height, input_width = input[0].size()
|
12 |
+
assert batch_size == 1
|
13 |
+
output_channels, output_height, output_width = output[0].size()
|
14 |
+
|
15 |
+
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
|
16 |
+
bias_ops = 1 if self.bias is not None else 0
|
17 |
+
|
18 |
+
params = output_channels * (kernel_ops + bias_ops)
|
19 |
+
# overall macs count is:
|
20 |
+
# kernel**2 * in_channels/groups * out_channels * out_width * out_height
|
21 |
+
macs = batch_size * params * output_height * output_width
|
22 |
+
|
23 |
+
list_conv2d.append(macs)
|
24 |
+
|
25 |
+
list_linear = []
|
26 |
+
|
27 |
+
def linear_hook(self, input, output):
|
28 |
+
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
|
29 |
+
assert batch_size == 1
|
30 |
+
weight_ops = self.weight.nelement()
|
31 |
+
bias_ops = self.bias.nelement()
|
32 |
+
|
33 |
+
# overall macs count is equal to the number of parameters in layer
|
34 |
+
macs = batch_size * (weight_ops + bias_ops)
|
35 |
+
list_linear.append(macs)
|
36 |
+
|
37 |
+
def foo(net):
|
38 |
+
if net.__class__.__name__ == 'Conv2dStaticSamePadding':
|
39 |
+
net.register_forward_hook(conv2d_hook)
|
40 |
+
childrens = list(net.children())
|
41 |
+
if not childrens:
|
42 |
+
if isinstance(net, nn.Conv2d):
|
43 |
+
net.register_forward_hook(conv2d_hook)
|
44 |
+
elif isinstance(net, nn.Linear):
|
45 |
+
net.register_forward_hook(linear_hook)
|
46 |
+
else:
|
47 |
+
print('Warning: flop of module {} is not counted!'.format(net))
|
48 |
+
return
|
49 |
+
for c in childrens:
|
50 |
+
foo(c)
|
51 |
+
|
52 |
+
# Register hook
|
53 |
+
foo(model)
|
54 |
+
|
55 |
+
device = next(model.parameters()).device
|
56 |
+
input = torch.rand(spec_size).to(device)
|
57 |
+
with torch.no_grad():
|
58 |
+
model(input)
|
59 |
+
|
60 |
+
total_macs = sum(list_conv2d) + sum(list_linear)
|
61 |
+
|
62 |
+
print("*************Computational Complexity (multiply-adds) **************")
|
63 |
+
print("Number of Convolutional Layers: ", len(list_conv2d))
|
64 |
+
print("Number of Linear Layers: ", len(list_linear))
|
65 |
+
print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs)))
|
66 |
+
print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs))
|
67 |
+
print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9))
|
68 |
+
print("********************************************************************")
|
69 |
+
return total_macs
|
70 |
+
|
71 |
+
|
72 |
+
def count_macs_transformer(model, spec_size):
|
73 |
+
"""Count macs. Code modified from others' implementation.
|
74 |
+
"""
|
75 |
+
list_conv2d = []
|
76 |
+
|
77 |
+
def conv2d_hook(self, input, output):
|
78 |
+
batch_size, input_channels, input_height, input_width = input[0].size()
|
79 |
+
assert batch_size == 1
|
80 |
+
output_channels, output_height, output_width = output[0].size()
|
81 |
+
|
82 |
+
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
|
83 |
+
bias_ops = 1 if self.bias is not None else 0
|
84 |
+
|
85 |
+
params = output_channels * (kernel_ops + bias_ops)
|
86 |
+
# overall macs count is:
|
87 |
+
# kernel**2 * in_channels/groups * out_channels * out_width * out_height
|
88 |
+
macs = batch_size * params * output_height * output_width
|
89 |
+
|
90 |
+
list_conv2d.append(macs)
|
91 |
+
|
92 |
+
list_linear = []
|
93 |
+
|
94 |
+
def linear_hook(self, input, output):
|
95 |
+
batch_size = input[0].size(0) if input[0].dim() >= 2 else 1
|
96 |
+
assert batch_size == 1
|
97 |
+
if input[0].dim() == 3:
|
98 |
+
# (batch size, sequence length, embeddings size)
|
99 |
+
batch_size, seq_len, embed_size = input[0].size()
|
100 |
+
|
101 |
+
weight_ops = self.weight.nelement()
|
102 |
+
bias_ops = self.bias.nelement() if self.bias is not None else 0
|
103 |
+
# linear layer applied position-wise, multiply with sequence length
|
104 |
+
macs = batch_size * (weight_ops + bias_ops) * seq_len
|
105 |
+
else:
|
106 |
+
# classification head
|
107 |
+
# (batch size, embeddings size)
|
108 |
+
batch_size, embed_size = input[0].size()
|
109 |
+
weight_ops = self.weight.nelement()
|
110 |
+
bias_ops = self.bias.nelement() if self.bias is not None else 0
|
111 |
+
# overall macs count is equal to the number of parameters in layer
|
112 |
+
macs = batch_size * (weight_ops + bias_ops)
|
113 |
+
list_linear.append(macs)
|
114 |
+
|
115 |
+
list_att = []
|
116 |
+
|
117 |
+
def attention_hook(self, input, output):
|
118 |
+
# here we only calculate the attention macs; linear layers are processed in linear_hook
|
119 |
+
batch_size, seq_len, embed_size = input[0].size()
|
120 |
+
|
121 |
+
# 2 times embed_size * seq_len**2
|
122 |
+
# - computing the attention matrix: embed_size * seq_len**2
|
123 |
+
# - multiply attention matrix with value matrix: embed_size * seq_len**2
|
124 |
+
macs = batch_size * embed_size * seq_len * seq_len * 2
|
125 |
+
list_att.append(macs)
|
126 |
+
|
127 |
+
def foo(net):
|
128 |
+
childrens = list(net.children())
|
129 |
+
if net.__class__.__name__ == "MultiHeadAttention":
|
130 |
+
net.register_forward_hook(attention_hook)
|
131 |
+
if not childrens:
|
132 |
+
if isinstance(net, nn.Conv2d):
|
133 |
+
net.register_forward_hook(conv2d_hook)
|
134 |
+
elif isinstance(net, nn.Linear):
|
135 |
+
net.register_forward_hook(linear_hook)
|
136 |
+
else:
|
137 |
+
print('Warning: flop of module {} is not counted!'.format(net))
|
138 |
+
return
|
139 |
+
for c in childrens:
|
140 |
+
foo(c)
|
141 |
+
|
142 |
+
# Register hook
|
143 |
+
foo(model)
|
144 |
+
|
145 |
+
device = next(model.parameters()).device
|
146 |
+
input = torch.rand(spec_size).to(device)
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
model(input)
|
150 |
+
|
151 |
+
total_macs = sum(list_conv2d) + sum(list_linear) + sum(list_att)
|
152 |
+
|
153 |
+
print("*************Computational Complexity (multiply-adds) **************")
|
154 |
+
print("Number of Convolutional Layers: ", len(list_conv2d))
|
155 |
+
print("Number of Linear Layers: ", len(list_linear))
|
156 |
+
print("Number of Attention Layers: ", len(list_att))
|
157 |
+
print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs)))
|
158 |
+
print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs))
|
159 |
+
print("Relative Share of Attention Layers: {:.2f}".format(sum(list_att) / total_macs))
|
160 |
+
print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9))
|
161 |
+
print("********************************************************************")
|
162 |
+
return total_macs
|
efficientat/helpers/init.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
def worker_init_fn(wid):
|
7 |
+
seed_sequence = np.random.SeedSequence(
|
8 |
+
[torch.initial_seed(), wid]
|
9 |
+
)
|
10 |
+
|
11 |
+
to_seed = spawn_get(seed_sequence, 2, dtype=int)
|
12 |
+
torch.random.manual_seed(to_seed)
|
13 |
+
|
14 |
+
np_seed = spawn_get(seed_sequence, 2, dtype=np.ndarray)
|
15 |
+
np.random.seed(np_seed)
|
16 |
+
|
17 |
+
py_seed = spawn_get(seed_sequence, 2, dtype=int)
|
18 |
+
random.seed(py_seed)
|
19 |
+
|
20 |
+
|
21 |
+
def spawn_get(seedseq, n_entropy, dtype):
|
22 |
+
child = seedseq.spawn(1)[0]
|
23 |
+
state = child.generate_state(n_entropy, dtype=np.uint32)
|
24 |
+
|
25 |
+
if dtype == np.ndarray:
|
26 |
+
return state
|
27 |
+
elif dtype == int:
|
28 |
+
state_as_int = 0
|
29 |
+
for shift, s in enumerate(state):
|
30 |
+
state_as_int = state_as_int + int((2 ** (32 * shift) * s))
|
31 |
+
return state_as_int
|
32 |
+
else:
|
33 |
+
raise ValueError(f'not a valid dtype "{dtype}"')
|
efficientat/helpers/utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def NAME_TO_WIDTH(name):
|
2 |
+
map = {
|
3 |
+
'mn04': 0.4,
|
4 |
+
'mn05': 0.5,
|
5 |
+
'mn10': 1.0,
|
6 |
+
'mn20': 2.0,
|
7 |
+
'mn30': 3.0,
|
8 |
+
'mn40': 4.0
|
9 |
+
}
|
10 |
+
try:
|
11 |
+
w = map[name[:4]]
|
12 |
+
except:
|
13 |
+
w = 1.0
|
14 |
+
|
15 |
+
return w
|
16 |
+
|
17 |
+
|
18 |
+
import csv
|
19 |
+
|
20 |
+
# Load label
|
21 |
+
with open('efficientat/metadata/class_labels_indices.csv', 'r') as f:
|
22 |
+
reader = csv.reader(f, delimiter=',')
|
23 |
+
lines = list(reader)
|
24 |
+
|
25 |
+
labels = []
|
26 |
+
ids = [] # Each label has a unique id such as "/m/068hy"
|
27 |
+
for i1 in range(1, len(lines)):
|
28 |
+
id = lines[i1][1]
|
29 |
+
label = lines[i1][2]
|
30 |
+
ids.append(id)
|
31 |
+
labels.append(label)
|
32 |
+
|
33 |
+
classes_num = len(labels)
|
34 |
+
|
35 |
+
|
36 |
+
import numpy as np
|
37 |
+
|
38 |
+
|
39 |
+
def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value):
|
40 |
+
rampup = exp_rampup(warmup)
|
41 |
+
rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value)
|
42 |
+
def wrapper(epoch):
|
43 |
+
return rampup(epoch) * rampdown(epoch)
|
44 |
+
return wrapper
|
45 |
+
|
46 |
+
|
47 |
+
def exp_rampup(rampup_length):
|
48 |
+
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
|
49 |
+
def wrapper(epoch):
|
50 |
+
if epoch < rampup_length:
|
51 |
+
epoch = np.clip(epoch, 0.5, rampup_length)
|
52 |
+
phase = 1.0 - epoch / rampup_length
|
53 |
+
return float(np.exp(-5.0 * phase * phase))
|
54 |
+
else:
|
55 |
+
return 1.0
|
56 |
+
return wrapper
|
57 |
+
|
58 |
+
|
59 |
+
def linear_rampdown(rampdown_length, start=0, last_value=0):
|
60 |
+
def wrapper(epoch):
|
61 |
+
if epoch <= start:
|
62 |
+
return 1.
|
63 |
+
elif epoch - start < rampdown_length:
|
64 |
+
return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length
|
65 |
+
else:
|
66 |
+
return last_value
|
67 |
+
return wrapper
|
68 |
+
|
69 |
+
|
70 |
+
import torch
|
71 |
+
|
72 |
+
|
73 |
+
def mixup(size, alpha):
|
74 |
+
rn_indices = torch.randperm(size)
|
75 |
+
lambd = np.random.beta(alpha, alpha, size).astype(np.float32)
|
76 |
+
lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1)
|
77 |
+
lam = torch.FloatTensor(lambd)
|
78 |
+
return rn_indices, lam
|
79 |
+
|
80 |
+
|
81 |
+
from torch.distributions.beta import Beta
|
82 |
+
|
83 |
+
|
84 |
+
def mixstyle(x, p=0.4, alpha=0.4, eps=1e-6, mix_labels=False):
|
85 |
+
if np.random.rand() > p:
|
86 |
+
return x
|
87 |
+
batch_size = x.size(0)
|
88 |
+
|
89 |
+
# changed from dim=[2,3] to dim=[1,3] - from channel-wise statistics to frequency-wise statistics
|
90 |
+
f_mu = x.mean(dim=[1, 3], keepdim=True)
|
91 |
+
f_var = x.var(dim=[1, 3], keepdim=True)
|
92 |
+
|
93 |
+
f_sig = (f_var + eps).sqrt() # compute instance standard deviation
|
94 |
+
f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients
|
95 |
+
x_normed = (x - f_mu) / f_sig # normalize input
|
96 |
+
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) # sample instance-wise convex weights
|
97 |
+
perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices
|
98 |
+
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling
|
99 |
+
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean
|
100 |
+
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation
|
101 |
+
x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics
|
102 |
+
if mix_labels:
|
103 |
+
return x, perm, lmda
|
104 |
+
return x
|
efficientat/metadata/class_labels_indices.csv
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
index,mid,display_name
|
2 |
+
0,/m/09x0r,"Speech"
|
3 |
+
1,/m/05zppz,"Male speech, man speaking"
|
4 |
+
2,/m/02zsn,"Female speech, woman speaking"
|
5 |
+
3,/m/0ytgt,"Child speech, kid speaking"
|
6 |
+
4,/m/01h8n0,"Conversation"
|
7 |
+
5,/m/02qldy,"Narration, monologue"
|
8 |
+
6,/m/0261r1,"Babbling"
|
9 |
+
7,/m/0brhx,"Speech synthesizer"
|
10 |
+
8,/m/07p6fty,"Shout"
|
11 |
+
9,/m/07q4ntr,"Bellow"
|
12 |
+
10,/m/07rwj3x,"Whoop"
|
13 |
+
11,/m/07sr1lc,"Yell"
|
14 |
+
12,/m/04gy_2,"Battle cry"
|
15 |
+
13,/t/dd00135,"Children shouting"
|
16 |
+
14,/m/03qc9zr,"Screaming"
|
17 |
+
15,/m/02rtxlg,"Whispering"
|
18 |
+
16,/m/01j3sz,"Laughter"
|
19 |
+
17,/t/dd00001,"Baby laughter"
|
20 |
+
18,/m/07r660_,"Giggle"
|
21 |
+
19,/m/07s04w4,"Snicker"
|
22 |
+
20,/m/07sq110,"Belly laugh"
|
23 |
+
21,/m/07rgt08,"Chuckle, chortle"
|
24 |
+
22,/m/0463cq4,"Crying, sobbing"
|
25 |
+
23,/t/dd00002,"Baby cry, infant cry"
|
26 |
+
24,/m/07qz6j3,"Whimper"
|
27 |
+
25,/m/07qw_06,"Wail, moan"
|
28 |
+
26,/m/07plz5l,"Sigh"
|
29 |
+
27,/m/015lz1,"Singing"
|
30 |
+
28,/m/0l14jd,"Choir"
|
31 |
+
29,/m/01swy6,"Yodeling"
|
32 |
+
30,/m/02bk07,"Chant"
|
33 |
+
31,/m/01c194,"Mantra"
|
34 |
+
32,/t/dd00003,"Male singing"
|
35 |
+
33,/t/dd00004,"Female singing"
|
36 |
+
34,/t/dd00005,"Child singing"
|
37 |
+
35,/t/dd00006,"Synthetic singing"
|
38 |
+
36,/m/06bxc,"Rapping"
|
39 |
+
37,/m/02fxyj,"Humming"
|
40 |
+
38,/m/07s2xch,"Groan"
|
41 |
+
39,/m/07r4k75,"Grunt"
|
42 |
+
40,/m/01w250,"Whistling"
|
43 |
+
41,/m/0lyf6,"Breathing"
|
44 |
+
42,/m/07mzm6,"Wheeze"
|
45 |
+
43,/m/01d3sd,"Snoring"
|
46 |
+
44,/m/07s0dtb,"Gasp"
|
47 |
+
45,/m/07pyy8b,"Pant"
|
48 |
+
46,/m/07q0yl5,"Snort"
|
49 |
+
47,/m/01b_21,"Cough"
|
50 |
+
48,/m/0dl9sf8,"Throat clearing"
|
51 |
+
49,/m/01hsr_,"Sneeze"
|
52 |
+
50,/m/07ppn3j,"Sniff"
|
53 |
+
51,/m/06h7j,"Run"
|
54 |
+
52,/m/07qv_x_,"Shuffle"
|
55 |
+
53,/m/07pbtc8,"Walk, footsteps"
|
56 |
+
54,/m/03cczk,"Chewing, mastication"
|
57 |
+
55,/m/07pdhp0,"Biting"
|
58 |
+
56,/m/0939n_,"Gargling"
|
59 |
+
57,/m/01g90h,"Stomach rumble"
|
60 |
+
58,/m/03q5_w,"Burping, eructation"
|
61 |
+
59,/m/02p3nc,"Hiccup"
|
62 |
+
60,/m/02_nn,"Fart"
|
63 |
+
61,/m/0k65p,"Hands"
|
64 |
+
62,/m/025_jnm,"Finger snapping"
|
65 |
+
63,/m/0l15bq,"Clapping"
|
66 |
+
64,/m/01jg02,"Heart sounds, heartbeat"
|
67 |
+
65,/m/01jg1z,"Heart murmur"
|
68 |
+
66,/m/053hz1,"Cheering"
|
69 |
+
67,/m/028ght,"Applause"
|
70 |
+
68,/m/07rkbfh,"Chatter"
|
71 |
+
69,/m/03qtwd,"Crowd"
|
72 |
+
70,/m/07qfr4h,"Hubbub, speech noise, speech babble"
|
73 |
+
71,/t/dd00013,"Children playing"
|
74 |
+
72,/m/0jbk,"Animal"
|
75 |
+
73,/m/068hy,"Domestic animals, pets"
|
76 |
+
74,/m/0bt9lr,"Dog"
|
77 |
+
75,/m/05tny_,"Bark"
|
78 |
+
76,/m/07r_k2n,"Yip"
|
79 |
+
77,/m/07qf0zm,"Howl"
|
80 |
+
78,/m/07rc7d9,"Bow-wow"
|
81 |
+
79,/m/0ghcn6,"Growling"
|
82 |
+
80,/t/dd00136,"Whimper (dog)"
|
83 |
+
81,/m/01yrx,"Cat"
|
84 |
+
82,/m/02yds9,"Purr"
|
85 |
+
83,/m/07qrkrw,"Meow"
|
86 |
+
84,/m/07rjwbb,"Hiss"
|
87 |
+
85,/m/07r81j2,"Caterwaul"
|
88 |
+
86,/m/0ch8v,"Livestock, farm animals, working animals"
|
89 |
+
87,/m/03k3r,"Horse"
|
90 |
+
88,/m/07rv9rh,"Clip-clop"
|
91 |
+
89,/m/07q5rw0,"Neigh, whinny"
|
92 |
+
90,/m/01xq0k1,"Cattle, bovinae"
|
93 |
+
91,/m/07rpkh9,"Moo"
|
94 |
+
92,/m/0239kh,"Cowbell"
|
95 |
+
93,/m/068zj,"Pig"
|
96 |
+
94,/t/dd00018,"Oink"
|
97 |
+
95,/m/03fwl,"Goat"
|
98 |
+
96,/m/07q0h5t,"Bleat"
|
99 |
+
97,/m/07bgp,"Sheep"
|
100 |
+
98,/m/025rv6n,"Fowl"
|
101 |
+
99,/m/09b5t,"Chicken, rooster"
|
102 |
+
100,/m/07st89h,"Cluck"
|
103 |
+
101,/m/07qn5dc,"Crowing, cock-a-doodle-doo"
|
104 |
+
102,/m/01rd7k,"Turkey"
|
105 |
+
103,/m/07svc2k,"Gobble"
|
106 |
+
104,/m/09ddx,"Duck"
|
107 |
+
105,/m/07qdb04,"Quack"
|
108 |
+
106,/m/0dbvp,"Goose"
|
109 |
+
107,/m/07qwf61,"Honk"
|
110 |
+
108,/m/01280g,"Wild animals"
|
111 |
+
109,/m/0cdnk,"Roaring cats (lions, tigers)"
|
112 |
+
110,/m/04cvmfc,"Roar"
|
113 |
+
111,/m/015p6,"Bird"
|
114 |
+
112,/m/020bb7,"Bird vocalization, bird call, bird song"
|
115 |
+
113,/m/07pggtn,"Chirp, tweet"
|
116 |
+
114,/m/07sx8x_,"Squawk"
|
117 |
+
115,/m/0h0rv,"Pigeon, dove"
|
118 |
+
116,/m/07r_25d,"Coo"
|
119 |
+
117,/m/04s8yn,"Crow"
|
120 |
+
118,/m/07r5c2p,"Caw"
|
121 |
+
119,/m/09d5_,"Owl"
|
122 |
+
120,/m/07r_80w,"Hoot"
|
123 |
+
121,/m/05_wcq,"Bird flight, flapping wings"
|
124 |
+
122,/m/01z5f,"Canidae, dogs, wolves"
|
125 |
+
123,/m/06hps,"Rodents, rats, mice"
|
126 |
+
124,/m/04rmv,"Mouse"
|
127 |
+
125,/m/07r4gkf,"Patter"
|
128 |
+
126,/m/03vt0,"Insect"
|
129 |
+
127,/m/09xqv,"Cricket"
|
130 |
+
128,/m/09f96,"Mosquito"
|
131 |
+
129,/m/0h2mp,"Fly, housefly"
|
132 |
+
130,/m/07pjwq1,"Buzz"
|
133 |
+
131,/m/01h3n,"Bee, wasp, etc."
|
134 |
+
132,/m/09ld4,"Frog"
|
135 |
+
133,/m/07st88b,"Croak"
|
136 |
+
134,/m/078jl,"Snake"
|
137 |
+
135,/m/07qn4z3,"Rattle"
|
138 |
+
136,/m/032n05,"Whale vocalization"
|
139 |
+
137,/m/04rlf,"Music"
|
140 |
+
138,/m/04szw,"Musical instrument"
|
141 |
+
139,/m/0fx80y,"Plucked string instrument"
|
142 |
+
140,/m/0342h,"Guitar"
|
143 |
+
141,/m/02sgy,"Electric guitar"
|
144 |
+
142,/m/018vs,"Bass guitar"
|
145 |
+
143,/m/042v_gx,"Acoustic guitar"
|
146 |
+
144,/m/06w87,"Steel guitar, slide guitar"
|
147 |
+
145,/m/01glhc,"Tapping (guitar technique)"
|
148 |
+
146,/m/07s0s5r,"Strum"
|
149 |
+
147,/m/018j2,"Banjo"
|
150 |
+
148,/m/0jtg0,"Sitar"
|
151 |
+
149,/m/04rzd,"Mandolin"
|
152 |
+
150,/m/01bns_,"Zither"
|
153 |
+
151,/m/07xzm,"Ukulele"
|
154 |
+
152,/m/05148p4,"Keyboard (musical)"
|
155 |
+
153,/m/05r5c,"Piano"
|
156 |
+
154,/m/01s0ps,"Electric piano"
|
157 |
+
155,/m/013y1f,"Organ"
|
158 |
+
156,/m/03xq_f,"Electronic organ"
|
159 |
+
157,/m/03gvt,"Hammond organ"
|
160 |
+
158,/m/0l14qv,"Synthesizer"
|
161 |
+
159,/m/01v1d8,"Sampler"
|
162 |
+
160,/m/03q5t,"Harpsichord"
|
163 |
+
161,/m/0l14md,"Percussion"
|
164 |
+
162,/m/02hnl,"Drum kit"
|
165 |
+
163,/m/0cfdd,"Drum machine"
|
166 |
+
164,/m/026t6,"Drum"
|
167 |
+
165,/m/06rvn,"Snare drum"
|
168 |
+
166,/m/03t3fj,"Rimshot"
|
169 |
+
167,/m/02k_mr,"Drum roll"
|
170 |
+
168,/m/0bm02,"Bass drum"
|
171 |
+
169,/m/011k_j,"Timpani"
|
172 |
+
170,/m/01p970,"Tabla"
|
173 |
+
171,/m/01qbl,"Cymbal"
|
174 |
+
172,/m/03qtq,"Hi-hat"
|
175 |
+
173,/m/01sm1g,"Wood block"
|
176 |
+
174,/m/07brj,"Tambourine"
|
177 |
+
175,/m/05r5wn,"Rattle (instrument)"
|
178 |
+
176,/m/0xzly,"Maraca"
|
179 |
+
177,/m/0mbct,"Gong"
|
180 |
+
178,/m/016622,"Tubular bells"
|
181 |
+
179,/m/0j45pbj,"Mallet percussion"
|
182 |
+
180,/m/0dwsp,"Marimba, xylophone"
|
183 |
+
181,/m/0dwtp,"Glockenspiel"
|
184 |
+
182,/m/0dwt5,"Vibraphone"
|
185 |
+
183,/m/0l156b,"Steelpan"
|
186 |
+
184,/m/05pd6,"Orchestra"
|
187 |
+
185,/m/01kcd,"Brass instrument"
|
188 |
+
186,/m/0319l,"French horn"
|
189 |
+
187,/m/07gql,"Trumpet"
|
190 |
+
188,/m/07c6l,"Trombone"
|
191 |
+
189,/m/0l14_3,"Bowed string instrument"
|
192 |
+
190,/m/02qmj0d,"String section"
|
193 |
+
191,/m/07y_7,"Violin, fiddle"
|
194 |
+
192,/m/0d8_n,"Pizzicato"
|
195 |
+
193,/m/01xqw,"Cello"
|
196 |
+
194,/m/02fsn,"Double bass"
|
197 |
+
195,/m/085jw,"Wind instrument, woodwind instrument"
|
198 |
+
196,/m/0l14j_,"Flute"
|
199 |
+
197,/m/06ncr,"Saxophone"
|
200 |
+
198,/m/01wy6,"Clarinet"
|
201 |
+
199,/m/03m5k,"Harp"
|
202 |
+
200,/m/0395lw,"Bell"
|
203 |
+
201,/m/03w41f,"Church bell"
|
204 |
+
202,/m/027m70_,"Jingle bell"
|
205 |
+
203,/m/0gy1t2s,"Bicycle bell"
|
206 |
+
204,/m/07n_g,"Tuning fork"
|
207 |
+
205,/m/0f8s22,"Chime"
|
208 |
+
206,/m/026fgl,"Wind chime"
|
209 |
+
207,/m/0150b9,"Change ringing (campanology)"
|
210 |
+
208,/m/03qjg,"Harmonica"
|
211 |
+
209,/m/0mkg,"Accordion"
|
212 |
+
210,/m/0192l,"Bagpipes"
|
213 |
+
211,/m/02bxd,"Didgeridoo"
|
214 |
+
212,/m/0l14l2,"Shofar"
|
215 |
+
213,/m/07kc_,"Theremin"
|
216 |
+
214,/m/0l14t7,"Singing bowl"
|
217 |
+
215,/m/01hgjl,"Scratching (performance technique)"
|
218 |
+
216,/m/064t9,"Pop music"
|
219 |
+
217,/m/0glt670,"Hip hop music"
|
220 |
+
218,/m/02cz_7,"Beatboxing"
|
221 |
+
219,/m/06by7,"Rock music"
|
222 |
+
220,/m/03lty,"Heavy metal"
|
223 |
+
221,/m/05r6t,"Punk rock"
|
224 |
+
222,/m/0dls3,"Grunge"
|
225 |
+
223,/m/0dl5d,"Progressive rock"
|
226 |
+
224,/m/07sbbz2,"Rock and roll"
|
227 |
+
225,/m/05w3f,"Psychedelic rock"
|
228 |
+
226,/m/06j6l,"Rhythm and blues"
|
229 |
+
227,/m/0gywn,"Soul music"
|
230 |
+
228,/m/06cqb,"Reggae"
|
231 |
+
229,/m/01lyv,"Country"
|
232 |
+
230,/m/015y_n,"Swing music"
|
233 |
+
231,/m/0gg8l,"Bluegrass"
|
234 |
+
232,/m/02x8m,"Funk"
|
235 |
+
233,/m/02w4v,"Folk music"
|
236 |
+
234,/m/06j64v,"Middle Eastern music"
|
237 |
+
235,/m/03_d0,"Jazz"
|
238 |
+
236,/m/026z9,"Disco"
|
239 |
+
237,/m/0ggq0m,"Classical music"
|
240 |
+
238,/m/05lls,"Opera"
|
241 |
+
239,/m/02lkt,"Electronic music"
|
242 |
+
240,/m/03mb9,"House music"
|
243 |
+
241,/m/07gxw,"Techno"
|
244 |
+
242,/m/07s72n,"Dubstep"
|
245 |
+
243,/m/0283d,"Drum and bass"
|
246 |
+
244,/m/0m0jc,"Electronica"
|
247 |
+
245,/m/08cyft,"Electronic dance music"
|
248 |
+
246,/m/0fd3y,"Ambient music"
|
249 |
+
247,/m/07lnk,"Trance music"
|
250 |
+
248,/m/0g293,"Music of Latin America"
|
251 |
+
249,/m/0ln16,"Salsa music"
|
252 |
+
250,/m/0326g,"Flamenco"
|
253 |
+
251,/m/0155w,"Blues"
|
254 |
+
252,/m/05fw6t,"Music for children"
|
255 |
+
253,/m/02v2lh,"New-age music"
|
256 |
+
254,/m/0y4f8,"Vocal music"
|
257 |
+
255,/m/0z9c,"A capella"
|
258 |
+
256,/m/0164x2,"Music of Africa"
|
259 |
+
257,/m/0145m,"Afrobeat"
|
260 |
+
258,/m/02mscn,"Christian music"
|
261 |
+
259,/m/016cjb,"Gospel music"
|
262 |
+
260,/m/028sqc,"Music of Asia"
|
263 |
+
261,/m/015vgc,"Carnatic music"
|
264 |
+
262,/m/0dq0md,"Music of Bollywood"
|
265 |
+
263,/m/06rqw,"Ska"
|
266 |
+
264,/m/02p0sh1,"Traditional music"
|
267 |
+
265,/m/05rwpb,"Independent music"
|
268 |
+
266,/m/074ft,"Song"
|
269 |
+
267,/m/025td0t,"Background music"
|
270 |
+
268,/m/02cjck,"Theme music"
|
271 |
+
269,/m/03r5q_,"Jingle (music)"
|
272 |
+
270,/m/0l14gg,"Soundtrack music"
|
273 |
+
271,/m/07pkxdp,"Lullaby"
|
274 |
+
272,/m/01z7dr,"Video game music"
|
275 |
+
273,/m/0140xf,"Christmas music"
|
276 |
+
274,/m/0ggx5q,"Dance music"
|
277 |
+
275,/m/04wptg,"Wedding music"
|
278 |
+
276,/t/dd00031,"Happy music"
|
279 |
+
277,/t/dd00032,"Funny music"
|
280 |
+
278,/t/dd00033,"Sad music"
|
281 |
+
279,/t/dd00034,"Tender music"
|
282 |
+
280,/t/dd00035,"Exciting music"
|
283 |
+
281,/t/dd00036,"Angry music"
|
284 |
+
282,/t/dd00037,"Scary music"
|
285 |
+
283,/m/03m9d0z,"Wind"
|
286 |
+
284,/m/09t49,"Rustling leaves"
|
287 |
+
285,/t/dd00092,"Wind noise (microphone)"
|
288 |
+
286,/m/0jb2l,"Thunderstorm"
|
289 |
+
287,/m/0ngt1,"Thunder"
|
290 |
+
288,/m/0838f,"Water"
|
291 |
+
289,/m/06mb1,"Rain"
|
292 |
+
290,/m/07r10fb,"Raindrop"
|
293 |
+
291,/t/dd00038,"Rain on surface"
|
294 |
+
292,/m/0j6m2,"Stream"
|
295 |
+
293,/m/0j2kx,"Waterfall"
|
296 |
+
294,/m/05kq4,"Ocean"
|
297 |
+
295,/m/034srq,"Waves, surf"
|
298 |
+
296,/m/06wzb,"Steam"
|
299 |
+
297,/m/07swgks,"Gurgling"
|
300 |
+
298,/m/02_41,"Fire"
|
301 |
+
299,/m/07pzfmf,"Crackle"
|
302 |
+
300,/m/07yv9,"Vehicle"
|
303 |
+
301,/m/019jd,"Boat, Water vehicle"
|
304 |
+
302,/m/0hsrw,"Sailboat, sailing ship"
|
305 |
+
303,/m/056ks2,"Rowboat, canoe, kayak"
|
306 |
+
304,/m/02rlv9,"Motorboat, speedboat"
|
307 |
+
305,/m/06q74,"Ship"
|
308 |
+
306,/m/012f08,"Motor vehicle (road)"
|
309 |
+
307,/m/0k4j,"Car"
|
310 |
+
308,/m/0912c9,"Vehicle horn, car horn, honking"
|
311 |
+
309,/m/07qv_d5,"Toot"
|
312 |
+
310,/m/02mfyn,"Car alarm"
|
313 |
+
311,/m/04gxbd,"Power windows, electric windows"
|
314 |
+
312,/m/07rknqz,"Skidding"
|
315 |
+
313,/m/0h9mv,"Tire squeal"
|
316 |
+
314,/t/dd00134,"Car passing by"
|
317 |
+
315,/m/0ltv,"Race car, auto racing"
|
318 |
+
316,/m/07r04,"Truck"
|
319 |
+
317,/m/0gvgw0,"Air brake"
|
320 |
+
318,/m/05x_td,"Air horn, truck horn"
|
321 |
+
319,/m/02rhddq,"Reversing beeps"
|
322 |
+
320,/m/03cl9h,"Ice cream truck, ice cream van"
|
323 |
+
321,/m/01bjv,"Bus"
|
324 |
+
322,/m/03j1ly,"Emergency vehicle"
|
325 |
+
323,/m/04qvtq,"Police car (siren)"
|
326 |
+
324,/m/012n7d,"Ambulance (siren)"
|
327 |
+
325,/m/012ndj,"Fire engine, fire truck (siren)"
|
328 |
+
326,/m/04_sv,"Motorcycle"
|
329 |
+
327,/m/0btp2,"Traffic noise, roadway noise"
|
330 |
+
328,/m/06d_3,"Rail transport"
|
331 |
+
329,/m/07jdr,"Train"
|
332 |
+
330,/m/04zmvq,"Train whistle"
|
333 |
+
331,/m/0284vy3,"Train horn"
|
334 |
+
332,/m/01g50p,"Railroad car, train wagon"
|
335 |
+
333,/t/dd00048,"Train wheels squealing"
|
336 |
+
334,/m/0195fx,"Subway, metro, underground"
|
337 |
+
335,/m/0k5j,"Aircraft"
|
338 |
+
336,/m/014yck,"Aircraft engine"
|
339 |
+
337,/m/04229,"Jet engine"
|
340 |
+
338,/m/02l6bg,"Propeller, airscrew"
|
341 |
+
339,/m/09ct_,"Helicopter"
|
342 |
+
340,/m/0cmf2,"Fixed-wing aircraft, airplane"
|
343 |
+
341,/m/0199g,"Bicycle"
|
344 |
+
342,/m/06_fw,"Skateboard"
|
345 |
+
343,/m/02mk9,"Engine"
|
346 |
+
344,/t/dd00065,"Light engine (high frequency)"
|
347 |
+
345,/m/08j51y,"Dental drill, dentist's drill"
|
348 |
+
346,/m/01yg9g,"Lawn mower"
|
349 |
+
347,/m/01j4z9,"Chainsaw"
|
350 |
+
348,/t/dd00066,"Medium engine (mid frequency)"
|
351 |
+
349,/t/dd00067,"Heavy engine (low frequency)"
|
352 |
+
350,/m/01h82_,"Engine knocking"
|
353 |
+
351,/t/dd00130,"Engine starting"
|
354 |
+
352,/m/07pb8fc,"Idling"
|
355 |
+
353,/m/07q2z82,"Accelerating, revving, vroom"
|
356 |
+
354,/m/02dgv,"Door"
|
357 |
+
355,/m/03wwcy,"Doorbell"
|
358 |
+
356,/m/07r67yg,"Ding-dong"
|
359 |
+
357,/m/02y_763,"Sliding door"
|
360 |
+
358,/m/07rjzl8,"Slam"
|
361 |
+
359,/m/07r4wb8,"Knock"
|
362 |
+
360,/m/07qcpgn,"Tap"
|
363 |
+
361,/m/07q6cd_,"Squeak"
|
364 |
+
362,/m/0642b4,"Cupboard open or close"
|
365 |
+
363,/m/0fqfqc,"Drawer open or close"
|
366 |
+
364,/m/04brg2,"Dishes, pots, and pans"
|
367 |
+
365,/m/023pjk,"Cutlery, silverware"
|
368 |
+
366,/m/07pn_8q,"Chopping (food)"
|
369 |
+
367,/m/0dxrf,"Frying (food)"
|
370 |
+
368,/m/0fx9l,"Microwave oven"
|
371 |
+
369,/m/02pjr4,"Blender"
|
372 |
+
370,/m/02jz0l,"Water tap, faucet"
|
373 |
+
371,/m/0130jx,"Sink (filling or washing)"
|
374 |
+
372,/m/03dnzn,"Bathtub (filling or washing)"
|
375 |
+
373,/m/03wvsk,"Hair dryer"
|
376 |
+
374,/m/01jt3m,"Toilet flush"
|
377 |
+
375,/m/012xff,"Toothbrush"
|
378 |
+
376,/m/04fgwm,"Electric toothbrush"
|
379 |
+
377,/m/0d31p,"Vacuum cleaner"
|
380 |
+
378,/m/01s0vc,"Zipper (clothing)"
|
381 |
+
379,/m/03v3yw,"Keys jangling"
|
382 |
+
380,/m/0242l,"Coin (dropping)"
|
383 |
+
381,/m/01lsmm,"Scissors"
|
384 |
+
382,/m/02g901,"Electric shaver, electric razor"
|
385 |
+
383,/m/05rj2,"Shuffling cards"
|
386 |
+
384,/m/0316dw,"Typing"
|
387 |
+
385,/m/0c2wf,"Typewriter"
|
388 |
+
386,/m/01m2v,"Computer keyboard"
|
389 |
+
387,/m/081rb,"Writing"
|
390 |
+
388,/m/07pp_mv,"Alarm"
|
391 |
+
389,/m/07cx4,"Telephone"
|
392 |
+
390,/m/07pp8cl,"Telephone bell ringing"
|
393 |
+
391,/m/01hnzm,"Ringtone"
|
394 |
+
392,/m/02c8p,"Telephone dialing, DTMF"
|
395 |
+
393,/m/015jpf,"Dial tone"
|
396 |
+
394,/m/01z47d,"Busy signal"
|
397 |
+
395,/m/046dlr,"Alarm clock"
|
398 |
+
396,/m/03kmc9,"Siren"
|
399 |
+
397,/m/0dgbq,"Civil defense siren"
|
400 |
+
398,/m/030rvx,"Buzzer"
|
401 |
+
399,/m/01y3hg,"Smoke detector, smoke alarm"
|
402 |
+
400,/m/0c3f7m,"Fire alarm"
|
403 |
+
401,/m/04fq5q,"Foghorn"
|
404 |
+
402,/m/0l156k,"Whistle"
|
405 |
+
403,/m/06hck5,"Steam whistle"
|
406 |
+
404,/t/dd00077,"Mechanisms"
|
407 |
+
405,/m/02bm9n,"Ratchet, pawl"
|
408 |
+
406,/m/01x3z,"Clock"
|
409 |
+
407,/m/07qjznt,"Tick"
|
410 |
+
408,/m/07qjznl,"Tick-tock"
|
411 |
+
409,/m/0l7xg,"Gears"
|
412 |
+
410,/m/05zc1,"Pulleys"
|
413 |
+
411,/m/0llzx,"Sewing machine"
|
414 |
+
412,/m/02x984l,"Mechanical fan"
|
415 |
+
413,/m/025wky1,"Air conditioning"
|
416 |
+
414,/m/024dl,"Cash register"
|
417 |
+
415,/m/01m4t,"Printer"
|
418 |
+
416,/m/0dv5r,"Camera"
|
419 |
+
417,/m/07bjf,"Single-lens reflex camera"
|
420 |
+
418,/m/07k1x,"Tools"
|
421 |
+
419,/m/03l9g,"Hammer"
|
422 |
+
420,/m/03p19w,"Jackhammer"
|
423 |
+
421,/m/01b82r,"Sawing"
|
424 |
+
422,/m/02p01q,"Filing (rasp)"
|
425 |
+
423,/m/023vsd,"Sanding"
|
426 |
+
424,/m/0_ksk,"Power tool"
|
427 |
+
425,/m/01d380,"Drill"
|
428 |
+
426,/m/014zdl,"Explosion"
|
429 |
+
427,/m/032s66,"Gunshot, gunfire"
|
430 |
+
428,/m/04zjc,"Machine gun"
|
431 |
+
429,/m/02z32qm,"Fusillade"
|
432 |
+
430,/m/0_1c,"Artillery fire"
|
433 |
+
431,/m/073cg4,"Cap gun"
|
434 |
+
432,/m/0g6b5,"Fireworks"
|
435 |
+
433,/g/122z_qxw,"Firecracker"
|
436 |
+
434,/m/07qsvvw,"Burst, pop"
|
437 |
+
435,/m/07pxg6y,"Eruption"
|
438 |
+
436,/m/07qqyl4,"Boom"
|
439 |
+
437,/m/083vt,"Wood"
|
440 |
+
438,/m/07pczhz,"Chop"
|
441 |
+
439,/m/07pl1bw,"Splinter"
|
442 |
+
440,/m/07qs1cx,"Crack"
|
443 |
+
441,/m/039jq,"Glass"
|
444 |
+
442,/m/07q7njn,"Chink, clink"
|
445 |
+
443,/m/07rn7sz,"Shatter"
|
446 |
+
444,/m/04k94,"Liquid"
|
447 |
+
445,/m/07rrlb6,"Splash, splatter"
|
448 |
+
446,/m/07p6mqd,"Slosh"
|
449 |
+
447,/m/07qlwh6,"Squish"
|
450 |
+
448,/m/07r5v4s,"Drip"
|
451 |
+
449,/m/07prgkl,"Pour"
|
452 |
+
450,/m/07pqc89,"Trickle, dribble"
|
453 |
+
451,/t/dd00088,"Gush"
|
454 |
+
452,/m/07p7b8y,"Fill (with liquid)"
|
455 |
+
453,/m/07qlf79,"Spray"
|
456 |
+
454,/m/07ptzwd,"Pump (liquid)"
|
457 |
+
455,/m/07ptfmf,"Stir"
|
458 |
+
456,/m/0dv3j,"Boiling"
|
459 |
+
457,/m/0790c,"Sonar"
|
460 |
+
458,/m/0dl83,"Arrow"
|
461 |
+
459,/m/07rqsjt,"Whoosh, swoosh, swish"
|
462 |
+
460,/m/07qnq_y,"Thump, thud"
|
463 |
+
461,/m/07rrh0c,"Thunk"
|
464 |
+
462,/m/0b_fwt,"Electronic tuner"
|
465 |
+
463,/m/02rr_,"Effects unit"
|
466 |
+
464,/m/07m2kt,"Chorus effect"
|
467 |
+
465,/m/018w8,"Basketball bounce"
|
468 |
+
466,/m/07pws3f,"Bang"
|
469 |
+
467,/m/07ryjzk,"Slap, smack"
|
470 |
+
468,/m/07rdhzs,"Whack, thwack"
|
471 |
+
469,/m/07pjjrj,"Smash, crash"
|
472 |
+
470,/m/07pc8lb,"Breaking"
|
473 |
+
471,/m/07pqn27,"Bouncing"
|
474 |
+
472,/m/07rbp7_,"Whip"
|
475 |
+
473,/m/07pyf11,"Flap"
|
476 |
+
474,/m/07qb_dv,"Scratch"
|
477 |
+
475,/m/07qv4k0,"Scrape"
|
478 |
+
476,/m/07pdjhy,"Rub"
|
479 |
+
477,/m/07s8j8t,"Roll"
|
480 |
+
478,/m/07plct2,"Crushing"
|
481 |
+
479,/t/dd00112,"Crumpling, crinkling"
|
482 |
+
480,/m/07qcx4z,"Tearing"
|
483 |
+
481,/m/02fs_r,"Beep, bleep"
|
484 |
+
482,/m/07qwdck,"Ping"
|
485 |
+
483,/m/07phxs1,"Ding"
|
486 |
+
484,/m/07rv4dm,"Clang"
|
487 |
+
485,/m/07s02z0,"Squeal"
|
488 |
+
486,/m/07qh7jl,"Creak"
|
489 |
+
487,/m/07qwyj0,"Rustle"
|
490 |
+
488,/m/07s34ls,"Whir"
|
491 |
+
489,/m/07qmpdm,"Clatter"
|
492 |
+
490,/m/07p9k1k,"Sizzle"
|
493 |
+
491,/m/07qc9xj,"Clicking"
|
494 |
+
492,/m/07rwm0c,"Clickety-clack"
|
495 |
+
493,/m/07phhsh,"Rumble"
|
496 |
+
494,/m/07qyrcz,"Plop"
|
497 |
+
495,/m/07qfgpx,"Jingle, tinkle"
|
498 |
+
496,/m/07rcgpl,"Hum"
|
499 |
+
497,/m/07p78v5,"Zing"
|
500 |
+
498,/t/dd00121,"Boing"
|
501 |
+
499,/m/07s12q4,"Crunch"
|
502 |
+
500,/m/028v0c,"Silence"
|
503 |
+
501,/m/01v_m0,"Sine wave"
|
504 |
+
502,/m/0b9m1,"Harmonic"
|
505 |
+
503,/m/0hdsk,"Chirp tone"
|
506 |
+
504,/m/0c1dj,"Sound effect"
|
507 |
+
505,/m/07pt_g0,"Pulse"
|
508 |
+
506,/t/dd00125,"Inside, small room"
|
509 |
+
507,/t/dd00126,"Inside, large room or hall"
|
510 |
+
508,/t/dd00127,"Inside, public space"
|
511 |
+
509,/t/dd00128,"Outside, urban or manmade"
|
512 |
+
510,/t/dd00129,"Outside, rural or natural"
|
513 |
+
511,/m/01b9nn,"Reverberation"
|
514 |
+
512,/m/01jnbd,"Echo"
|
515 |
+
513,/m/096m7z,"Noise"
|
516 |
+
514,/m/06_y0by,"Environmental noise"
|
517 |
+
515,/m/07rgkc5,"Static"
|
518 |
+
516,/m/06xkwv,"Mains hum"
|
519 |
+
517,/m/0g12c5,"Distortion"
|
520 |
+
518,/m/08p9q4,"Sidetone"
|
521 |
+
519,/m/07szfh9,"Cacophony"
|
522 |
+
520,/m/0chx_,"White noise"
|
523 |
+
521,/m/0cj0r,"Pink noise"
|
524 |
+
522,/m/07p_0gm,"Throbbing"
|
525 |
+
523,/m/01jwx6,"Vibration"
|
526 |
+
524,/m/07c52,"Television"
|
527 |
+
525,/m/06bz3,"Radio"
|
528 |
+
526,/m/07hvw1,"Field recording"
|
efficientat/models/MobileNetV3.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
3 |
+
from torch import nn, Tensor
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.ops.misc import ConvNormActivation
|
6 |
+
from torch.hub import load_state_dict_from_url
|
7 |
+
import urllib.parse
|
8 |
+
|
9 |
+
|
10 |
+
from efficientat.models.utils import cnn_out_size
|
11 |
+
from efficientat.models.block_types import InvertedResidualConfig, InvertedResidual
|
12 |
+
from efficientat.models.attention_pooling import MultiHeadAttentionPooling
|
13 |
+
from efficientat.helpers.utils import NAME_TO_WIDTH
|
14 |
+
|
15 |
+
# Adapted version of MobileNetV3 pytorch implementation
|
16 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py
|
17 |
+
|
18 |
+
# points to github releases
|
19 |
+
model_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/"
|
20 |
+
# folder to store downloaded models to
|
21 |
+
model_dir = "resources"
|
22 |
+
|
23 |
+
|
24 |
+
pretrained_models = {
|
25 |
+
# pytorch ImageNet pre-trained model
|
26 |
+
# own ImageNet pre-trained models will follow
|
27 |
+
# NOTE: for easy loading we provide the adapted state dict ready for AudioSet training (1 input channel,
|
28 |
+
# 527 output classes)
|
29 |
+
# NOTE: the classifier is just a random initialization, feature extractor (conv layers) is pre-trained
|
30 |
+
"mn10_im_pytorch": urllib.parse.urljoin(model_url, "mn10_im_pytorch.pt"),
|
31 |
+
# Models trained on AudioSet
|
32 |
+
"mn04_as": urllib.parse.urljoin(model_url, "mn04_as_mAP_432.pt"),
|
33 |
+
"mn05_as": urllib.parse.urljoin(model_url, "mn05_as_mAP_443.pt"),
|
34 |
+
"mn10_as": urllib.parse.urljoin(model_url, "mn10_as_mAP_471.pt"),
|
35 |
+
"mn20_as": urllib.parse.urljoin(model_url, "mn20_as_mAP_478.pt"),
|
36 |
+
"mn30_as": urllib.parse.urljoin(model_url, "mn30_as_mAP_482.pt"),
|
37 |
+
"mn40_as": urllib.parse.urljoin(model_url, "mn40_as_mAP_484.pt"),
|
38 |
+
"mn40_as(2)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483.pt"),
|
39 |
+
"mn40_as(3)": urllib.parse.urljoin(model_url, "mn40_as_mAP_483(2).pt"),
|
40 |
+
"mn40_as_no_im_pre": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483.pt"),
|
41 |
+
"mn40_as_no_im_pre(2)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_483(2).pt"),
|
42 |
+
"mn40_as_no_im_pre(3)": urllib.parse.urljoin(model_url, "mn40_as_no_im_pre_mAP_482.pt"),
|
43 |
+
"mn40_as_ext": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_487.pt"),
|
44 |
+
"mn40_as_ext(2)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_486.pt"),
|
45 |
+
"mn40_as_ext(3)": urllib.parse.urljoin(model_url, "mn40_as_ext_mAP_485.pt"),
|
46 |
+
# varying hop size (time resolution)
|
47 |
+
"mn10_as_hop_15": urllib.parse.urljoin(model_url, "mn10_as_hop_15_mAP_463.pt"),
|
48 |
+
"mn10_as_hop_20": urllib.parse.urljoin(model_url, "mn10_as_hop_20_mAP_456.pt"),
|
49 |
+
"mn10_as_hop_25": urllib.parse.urljoin(model_url, "mn10_as_hop_25_mAP_447.pt"),
|
50 |
+
# varying n_mels (frequency resolution)
|
51 |
+
"mn10_as_mels_40": urllib.parse.urljoin(model_url, "mn10_as_mels_40_mAP_453.pt"),
|
52 |
+
"mn10_as_mels_64": urllib.parse.urljoin(model_url, "mn10_as_mels_64_mAP_461.pt"),
|
53 |
+
"mn10_as_mels_256": urllib.parse.urljoin(model_url, "mn10_as_mels_256_mAP_474.pt"),
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
class MobileNetV3(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
inverted_residual_setting: List[InvertedResidualConfig],
|
61 |
+
last_channel: int,
|
62 |
+
num_classes: int = 1000,
|
63 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
64 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
65 |
+
dropout: float = 0.2,
|
66 |
+
in_conv_kernel: int = 3,
|
67 |
+
in_conv_stride: int = 2,
|
68 |
+
in_channels: int = 1,
|
69 |
+
**kwargs: Any,
|
70 |
+
) -> None:
|
71 |
+
"""
|
72 |
+
MobileNet V3 main class
|
73 |
+
|
74 |
+
Args:
|
75 |
+
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
|
76 |
+
last_channel (int): The number of channels on the penultimate layer
|
77 |
+
num_classes (int): Number of classes
|
78 |
+
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for models
|
79 |
+
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
|
80 |
+
dropout (float): The droupout probability
|
81 |
+
in_conv_kernel (int): Size of kernel for first convolution
|
82 |
+
in_conv_stride (int): Size of stride for first convolution
|
83 |
+
in_channels (int): Number of input channels
|
84 |
+
"""
|
85 |
+
super(MobileNetV3, self).__init__()
|
86 |
+
|
87 |
+
if not inverted_residual_setting:
|
88 |
+
raise ValueError("The inverted_residual_setting should not be empty")
|
89 |
+
elif not (
|
90 |
+
isinstance(inverted_residual_setting, Sequence)
|
91 |
+
and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
|
92 |
+
):
|
93 |
+
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
|
94 |
+
|
95 |
+
if block is None:
|
96 |
+
block = InvertedResidual
|
97 |
+
|
98 |
+
depthwise_norm_layer = norm_layer = \
|
99 |
+
norm_layer if norm_layer is not None else partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
|
100 |
+
|
101 |
+
layers: List[nn.Module] = []
|
102 |
+
|
103 |
+
kernel_sizes = [in_conv_kernel]
|
104 |
+
strides = [in_conv_stride]
|
105 |
+
|
106 |
+
# building first layer
|
107 |
+
firstconv_output_channels = inverted_residual_setting[0].input_channels
|
108 |
+
layers.append(
|
109 |
+
ConvNormActivation(
|
110 |
+
in_channels,
|
111 |
+
firstconv_output_channels,
|
112 |
+
kernel_size=in_conv_kernel,
|
113 |
+
stride=in_conv_stride,
|
114 |
+
norm_layer=norm_layer,
|
115 |
+
activation_layer=nn.Hardswish,
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
# get squeeze excitation config
|
120 |
+
se_cnf = kwargs.get('se_conf', None)
|
121 |
+
|
122 |
+
# building inverted residual blocks
|
123 |
+
# - keep track of size of frequency and time dimensions for possible application of Squeeze-and-Excitation
|
124 |
+
# on the frequency/time dimension
|
125 |
+
# - applying Squeeze-and-Excitation on the time dimension is not recommended as this constrains the network to
|
126 |
+
# a particular length of the audio clip, whereas Squeeze-and-Excitation on the frequency bands is fine,
|
127 |
+
# as the number of frequency bands is usually not changing
|
128 |
+
f_dim, t_dim = kwargs.get('input_dims', (128, 1000))
|
129 |
+
# take into account first conv layer
|
130 |
+
f_dim = cnn_out_size(f_dim, 1, 1, 3, 2)
|
131 |
+
t_dim = cnn_out_size(t_dim, 1, 1, 3, 2)
|
132 |
+
for cnf in inverted_residual_setting:
|
133 |
+
f_dim = cnf.out_size(f_dim)
|
134 |
+
t_dim = cnf.out_size(t_dim)
|
135 |
+
cnf.f_dim, cnf.t_dim = f_dim, t_dim # update dimensions in block config
|
136 |
+
layers.append(block(cnf, se_cnf, norm_layer, depthwise_norm_layer))
|
137 |
+
kernel_sizes.append(cnf.kernel)
|
138 |
+
strides.append(cnf.stride)
|
139 |
+
|
140 |
+
# building last several layers
|
141 |
+
lastconv_input_channels = inverted_residual_setting[-1].out_channels
|
142 |
+
lastconv_output_channels = 6 * lastconv_input_channels
|
143 |
+
layers.append(
|
144 |
+
ConvNormActivation(
|
145 |
+
lastconv_input_channels,
|
146 |
+
lastconv_output_channels,
|
147 |
+
kernel_size=1,
|
148 |
+
norm_layer=norm_layer,
|
149 |
+
activation_layer=nn.Hardswish,
|
150 |
+
)
|
151 |
+
)
|
152 |
+
|
153 |
+
self.features = nn.Sequential(*layers)
|
154 |
+
self.head_type = kwargs.get("head_type", False)
|
155 |
+
if self.head_type == "multihead_attention_pooling":
|
156 |
+
self.classifier = MultiHeadAttentionPooling(lastconv_output_channels, num_classes,
|
157 |
+
num_heads=kwargs.get("multihead_attention_heads"))
|
158 |
+
elif self.head_type == "fully_convolutional":
|
159 |
+
self.classifier = nn.Sequential(
|
160 |
+
nn.Conv2d(
|
161 |
+
lastconv_output_channels,
|
162 |
+
num_classes,
|
163 |
+
kernel_size=(1, 1),
|
164 |
+
stride=(1, 1),
|
165 |
+
padding=(0, 0),
|
166 |
+
bias=False),
|
167 |
+
nn.BatchNorm2d(num_classes),
|
168 |
+
nn.AdaptiveAvgPool2d((1, 1)),
|
169 |
+
)
|
170 |
+
elif self.head_type == "mlp":
|
171 |
+
self.classifier = nn.Sequential(
|
172 |
+
nn.AdaptiveAvgPool2d(1),
|
173 |
+
nn.Flatten(start_dim=1),
|
174 |
+
nn.Linear(lastconv_output_channels, last_channel),
|
175 |
+
nn.Hardswish(inplace=True),
|
176 |
+
nn.Dropout(p=dropout, inplace=True),
|
177 |
+
nn.Linear(last_channel, num_classes),
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
raise NotImplementedError(f"Head '{self.head_type}' unknown. Must be one of: 'mlp', "
|
181 |
+
f"'fully_convolutional', 'multihead_attention_pooling'")
|
182 |
+
|
183 |
+
for m in self.modules():
|
184 |
+
if isinstance(m, nn.Conv2d):
|
185 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
186 |
+
if m.bias is not None:
|
187 |
+
nn.init.zeros_(m.bias)
|
188 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
|
189 |
+
nn.init.ones_(m.weight)
|
190 |
+
nn.init.zeros_(m.bias)
|
191 |
+
elif isinstance(m, nn.Linear):
|
192 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
193 |
+
if m.bias is not None:
|
194 |
+
nn.init.zeros_(m.bias)
|
195 |
+
|
196 |
+
def _forward_impl(self, x: Tensor) -> (Tensor, Tensor):
|
197 |
+
x = self.features(x)
|
198 |
+
features = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
|
199 |
+
x = self.classifier(x).squeeze()
|
200 |
+
if features.dim() == 1 and x.dim() == 1:
|
201 |
+
# squeezed batch dimension
|
202 |
+
features = features.unsqueeze(0)
|
203 |
+
x = x.unsqueeze(0)
|
204 |
+
return x, features
|
205 |
+
|
206 |
+
def forward(self, x: Tensor) -> (Tensor, Tensor):
|
207 |
+
return self._forward_impl(x)
|
208 |
+
|
209 |
+
|
210 |
+
def _mobilenet_v3_conf(
|
211 |
+
width_mult: float = 1.0,
|
212 |
+
reduced_tail: bool = False,
|
213 |
+
dilated: bool = False,
|
214 |
+
c4_stride: int = 2,
|
215 |
+
**kwargs: Any
|
216 |
+
):
|
217 |
+
reduce_divider = 2 if reduced_tail else 1
|
218 |
+
dilation = 2 if dilated else 1
|
219 |
+
|
220 |
+
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
|
221 |
+
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
|
222 |
+
|
223 |
+
# InvertedResidualConfig:
|
224 |
+
# input_channels, kernel, expanded_channels, out_channels, use_se, activation, stride, dilation, width_mult
|
225 |
+
inverted_residual_setting = [
|
226 |
+
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
|
227 |
+
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
|
228 |
+
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
|
229 |
+
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
|
230 |
+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
231 |
+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
232 |
+
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
|
233 |
+
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
|
234 |
+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
235 |
+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
236 |
+
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
|
237 |
+
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
|
238 |
+
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", c4_stride, dilation), # C4
|
239 |
+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
240 |
+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
241 |
+
]
|
242 |
+
last_channel = adjust_channels(1280 // reduce_divider)
|
243 |
+
|
244 |
+
return inverted_residual_setting, last_channel
|
245 |
+
|
246 |
+
|
247 |
+
def _mobilenet_v3(
|
248 |
+
inverted_residual_setting: List[InvertedResidualConfig],
|
249 |
+
last_channel: int,
|
250 |
+
pretrained_name: str,
|
251 |
+
**kwargs: Any,
|
252 |
+
):
|
253 |
+
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
|
254 |
+
|
255 |
+
if pretrained_name in pretrained_models:
|
256 |
+
model_url = pretrained_models.get(pretrained_name)
|
257 |
+
state_dict = load_state_dict_from_url(model_url, model_dir=model_dir, map_location="cpu")
|
258 |
+
if kwargs['num_classes'] != state_dict['classifier.5.bias'].size(0):
|
259 |
+
# if the number of logits is not matching the state dict,
|
260 |
+
# drop the corresponding pre-trained part
|
261 |
+
print(f"Number of classes defined: {kwargs['num_classes']}, "
|
262 |
+
f"but try to load pre-trained layer with logits: {state_dict['classifier.5.bias'].size(0)}\n"
|
263 |
+
"Dropping last layer.")
|
264 |
+
del state_dict['classifier.5.weight']
|
265 |
+
del state_dict['classifier.5.bias']
|
266 |
+
try:
|
267 |
+
model.load_state_dict(state_dict)
|
268 |
+
except RuntimeError as e:
|
269 |
+
print(str(e))
|
270 |
+
print("Loading weights pre-trained weights in a non-strict manner.")
|
271 |
+
model.load_state_dict(state_dict, strict=False)
|
272 |
+
elif pretrained_name:
|
273 |
+
raise NotImplementedError(f"Model name '{pretrained_name}' unknown.")
|
274 |
+
return model
|
275 |
+
|
276 |
+
|
277 |
+
def mobilenet_v3(pretrained_name: str = None, **kwargs: Any) \
|
278 |
+
-> MobileNetV3:
|
279 |
+
"""
|
280 |
+
Constructs a MobileNetV3 architecture from
|
281 |
+
"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>".
|
282 |
+
"""
|
283 |
+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(**kwargs)
|
284 |
+
return _mobilenet_v3(inverted_residual_setting, last_channel, pretrained_name, **kwargs)
|
285 |
+
|
286 |
+
|
287 |
+
def get_model(num_classes: int = 527, pretrained_name: str = None, width_mult: float = 1.0,
|
288 |
+
reduced_tail: bool = False, dilated: bool = False, c4_stride: int = 2, head_type: str = "mlp",
|
289 |
+
multihead_attention_heads: int = 4, input_dim_f: int = 128,
|
290 |
+
input_dim_t: int = 1000, se_dims: str = 'c', se_agg: str = "max", se_r: int = 4):
|
291 |
+
"""
|
292 |
+
Arguments to modify the instantiation of a MobileNetv3
|
293 |
+
|
294 |
+
Args:
|
295 |
+
num_classes (int): Specifies number of classes to predict
|
296 |
+
pretrained_name (str): Specifies name of pre-trained model to load
|
297 |
+
width_mult (float): Scales width of network
|
298 |
+
reduced_tail (bool): Scales down network tail
|
299 |
+
dilated (bool): Applies dilated convolution to network tail
|
300 |
+
c4_stride (int): Set to '2' in original implementation;
|
301 |
+
might be changed to modify the size of receptive field
|
302 |
+
head_type (str): decides which classification head to use
|
303 |
+
multihead_attention_heads (int): number of heads in case 'multihead_attention_heads' is used
|
304 |
+
input_dim_f (int): number of frequency bands
|
305 |
+
input_dim_t (int): number of time frames
|
306 |
+
se_dims (Tuple): choose dimension to apply squeeze-excitation on, if multiple dimensions are chosen, then
|
307 |
+
squeeze-excitation is applied concurrently and se layer outputs are fused by se_agg operation
|
308 |
+
se_agg (str): operation to fuse output of concurrent se layers
|
309 |
+
se_r (int): squeeze excitation bottleneck size
|
310 |
+
se_dims (str): contains letters corresponding to dimensions 'c' - channel, 'f' - frequency, 't' - time
|
311 |
+
"""
|
312 |
+
|
313 |
+
dim_map = {'c': 1, 'f': 2, 't': 3}
|
314 |
+
assert len(se_dims) <= 3 and all([s in dim_map.keys() for s in se_dims]) or se_dims == 'none'
|
315 |
+
input_dims = (input_dim_f, input_dim_t)
|
316 |
+
if se_dims == 'none':
|
317 |
+
se_dims = None
|
318 |
+
else:
|
319 |
+
se_dims = [dim_map[s] for s in se_dims]
|
320 |
+
se_conf = dict(se_dims=se_dims, se_agg=se_agg, se_r=se_r)
|
321 |
+
m = mobilenet_v3(pretrained_name=pretrained_name, num_classes=num_classes,
|
322 |
+
width_mult=width_mult, reduced_tail=reduced_tail, dilated=dilated, c4_stride=c4_stride,
|
323 |
+
head_type=head_type, multihead_attention_heads=multihead_attention_heads,
|
324 |
+
input_dims=input_dims, se_conf=se_conf
|
325 |
+
)
|
326 |
+
print(m)
|
327 |
+
return m
|
328 |
+
|
329 |
+
|
330 |
+
class EnsemblerModel(nn.Module):
|
331 |
+
def __init__(self, model_names):
|
332 |
+
super(EnsemblerModel, self).__init__()
|
333 |
+
self.models = nn.ModuleList([get_model(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name)
|
334 |
+
for model_name in model_names])
|
335 |
+
|
336 |
+
def forward(self, x):
|
337 |
+
all_out = None
|
338 |
+
for m in self.models:
|
339 |
+
out, _ = m(x)
|
340 |
+
if all_out is None:
|
341 |
+
all_out = out
|
342 |
+
else:
|
343 |
+
all_out = out + all_out
|
344 |
+
all_out = all_out / len(self.models)
|
345 |
+
return all_out, all_out
|
346 |
+
|
347 |
+
|
348 |
+
def get_ensemble_model(model_names):
|
349 |
+
return EnsemblerModel(model_names)
|
efficientat/models/attention_pooling.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
from efficientat.models.utils import collapse_dim
|
7 |
+
|
8 |
+
|
9 |
+
class MultiHeadAttentionPooling(nn.Module):
|
10 |
+
"""Multi-Head Attention as used in PSLA paper (https://arxiv.org/pdf/2102.01243.pdf)
|
11 |
+
"""
|
12 |
+
def __init__(self, in_dim, out_dim, att_activation: str = 'sigmoid',
|
13 |
+
clf_activation: str = 'ident', num_heads: int = 4, epsilon: float = 1e-7):
|
14 |
+
super(MultiHeadAttentionPooling, self).__init__()
|
15 |
+
|
16 |
+
self.in_dim = in_dim
|
17 |
+
self.out_dim = out_dim
|
18 |
+
self.num_heads = num_heads
|
19 |
+
self.epsilon = epsilon
|
20 |
+
|
21 |
+
self.att_activation = att_activation
|
22 |
+
self.clf_activation = clf_activation
|
23 |
+
|
24 |
+
# out size: out dim x 2 (att and clf paths) x num_heads
|
25 |
+
self.subspace_proj = nn.Linear(self.in_dim, self.out_dim * 2 * self.num_heads)
|
26 |
+
self.head_weight = nn.Parameter(torch.tensor([1.0 / self.num_heads] * self.num_heads).view(1, -1, 1))
|
27 |
+
|
28 |
+
def activate(self, x, activation):
|
29 |
+
if activation == 'linear':
|
30 |
+
return x
|
31 |
+
elif activation == 'relu':
|
32 |
+
return F.relu(x)
|
33 |
+
elif activation == 'sigmoid':
|
34 |
+
return torch.sigmoid(x)
|
35 |
+
elif activation == 'softmax':
|
36 |
+
return F.softmax(x, dim=1)
|
37 |
+
elif activation == 'ident':
|
38 |
+
return x
|
39 |
+
|
40 |
+
def forward(self, x) -> Tensor:
|
41 |
+
"""x: Tensor of size (batch_size, channels, frequency bands, sequence length)
|
42 |
+
"""
|
43 |
+
x = collapse_dim(x, dim=2) # results in tensor of size (batch_size, channels, sequence_length)
|
44 |
+
x = x.transpose(1, 2) # results in tensor of size (batch_size, sequence_length, channels)
|
45 |
+
b, n, c = x.shape
|
46 |
+
|
47 |
+
x = self.subspace_proj(x).reshape(b, n, 2, self.num_heads, self.out_dim).permute(2, 0, 3, 1, 4)
|
48 |
+
att, val = x[0], x[1]
|
49 |
+
val = self.activate(val, self.clf_activation)
|
50 |
+
att = self.activate(att, self.att_activation)
|
51 |
+
att = torch.clamp(att, self.epsilon, 1. - self.epsilon)
|
52 |
+
att = att / torch.sum(att, dim=2, keepdim=True)
|
53 |
+
|
54 |
+
out = torch.sum(att * val, dim=2) * self.head_weight
|
55 |
+
out = torch.sum(out, dim=1)
|
56 |
+
return out
|
efficientat/models/block_types.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Callable, List
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch import Tensor
|
5 |
+
from torchvision.ops.misc import ConvNormActivation
|
6 |
+
|
7 |
+
from efficientat.models.utils import make_divisible, cnn_out_size
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class ConcurrentSEBlock(torch.nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
c_dim: int,
|
15 |
+
f_dim: int,
|
16 |
+
t_dim: int,
|
17 |
+
se_cnf: Dict
|
18 |
+
) -> None:
|
19 |
+
super().__init__()
|
20 |
+
dims = [c_dim, f_dim, t_dim]
|
21 |
+
self.conc_se_layers = nn.ModuleList()
|
22 |
+
for d in se_cnf['se_dims']:
|
23 |
+
input_dim = dims[d-1]
|
24 |
+
squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8)
|
25 |
+
self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d))
|
26 |
+
if se_cnf['se_agg'] == "max":
|
27 |
+
self.agg_op = lambda x: torch.max(x, dim=0)[0]
|
28 |
+
elif se_cnf['se_agg'] == "avg":
|
29 |
+
self.agg_op = lambda x: torch.mean(x, dim=0)
|
30 |
+
elif se_cnf['se_agg'] == "add":
|
31 |
+
self.agg_op = lambda x: torch.sum(x, dim=0)
|
32 |
+
elif se_cnf['se_agg'] == "min":
|
33 |
+
self.agg_op = lambda x: torch.min(x, dim=0)[0]
|
34 |
+
else:
|
35 |
+
raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented")
|
36 |
+
|
37 |
+
def forward(self, input: Tensor) -> Tensor:
|
38 |
+
# apply all concurrent se layers
|
39 |
+
se_outs = []
|
40 |
+
for se_layer in self.conc_se_layers:
|
41 |
+
se_outs.append(se_layer(input))
|
42 |
+
out = self.agg_op(torch.stack(se_outs, dim=0))
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class SqueezeExcitation(torch.nn.Module):
|
47 |
+
"""
|
48 |
+
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507.
|
49 |
+
Args:
|
50 |
+
input_dim (int): Input dimension
|
51 |
+
squeeze_dim (int): Size of Bottleneck
|
52 |
+
activation (Callable): activation applied to bottleneck
|
53 |
+
scale_activation (Callable): activation applied to the output
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
input_dim: int,
|
59 |
+
squeeze_dim: int,
|
60 |
+
se_dim: int,
|
61 |
+
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
|
62 |
+
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
|
63 |
+
) -> None:
|
64 |
+
super().__init__()
|
65 |
+
self.fc1 = torch.nn.Linear(input_dim, squeeze_dim)
|
66 |
+
self.fc2 = torch.nn.Linear(squeeze_dim, input_dim)
|
67 |
+
assert se_dim in [1, 2, 3]
|
68 |
+
self.se_dim = [1, 2, 3]
|
69 |
+
self.se_dim.remove(se_dim)
|
70 |
+
self.activation = activation()
|
71 |
+
self.scale_activation = scale_activation()
|
72 |
+
|
73 |
+
def _scale(self, input: Tensor) -> Tensor:
|
74 |
+
scale = torch.mean(input, self.se_dim, keepdim=True)
|
75 |
+
shape = scale.size()
|
76 |
+
scale = self.fc1(scale.squeeze(2).squeeze(2))
|
77 |
+
scale = self.activation(scale)
|
78 |
+
scale = self.fc2(scale)
|
79 |
+
scale = scale
|
80 |
+
return self.scale_activation(scale).view(shape)
|
81 |
+
|
82 |
+
def forward(self, input: Tensor) -> Tensor:
|
83 |
+
scale = self._scale(input)
|
84 |
+
return scale * input
|
85 |
+
|
86 |
+
|
87 |
+
class InvertedResidualConfig:
|
88 |
+
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
input_channels: int,
|
92 |
+
kernel: int,
|
93 |
+
expanded_channels: int,
|
94 |
+
out_channels: int,
|
95 |
+
use_se: bool,
|
96 |
+
activation: str,
|
97 |
+
stride: int,
|
98 |
+
dilation: int,
|
99 |
+
width_mult: float,
|
100 |
+
):
|
101 |
+
self.input_channels = self.adjust_channels(input_channels, width_mult)
|
102 |
+
self.kernel = kernel
|
103 |
+
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
|
104 |
+
self.out_channels = self.adjust_channels(out_channels, width_mult)
|
105 |
+
self.use_se = use_se
|
106 |
+
self.use_hs = activation == "HS"
|
107 |
+
self.stride = stride
|
108 |
+
self.dilation = dilation
|
109 |
+
self.f_dim = None
|
110 |
+
self.t_dim = None
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def adjust_channels(channels: int, width_mult: float):
|
114 |
+
return make_divisible(channels * width_mult, 8)
|
115 |
+
|
116 |
+
def out_size(self, in_size):
|
117 |
+
padding = (self.kernel - 1) // 2 * self.dilation
|
118 |
+
return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride)
|
119 |
+
|
120 |
+
|
121 |
+
class InvertedResidual(nn.Module):
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
cnf: InvertedResidualConfig,
|
125 |
+
se_cnf: Dict,
|
126 |
+
norm_layer: Callable[..., nn.Module],
|
127 |
+
depthwise_norm_layer: Callable[..., nn.Module]
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
if not (1 <= cnf.stride <= 2):
|
131 |
+
raise ValueError("illegal stride value")
|
132 |
+
|
133 |
+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
|
134 |
+
|
135 |
+
layers: List[nn.Module] = []
|
136 |
+
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
|
137 |
+
|
138 |
+
# expand
|
139 |
+
if cnf.expanded_channels != cnf.input_channels:
|
140 |
+
layers.append(
|
141 |
+
ConvNormActivation(
|
142 |
+
cnf.input_channels,
|
143 |
+
cnf.expanded_channels,
|
144 |
+
kernel_size=1,
|
145 |
+
norm_layer=norm_layer,
|
146 |
+
activation_layer=activation_layer,
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
# depthwise
|
151 |
+
stride = 1 if cnf.dilation > 1 else cnf.stride
|
152 |
+
layers.append(
|
153 |
+
ConvNormActivation(
|
154 |
+
cnf.expanded_channels,
|
155 |
+
cnf.expanded_channels,
|
156 |
+
kernel_size=cnf.kernel,
|
157 |
+
stride=stride,
|
158 |
+
dilation=cnf.dilation,
|
159 |
+
groups=cnf.expanded_channels,
|
160 |
+
norm_layer=depthwise_norm_layer,
|
161 |
+
activation_layer=activation_layer,
|
162 |
+
)
|
163 |
+
)
|
164 |
+
if cnf.use_se and se_cnf['se_dims'] is not None:
|
165 |
+
layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf))
|
166 |
+
|
167 |
+
# project
|
168 |
+
layers.append(
|
169 |
+
ConvNormActivation(
|
170 |
+
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
|
171 |
+
)
|
172 |
+
)
|
173 |
+
|
174 |
+
self.block = nn.Sequential(*layers)
|
175 |
+
self.out_channels = cnf.out_channels
|
176 |
+
self._is_cn = cnf.stride > 1
|
177 |
+
|
178 |
+
def forward(self, inp: Tensor) -> Tensor:
|
179 |
+
result = self.block(inp)
|
180 |
+
if self.use_res_connect:
|
181 |
+
result += inp
|
182 |
+
return result
|
efficientat/models/preprocess.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torchaudio
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class AugmentMelSTFT(nn.Module):
|
7 |
+
def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192,
|
8 |
+
fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000):
|
9 |
+
torch.nn.Module.__init__(self)
|
10 |
+
# adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e
|
11 |
+
|
12 |
+
self.win_length = win_length
|
13 |
+
self.n_mels = n_mels
|
14 |
+
self.n_fft = n_fft
|
15 |
+
self.sr = sr
|
16 |
+
self.fmin = fmin
|
17 |
+
if fmax is None:
|
18 |
+
fmax = sr // 2 - fmax_aug_range // 2
|
19 |
+
print(f"Warning: FMAX is None setting to {fmax} ")
|
20 |
+
self.fmax = fmax
|
21 |
+
self.hopsize = hopsize
|
22 |
+
self.register_buffer('window',
|
23 |
+
torch.hann_window(win_length, periodic=False),
|
24 |
+
persistent=False)
|
25 |
+
assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation"
|
26 |
+
assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation"
|
27 |
+
self.fmin_aug_range = fmin_aug_range
|
28 |
+
self.fmax_aug_range = fmax_aug_range
|
29 |
+
|
30 |
+
self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False)
|
31 |
+
if freqm == 0:
|
32 |
+
self.freqm = torch.nn.Identity()
|
33 |
+
else:
|
34 |
+
self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
|
35 |
+
if timem == 0:
|
36 |
+
self.timem = torch.nn.Identity()
|
37 |
+
else:
|
38 |
+
self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1)
|
42 |
+
x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length,
|
43 |
+
center=True, normalized=False, window=self.window, return_complex=False)
|
44 |
+
x = (x ** 2).sum(dim=-1) # power mag
|
45 |
+
fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item()
|
46 |
+
fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item()
|
47 |
+
# don't augment eval data
|
48 |
+
if not self.training:
|
49 |
+
fmin = self.fmin
|
50 |
+
fmax = self.fmax
|
51 |
+
|
52 |
+
mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr,
|
53 |
+
fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0)
|
54 |
+
mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0),
|
55 |
+
device=x.device)
|
56 |
+
with torch.cuda.amp.autocast(enabled=False):
|
57 |
+
melspec = torch.matmul(mel_basis, x)
|
58 |
+
|
59 |
+
melspec = (melspec + 0.00001).log()
|
60 |
+
|
61 |
+
if self.training:
|
62 |
+
melspec = self.freqm(melspec)
|
63 |
+
melspec = self.timem(melspec)
|
64 |
+
|
65 |
+
melspec = (melspec + 4.5) / 5. # fast normalization
|
66 |
+
|
67 |
+
return melspec
|
efficientat/models/utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Callable
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
|
8 |
+
def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
|
9 |
+
"""
|
10 |
+
This function is taken from the original tf repo.
|
11 |
+
It ensures that all layers have a channel number that is divisible by 8
|
12 |
+
It can be seen here:
|
13 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
14 |
+
"""
|
15 |
+
if min_value is None:
|
16 |
+
min_value = divisor
|
17 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
18 |
+
# Make sure that round down does not go down by more than 10%.
|
19 |
+
if new_v < 0.9 * v:
|
20 |
+
new_v += divisor
|
21 |
+
return new_v
|
22 |
+
|
23 |
+
|
24 |
+
def cnn_out_size(in_size, padding, dilation, kernel, stride):
|
25 |
+
s = in_size + 2 * padding - dilation * (kernel - 1) - 1
|
26 |
+
return math.floor(s / stride + 1)
|
27 |
+
|
28 |
+
|
29 |
+
def collapse_dim(x: Tensor, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean,
|
30 |
+
combine_dim: int = None):
|
31 |
+
"""
|
32 |
+
Collapses dimension of multi-dimensional tensor by pooling or combining dimensions
|
33 |
+
:param x: input Tensor
|
34 |
+
:param dim: dimension to collapse
|
35 |
+
:param mode: 'pool' or 'combine'
|
36 |
+
:param pool_fn: function to be applied in case of pooling
|
37 |
+
:param combine_dim: dimension to join 'dim' to
|
38 |
+
:return: collapsed tensor
|
39 |
+
"""
|
40 |
+
if mode == "pool":
|
41 |
+
return pool_fn(x, dim)
|
42 |
+
elif mode == "combine":
|
43 |
+
s = list(x.size())
|
44 |
+
s[combine_dim] *= dim
|
45 |
+
s[dim] //= dim
|
46 |
+
return x.view(s)
|
47 |
+
|
48 |
+
|
49 |
+
class CollapseDim(nn.Module):
|
50 |
+
def __init__(self, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean,
|
51 |
+
combine_dim: int = None):
|
52 |
+
super(CollapseDim, self).__init__()
|
53 |
+
self.dim = dim
|
54 |
+
self.mode = mode
|
55 |
+
self.pool_fn = pool_fn
|
56 |
+
self.combine_dim = combine_dim
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
return collapse_dim(x, dim=self.dim, mode=self.mode, pool_fn=self.pool_fn, combine_dim=self.combine_dim)
|
efficientat/resources/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Download the latest version from this repo's Github Releases and place them inside this folder.
|
efficientat/resources/metro_station-paris.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75d28a33f45fd6eebd862bb25a3738dd83b7aa92ad64c26d5b1879ff2a715b3f
|
3 |
+
size 1323044
|
logo.png
ADDED
![]() |
Git LFS Details
|
metro_station-paris.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:75d28a33f45fd6eebd862bb25a3738dd83b7aa92ad64c26d5b1879ff2a715b3f
|
3 |
+
size 1323044
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libsndfile1
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
av==10.0.0
|
2 |
+
h5py==3.7.0
|
3 |
+
langchain
|
4 |
+
librosa==0.9.2
|
5 |
+
torch
|
6 |
+
torchvision
|
7 |
+
pandas
|
8 |
+
numpy
|
9 |
+
scikit_learn
|
10 |
+
torchaudio
|
11 |
+
|