Spaces:
Running
on
Zero
Running
on
Zero
slslslrhfem
commited on
Commit
·
c410f34
1
Parent(s):
a2657b4
change probability func
Browse files- inference.py +3 -3
inference.py
CHANGED
|
@@ -165,7 +165,7 @@ def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
|
|
| 165 |
# Combine sigmoid with linear component
|
| 166 |
raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
|
| 167 |
# Clip to ensure bounds
|
| 168 |
-
return torch.clamp(raw_prob, min=0.
|
| 169 |
|
| 170 |
# Apply the scaled sigmoid
|
| 171 |
|
|
@@ -175,7 +175,7 @@ def get_model(model_type, device):
|
|
| 175 |
if model_type == "MERT":
|
| 176 |
#from model import MusicAudioClassifier
|
| 177 |
#model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
|
| 178 |
-
ckpt_file = 'checkpoints/step=
|
| 179 |
model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
|
| 180 |
model.eval()
|
| 181 |
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
|
@@ -210,7 +210,7 @@ def inference(audio_path):
|
|
| 210 |
|
| 211 |
# 모델 로드 부분 추가
|
| 212 |
model = MusicAudioClassifier.load_from_checkpoint(
|
| 213 |
-
checkpoint_path = 'checkpoints/
|
| 214 |
input_dim=input_dim,
|
| 215 |
#emb_model=backbone_model
|
| 216 |
)
|
|
|
|
| 165 |
# Combine sigmoid with linear component
|
| 166 |
raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
|
| 167 |
# Clip to ensure bounds
|
| 168 |
+
return torch.clamp(raw_prob, min=0.01, max=0.99)
|
| 169 |
|
| 170 |
# Apply the scaled sigmoid
|
| 171 |
|
|
|
|
| 175 |
if model_type == "MERT":
|
| 176 |
#from model import MusicAudioClassifier
|
| 177 |
#model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
|
| 178 |
+
ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt'#'mert_finetune_10.pth'
|
| 179 |
model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
|
| 180 |
model.eval()
|
| 181 |
# model.load_state_dict(torch.load(ckpt_file, map_location=device))
|
|
|
|
| 210 |
|
| 211 |
# 모델 로드 부분 추가
|
| 212 |
model = MusicAudioClassifier.load_from_checkpoint(
|
| 213 |
+
checkpoint_path = 'checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt',
|
| 214 |
input_dim=input_dim,
|
| 215 |
#emb_model=backbone_model
|
| 216 |
)
|