asigalov61 commited on
Commit
676e005
·
verified ·
1 Parent(s): 661aab1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -262
app.py CHANGED
@@ -10,48 +10,31 @@ print('Loading core MIDI Loops Mixer modules...')
10
 
11
  import os
12
  import copy
 
 
13
 
14
  import time as reqtime
15
  import datetime
16
  from pytz import timezone
17
 
 
 
18
  print('=' * 70)
19
  print('Loading main MIDI Loops Mixer modules...')
20
 
21
- os.environ['USE_FLASH_ATTENTION'] = '1'
22
-
23
- import torch
24
-
25
- torch.set_float32_matmul_precision('high')
26
- torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
- torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
- torch.backends.cuda.enable_mem_efficient_sdp(True)
29
- torch.backends.cuda.enable_math_sdp(True)
30
- torch.backends.cuda.enable_flash_sdp(True)
31
- torch.backends.cuda.enable_cudnn_sdp(True)
32
-
33
- from huggingface_hub import hf_hub_download
34
 
35
  import TMIDIX
36
 
37
  from midi_to_colab_audio import midi_to_colab_audio
38
 
39
- from x_transformer_1_23_2 import *
40
-
41
- import random
42
-
43
- import tqdm
44
 
45
  print('=' * 70)
46
  print('Loading aux MIDI Loops Mixer modules...')
47
 
48
  import matplotlib.pyplot as plt
49
 
50
- import gradio as gr
51
- import spaces
52
-
53
- print('=' * 70)
54
- print('PyTorch version:', torch.__version__)
55
  print('=' * 70)
56
  print('Done!')
57
  print('Enjoy! :)')
@@ -59,182 +42,32 @@ print('=' * 70)
59
 
60
  #==================================================================================
61
 
62
- MODEL_CHECKPOINT = 'Guided_Accompaniment_Transformer_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
63
-
64
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
65
 
66
  #==================================================================================
67
 
68
  print('=' * 70)
69
- print('Instantiating model...')
70
 
71
- device_type = 'cuda'
72
- dtype = 'bfloat16'
73
-
74
- ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
75
- ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
76
-
77
- SEQ_LEN = 4096
78
- PAD_IDX = 1794
79
-
80
- model = TransformerWrapper(
81
- num_tokens = PAD_IDX+1,
82
- max_seq_len = SEQ_LEN,
83
- attn_layers = Decoder(dim = 2048,
84
- depth = 4,
85
- heads = 32,
86
- rotary_pos_emb = True,
87
- attn_flash = True
88
- )
89
- )
90
-
91
- model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
92
-
93
- print('=' * 70)
94
- print('Loading model checkpoint...')
95
-
96
- model_checkpoint = hf_hub_download(repo_id='asigalov61/MIDI-Loops-Mixer', filename=MODEL_CHECKPOINT)
97
-
98
- model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
99
-
100
- model = torch.compile(model, mode='max-autotune')
101
 
102
  print('=' * 70)
103
  print('Done!')
104
  print('=' * 70)
105
- print('Model will use', dtype, 'precision...')
106
  print('=' * 70)
107
 
108
  #==================================================================================
109
 
110
- def load_midi(input_midi, melody_patch=-1):
111
-
112
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
113
-
114
- escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
115
- escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
116
-
117
- sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)
118
-
119
- if melody_patch == -1:
120
- zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
121
-
122
- else:
123
- mel_score = [e for e in sp_escore_notes if e[6] == melody_patch]
124
-
125
- if mel_score:
126
- zscore = TMIDIX.recalculate_score_timings(mel_score)
127
-
128
- else:
129
- zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
130
-
131
- cscore = TMIDIX.chordify_score([1000, zscore])
132
-
133
- score = []
134
-
135
- score_list = []
136
-
137
- pc = cscore[0]
138
-
139
- for c in cscore:
140
- score.append(max(0, min(127, c[0][1]-pc[0][1])))
141
-
142
- scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
143
-
144
- n = c[0]
145
-
146
- score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
147
- scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
148
-
149
- score_list.append(scl)
150
-
151
- pc = c
152
-
153
- score_list.append(scl)
154
-
155
- return score, score_list
156
-
157
- #==================================================================================
158
-
159
- @spaces.GPU
160
- def Generate_Accompaniment(input_midi,
161
- generation_type,
162
- melody_patch,
163
- model_temperature
164
- ):
165
 
166
  #===============================================================================
167
 
168
- def generate_full_seq(input_seq, temperature=0.9, verbose=True):
169
-
170
- seq_abs_run_time = sum([t for t in input_seq if t < 128])
171
-
172
- cur_time = 0
173
-
174
- full_seq = copy.deepcopy(input_seq)
175
-
176
- toks_counter = 0
177
-
178
- while cur_time <= seq_abs_run_time:
179
-
180
- if verbose:
181
- if toks_counter % 128 == 0:
182
- print('Generated', toks_counter, 'tokens')
183
-
184
- x = torch.LongTensor(full_seq).cuda()
185
-
186
- with ctx:
187
- out = model.generate(x,
188
- 1,
189
- temperature=temperature,
190
- return_prime=False,
191
- verbose=False)
192
-
193
- y = out.tolist()[0][0]
194
-
195
- if y < 128:
196
- cur_time += y
197
-
198
- full_seq.append(y)
199
-
200
- toks_counter += 1
201
-
202
- return full_seq
203
-
204
- #===============================================================================
205
 
206
- def generate_block_seq(input_seq, trg_dtime, temperature=0.9):
207
-
208
- inp_seq = copy.deepcopy(input_seq)
209
-
210
- block_seq = []
211
-
212
- cur_time = 0
213
-
214
- while cur_time < trg_dtime:
215
-
216
- x = torch.LongTensor(inp_seq).cuda()
217
-
218
- with ctx:
219
- out = model.generate(x,
220
- 1,
221
- temperature=temperature,
222
- return_prime=False,
223
- verbose=False)
224
-
225
- y = out.tolist()[0][0]
226
-
227
- if y < 128:
228
- cur_time += y
229
-
230
- inp_seq.append(y)
231
- block_seq.append(y)
232
-
233
- if cur_time != trg_dtime:
234
- return []
235
-
236
- else:
237
- return block_seq
238
 
239
  #===============================================================================
240
 
@@ -267,69 +100,7 @@ def Generate_Accompaniment(input_midi,
267
  print('=' * 70)
268
  print('Generating...')
269
 
270
- model.to(device_type)
271
- model.eval()
272
-
273
- #==================================================================
274
 
275
- start_score_seq = [1792] + score + [1793]
276
-
277
- #==================================================================
278
-
279
- if generation_type == 'Guided':
280
-
281
- input_seq = []
282
-
283
- input_seq.extend(start_score_seq)
284
- input_seq.extend(score_list[0][0])
285
-
286
- block_seq_lens = []
287
-
288
- idx = 0
289
-
290
- max_retries = 3
291
- mrt = 0
292
-
293
- while idx < len(score_list)-1:
294
-
295
- if idx % 10 == 0:
296
- print('Generating', idx, 'block')
297
-
298
- input_seq.extend(score_list[idx][1])
299
-
300
- block_seq = []
301
-
302
- for _ in range(max_retries):
303
-
304
- block_seq = generate_block_seq(input_seq, score_list[idx+1][0][0])
305
-
306
- if block_seq:
307
- break
308
-
309
- if block_seq:
310
- input_seq.extend(block_seq)
311
- block_seq_lens.append(len(block_seq))
312
- idx += 1
313
- mrt = 0
314
-
315
- else:
316
-
317
- if block_seq_lens:
318
- input_seq = input_seq[:-(block_seq_lens[-1]+2)]
319
- block_seq_lens.pop()
320
- idx -= 1
321
- mrt += 1
322
-
323
- else:
324
- break
325
-
326
- if mrt == max_retries:
327
- break
328
-
329
- else:
330
- input_seq = generate_full_seq(start_score_seq, temperature=model_temperature)
331
-
332
- final_song = input_seq[len(start_score_seq):]
333
 
334
  print('=' * 70)
335
  print('Done!')
@@ -432,7 +203,7 @@ with gr.Blocks() as demo:
432
  #==================================================================================
433
 
434
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Loops Mixer</h1>")
435
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Guided melody accompaniment generation with transformers</h1>")
436
  gr.HTML("""
437
  <p>
438
  <a href="https://huggingface.co/spaces/asigalov61/MIDI-Loops-Mixer?duplicate=true">
@@ -474,24 +245,7 @@ with gr.Blocks() as demo:
474
  output_midi
475
  ]
476
  )
477
-
478
- gr.Examples(
479
- [["USSR-National-Anthem-Seed-Melody.mid", "Guided", -1, 0.9],
480
- ["Hotel-California-Seed-Melody.mid", "Guided", -1, 0.9],
481
- ["Sparks-Fly-Seed-Melody.mid", "Guided", -1, 0.9]
482
- ],
483
- [input_midi,
484
- generation_type,
485
- melody_patch,
486
- model_temperature
487
- ],
488
- [output_audio,
489
- output_plot,
490
- output_midi
491
- ],
492
- Generate_Accompaniment
493
- )
494
-
495
  #==================================================================================
496
 
497
  demo.launch()
 
10
 
11
  import os
12
  import copy
13
+ import statistics
14
+ import random
15
 
16
  import time as reqtime
17
  import datetime
18
  from pytz import timezone
19
 
20
+ import tqdm
21
+
22
  print('=' * 70)
23
  print('Loading main MIDI Loops Mixer modules...')
24
 
25
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  import TMIDIX
28
 
29
  from midi_to_colab_audio import midi_to_colab_audio
30
 
31
+ import gradio as gr
 
 
 
 
32
 
33
  print('=' * 70)
34
  print('Loading aux MIDI Loops Mixer modules...')
35
 
36
  import matplotlib.pyplot as plt
37
 
 
 
 
 
 
38
  print('=' * 70)
39
  print('Done!')
40
  print('Enjoy! :)')
 
42
 
43
  #==================================================================================
44
 
 
 
45
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
46
 
47
  #==================================================================================
48
 
49
  print('=' * 70)
50
+ print('Loading MIDI Loops Small Dataset model...')
51
 
52
+ midi_loops_dataset = TMIDIX.Tegridy_Any_Pickle_File_Reader('MIDI-Loops-Dataset-Small-CC-BY-NC-SA.pickle')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  print('=' * 70)
55
  print('Done!')
56
  print('=' * 70)
57
+ print('Loaded', len(midi_loops_dataset), 'MIDI Loops')
58
  print('=' * 70)
59
 
60
  #==================================================================================
61
 
62
+ def Mix_Loops(input_midi,
63
+ generation_type,
64
+ melody_patch,
65
+ model_temperature
66
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  #===============================================================================
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  #===============================================================================
73
 
 
100
  print('=' * 70)
101
  print('Generating...')
102
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  print('=' * 70)
106
  print('Done!')
 
203
  #==================================================================================
204
 
205
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Loops Mixer</h1>")
206
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Mix random MIDI loops into one coherent music composition</h1>")
207
  gr.HTML("""
208
  <p>
209
  <a href="https://huggingface.co/spaces/asigalov61/MIDI-Loops-Mixer?duplicate=true">
 
245
  output_midi
246
  ]
247
  )
248
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  #==================================================================================
250
 
251
  demo.launch()