slslslrhfem commited on
Commit
c410f34
·
1 Parent(s): a2657b4

change probability func

Browse files
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -165,7 +165,7 @@ def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
165
  # Combine sigmoid with linear component
166
  raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
167
  # Clip to ensure bounds
168
- return torch.clamp(raw_prob, min=0.001, max=0.999)
169
 
170
  # Apply the scaled sigmoid
171
 
@@ -175,7 +175,7 @@ def get_model(model_type, device):
175
  if model_type == "MERT":
176
  #from model import MusicAudioClassifier
177
  #model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
178
- ckpt_file = 'checkpoints/step=007000-val_loss=0.1831-val_acc=0.9278.ckpt'#'mert_finetune_10.pth'
179
  model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
180
  model.eval()
181
  # model.load_state_dict(torch.load(ckpt_file, map_location=device))
@@ -210,7 +210,7 @@ def inference(audio_path):
210
 
211
  # 모델 로드 부분 추가
212
  model = MusicAudioClassifier.load_from_checkpoint(
213
- checkpoint_path = 'checkpoints/EmbeddingModel_MERT_768-epoch=0073-val_loss=0.1058-val_acc=0.9585-val_f1=0.9366-val_precision=0.9936-val_recall=0.8857.ckpt',
214
  input_dim=input_dim,
215
  #emb_model=backbone_model
216
  )
 
165
  # Combine sigmoid with linear component
166
  raw_prob = torch.sigmoid(scaled_x) * (1-linear_property) + linear_property * ((x + 25) / 50)
167
  # Clip to ensure bounds
168
+ return torch.clamp(raw_prob, min=0.01, max=0.99)
169
 
170
  # Apply the scaled sigmoid
171
 
 
175
  if model_type == "MERT":
176
  #from model import MusicAudioClassifier
177
  #model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
178
+ ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt'#'mert_finetune_10.pth'
179
  model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
180
  model.eval()
181
  # model.load_state_dict(torch.load(ckpt_file, map_location=device))
 
210
 
211
  # 모델 로드 부분 추가
212
  model = MusicAudioClassifier.load_from_checkpoint(
213
+ checkpoint_path = 'checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt',
214
  input_dim=input_dim,
215
  #emb_model=backbone_model
216
  )