Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -334,21 +334,26 @@ def CompareMIDIs(input_src_midi, input_trg_midi, input_sampling_resolution, inpu
|
|
334 |
model = TransformerWrapper(
|
335 |
num_tokens = PAD_IDX+1,
|
336 |
max_seq_len = SEQ_LEN,
|
337 |
-
attn_layers = Decoder(dim =
|
338 |
)
|
339 |
|
340 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
341 |
|
342 |
-
model.to(DEVICE)
|
343 |
print('=' * 70)
|
344 |
|
345 |
print('Loading model checkpoint...')
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
350 |
print('=' * 70)
|
351 |
-
|
|
|
352 |
model.eval()
|
353 |
|
354 |
if DEVICE == 'cpu':
|
|
|
334 |
model = TransformerWrapper(
|
335 |
num_tokens = PAD_IDX+1,
|
336 |
max_seq_len = SEQ_LEN,
|
337 |
+
attn_layers = Decoder(dim = 2048, depth = 8, heads = 32, attn_flash = True)
|
338 |
)
|
339 |
|
340 |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
|
341 |
|
|
|
342 |
print('=' * 70)
|
343 |
|
344 |
print('Loading model checkpoint...')
|
345 |
|
346 |
+
model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
|
347 |
+
filename='Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
|
348 |
+
)
|
349 |
+
|
350 |
+
model.load_state_dict(torch.load(kar_model_checkpoint, map_location='cpu', weights_only=True))
|
351 |
+
|
352 |
+
model = torch.compile(model, mode='max-autotune')
|
353 |
+
|
354 |
print('=' * 70)
|
355 |
+
|
356 |
+
model.to(DEVICE)
|
357 |
model.eval()
|
358 |
|
359 |
if DEVICE == 'cpu':
|