Spaces:
Running
on
Zero
Running
on
Zero
slslslrhfem
commited on
Commit
ยท
693d2c7
1
Parent(s):
c410f34
change probability func
Browse files- inference.py +18 -27
inference.py
CHANGED
|
@@ -169,16 +169,16 @@ def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
|
|
| 169 |
|
| 170 |
# Apply the scaled sigmoid
|
| 171 |
|
| 172 |
-
|
| 173 |
def get_model(model_type, device):
|
| 174 |
"""Load the specified model."""
|
| 175 |
if model_type == "MERT":
|
| 176 |
-
|
| 177 |
-
#
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
| 180 |
model.eval()
|
| 181 |
-
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
| 182 |
embed_dim = 768
|
| 183 |
|
| 184 |
elif model_type == "pure_MERT":
|
|
@@ -189,42 +189,33 @@ def get_model(model_type, device):
|
|
| 189 |
else:
|
| 190 |
raise ValueError(f"Unknown model type: {model_type}")
|
| 191 |
|
| 192 |
-
|
| 193 |
model.eval()
|
| 194 |
return model, embed_dim
|
| 195 |
-
|
| 196 |
|
| 197 |
def inference(audio_path):
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
segments = segments.to('cuda').to(torch.float32)
|
| 201 |
-
padding_mask = padding_mask.to('cuda').unsqueeze(0)
|
| 202 |
-
logits,embedding = backbone_model(segments.squeeze(1))
|
| 203 |
-
# test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
|
| 204 |
-
# test_data, test_target = test_dataset[0]
|
| 205 |
-
# test_data = test_data.to('cuda').to(torch.float32)
|
| 206 |
-
# test_target = test_target.to('cuda')
|
| 207 |
-
# output, _ = backbone_model(test_data.unsqueeze(0))
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
|
| 211 |
-
# ๋ชจ๋ธ ๋ก๋ ๋ถ๋ถ ์ถ๊ฐ
|
| 212 |
model = MusicAudioClassifier.load_from_checkpoint(
|
| 213 |
-
checkpoint_path
|
| 214 |
-
input_dim=input_dim,
|
| 215 |
-
#
|
| 216 |
)
|
| 217 |
|
| 218 |
-
|
| 219 |
# Run inference
|
| 220 |
print(f"Segments shape: {segments.shape}")
|
| 221 |
print("Running inference...")
|
| 222 |
-
results = run_inference(model, embedding, padding_mask,
|
| 223 |
|
| 224 |
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 225 |
print(f"Results: {results}")
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
| 229 |
return results
|
| 230 |
|
|
|
|
| 169 |
|
| 170 |
# Apply the scaled sigmoid
|
| 171 |
|
|
|
|
| 172 |
def get_model(model_type, device):
|
| 173 |
"""Load the specified model."""
|
| 174 |
if model_type == "MERT":
|
| 175 |
+
ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt'
|
| 176 |
+
# map_location ์ถ๊ฐ
|
| 177 |
+
model = MERT_AudioCNN.load_from_checkpoint(
|
| 178 |
+
ckpt_file,
|
| 179 |
+
map_location=device # ๋๋ 'cuda:0' ๋๋ 'cpu'
|
| 180 |
+
).to(device)
|
| 181 |
model.eval()
|
|
|
|
| 182 |
embed_dim = 768
|
| 183 |
|
| 184 |
elif model_type == "pure_MERT":
|
|
|
|
| 189 |
else:
|
| 190 |
raise ValueError(f"Unknown model type: {model_type}")
|
| 191 |
|
|
|
|
| 192 |
model.eval()
|
| 193 |
return model, embed_dim
|
|
|
|
| 194 |
|
| 195 |
def inference(audio_path):
|
| 196 |
+
# device ์ค์ ์ ๋ช
ํํ ํ๊ธฐ
|
| 197 |
+
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
backbone_model, input_dim = get_model('MERT', device)
|
| 200 |
+
segments, padding_mask = load_audio(audio_path, sr=24000)
|
| 201 |
+
segments = segments.to(device).to(torch.float32)
|
| 202 |
+
padding_mask = padding_mask.to(device).unsqueeze(0)
|
| 203 |
+
logits, embedding = backbone_model(segments.squeeze(1))
|
| 204 |
|
| 205 |
+
# ๋ชจ๋ธ ๋ก๋ํ ๋๋ map_location ์ถ๊ฐ
|
|
|
|
| 206 |
model = MusicAudioClassifier.load_from_checkpoint(
|
| 207 |
+
checkpoint_path='checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt',
|
| 208 |
+
input_dim=input_dim,
|
| 209 |
+
map_location=device # ์ด ๋ถ๋ถ ์ถ๊ฐ
|
| 210 |
)
|
| 211 |
|
|
|
|
| 212 |
# Run inference
|
| 213 |
print(f"Segments shape: {segments.shape}")
|
| 214 |
print("Running inference...")
|
| 215 |
+
results = run_inference(model, embedding, padding_mask, device)
|
| 216 |
|
| 217 |
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
|
| 218 |
print(f"Results: {results}")
|
|
|
|
|
|
|
| 219 |
|
| 220 |
return results
|
| 221 |
|