slslslrhfem commited on
Commit
693d2c7
ยท
1 Parent(s): c410f34

change probability func

Browse files
Files changed (1) hide show
  1. inference.py +18 -27
inference.py CHANGED
@@ -169,16 +169,16 @@ def scaled_sigmoid(x, scale_factor=0.2, linear_property=0.3):
169
 
170
  # Apply the scaled sigmoid
171
 
172
-
173
  def get_model(model_type, device):
174
  """Load the specified model."""
175
  if model_type == "MERT":
176
- #from model import MusicAudioClassifier
177
- #model = MusicAudioClassifier(input_dim=768, is_emb=True, mode = 'both', share_parameter = False).to(device)
178
- ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt'#'mert_finetune_10.pth'
179
- model = MERT_AudioCNN.load_from_checkpoint(ckpt_file).to(device)
 
 
180
  model.eval()
181
- # model.load_state_dict(torch.load(ckpt_file, map_location=device))
182
  embed_dim = 768
183
 
184
  elif model_type == "pure_MERT":
@@ -189,42 +189,33 @@ def get_model(model_type, device):
189
  else:
190
  raise ValueError(f"Unknown model type: {model_type}")
191
 
192
-
193
  model.eval()
194
  return model, embed_dim
195
-
196
 
197
  def inference(audio_path):
198
- backbone_model, input_dim = get_model('MERT', 'cuda')
199
- segments, padding_mask = load_audio(audio_path, sr=24000)
200
- segments = segments.to('cuda').to(torch.float32)
201
- padding_mask = padding_mask.to('cuda').unsqueeze(0)
202
- logits,embedding = backbone_model(segments.squeeze(1))
203
- # test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
204
- # test_data, test_target = test_dataset[0]
205
- # test_data = test_data.to('cuda').to(torch.float32)
206
- # test_target = test_target.to('cuda')
207
- # output, _ = backbone_model(test_data.unsqueeze(0))
208
 
 
 
 
 
 
209
 
210
-
211
- # ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ์ถ”๊ฐ€
212
  model = MusicAudioClassifier.load_from_checkpoint(
213
- checkpoint_path = 'checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt',
214
- input_dim=input_dim,
215
- #emb_model=backbone_model
216
  )
217
 
218
-
219
  # Run inference
220
  print(f"Segments shape: {segments.shape}")
221
  print("Running inference...")
222
- results = run_inference(model, embedding, padding_mask, 'cuda')
223
 
224
  # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
225
  print(f"Results: {results}")
226
-
227
-
228
 
229
  return results
230
 
 
169
 
170
  # Apply the scaled sigmoid
171
 
 
172
  def get_model(model_type, device):
173
  """Load the specified model."""
174
  if model_type == "MERT":
175
+ ckpt_file = 'checkpoints/step=003432-val_loss=0.0216-val_acc=0.9963.ckpt'
176
+ # map_location ์ถ”๊ฐ€
177
+ model = MERT_AudioCNN.load_from_checkpoint(
178
+ ckpt_file,
179
+ map_location=device # ๋˜๋Š” 'cuda:0' ๋˜๋Š” 'cpu'
180
+ ).to(device)
181
  model.eval()
 
182
  embed_dim = 768
183
 
184
  elif model_type == "pure_MERT":
 
189
  else:
190
  raise ValueError(f"Unknown model type: {model_type}")
191
 
 
192
  model.eval()
193
  return model, embed_dim
 
194
 
195
  def inference(audio_path):
196
+ # device ์„ค์ •์„ ๋ช…ํ™•ํžˆ ํ•˜๊ธฐ
197
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
198
 
199
+ backbone_model, input_dim = get_model('MERT', device)
200
+ segments, padding_mask = load_audio(audio_path, sr=24000)
201
+ segments = segments.to(device).to(torch.float32)
202
+ padding_mask = padding_mask.to(device).unsqueeze(0)
203
+ logits, embedding = backbone_model(segments.squeeze(1))
204
 
205
+ # ๋ชจ๋ธ ๋กœ๋“œํ•  ๋•Œ๋„ map_location ์ถ”๊ฐ€
 
206
  model = MusicAudioClassifier.load_from_checkpoint(
207
+ checkpoint_path='checkpoints/EmbeddingModel_MERT_768_2class_weighted-epoch=0014-val_loss=0.0099-val_acc=0.9993-val_f1=0.9978-val_precision=0.9967-val_recall=0.9989.ckpt',
208
+ input_dim=input_dim,
209
+ map_location=device # ์ด ๋ถ€๋ถ„ ์ถ”๊ฐ€
210
  )
211
 
 
212
  # Run inference
213
  print(f"Segments shape: {segments.shape}")
214
  print("Running inference...")
215
+ results = run_inference(model, embedding, padding_mask, device)
216
 
217
  # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
218
  print(f"Results: {results}")
 
 
219
 
220
  return results
221