asigalov61 commited on
Commit
9172209
·
verified ·
1 Parent(s): 7f7c6fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -334,21 +334,26 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
334
  model = TransformerWrapper(
335
  num_tokens = PAD_IDX+1,
336
  max_seq_len = SEQ_LEN,
337
- attn_layers = Decoder(dim = 1024, depth = 32, heads = 32, attn_flash = True)
338
  )
339
 
340
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
341
 
342
- model.to(DEVICE)
343
  print('=' * 70)
344
 
345
  print('Loading model checkpoint...')
346
 
347
- model.load_state_dict(
348
- torch.load('Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth',
349
- map_location=DEVICE))
 
 
 
 
 
350
  print('=' * 70)
351
-
 
352
  model.eval()
353
 
354
  if DEVICE == 'cpu':
 
334
  model = TransformerWrapper(
335
  num_tokens = PAD_IDX+1,
336
  max_seq_len = SEQ_LEN,
337
+ attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, attn_flash = True)
338
  )
339
 
340
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
341
 
 
342
  print('=' * 70)
343
 
344
  print('Loading model checkpoint...')
345
 
346
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
347
+ filename='Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
348
+ )
349
+
350
+ model.load_state_dict(torch.load(kar_model_checkpoint, map_location='cpu', weights_only=True))
351
+
352
+ model = torch.compile(model, mode='max-autotune')
353
+
354
  print('=' * 70)
355
+
356
+ model.to(DEVICE)
357
  model.eval()
358
 
359
  if DEVICE == 'cpu':