Spaces:
Running
on
Zero
Running
on
Zero
#========================================================================= | |
# https://huggingface.co/spaces/asigalov61/Score-2-Performance-Transformer | |
#========================================================================= | |
import os | |
import time as reqtime | |
import datetime | |
from pytz import timezone | |
import copy | |
from itertools import groupby | |
import tqdm | |
import spaces | |
import gradio as gr | |
import torch | |
from x_transformer_1_23_2 import * | |
import random | |
import TMIDIX | |
from midi_to_colab_audio import midi_to_colab_audio | |
from huggingface_hub import hf_hub_download | |
# ================================================================================================= | |
print('Loading model...') | |
SEQ_LEN = 1802 | |
PAD_IDX = 771 | |
DEVICE = 'cuda' # 'cpu' | |
# instantiate the model | |
model = TransformerWrapper( | |
num_tokens = PAD_IDX+1, | |
max_seq_len = SEQ_LEN, | |
attn_layers = Decoder(dim = 1024, | |
depth = 8, | |
heads = 8, | |
rotary_pos_emb=True, | |
attn_flash = True | |
) | |
) | |
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX) | |
print('=' * 70) | |
print('Loading model checkpoint...') | |
model_checkpoint = hf_hub_download(repo_id='asigalov61/Score-2-Performance-Transformer', | |
filename='Score_2_Performance_Transformer_Final_Small_Trained_Model_4496_steps_1.5185_loss_0.5589_acc.pth' | |
) | |
model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True)) | |
model = torch.compile(model, mode='max-autotune') | |
dtype = torch.bfloat16 | |
ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype) | |
print('=' * 70) | |
print('Done!') | |
print('=' * 70) | |
# ================================================================================================= | |
def load_midi(midi_file): | |
print('Loading MIDI...') | |
raw_score = TMIDIX.midi2single_track_ms_score(midi_file) | |
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True) | |
if escore_notes[0]: | |
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16) | |
pe = escore_notes[0] | |
melody_chords = [] | |
seen = [] | |
for e in escore_notes: | |
if e[3] != 9: | |
#======================================================= | |
dtime = max(0, min(255, e[1]-pe[1])) | |
if dtime != 0: | |
seen = [] | |
# Durations | |
dur = max(1, min(255, e[2])) | |
# Pitches | |
ptc = max(1, min(127, e[4])) | |
vel = max(1, min(127, e[5])) | |
if ptc not in seen: | |
melody_chords.append([dtime, dur, ptc, vel]) | |
seen.append(ptc) | |
pe = e | |
print('=' * 70) | |
print('Number of notes in a composition:', len(melody_chords)) | |
print('=' * 70) | |
src_melody_chords_f = [] | |
melody_chords_f = [] | |
for i in range(0, len(melody_chords), 300): | |
chunk = melody_chords[i:i+300] | |
src = [] | |
src1 = [] | |
trg = [] | |
if len(chunk) == 300: | |
for mm in chunk: | |
src.extend([mm[0], mm[2]+256]) | |
src1.append([mm[0], mm[2]+256, mm[1]+384, mm[3]+640]) | |
trg.extend([mm[0], mm[2]+256, mm[1]+384, mm[3]+640]) | |
src_melody_chords_f.append(src1) | |
melody_chords_f.append([768] + src + [769] + trg + [770]) | |
print('Done!') | |
print('=' * 70) | |
print('Number of composition chunks:', len(melody_chords_f)) | |
print('=' * 70) | |
return melody_chords_f, src_melody_chords_f | |
# ================================================================================================= | |
def Convert_Score_to_Performance(input_midi, | |
input_conv_type, | |
input_number_prime_notes, | |
input_number_conv_notes, | |
input_model_dur_top_k, | |
input_model_dur_temperature, | |
input_model_vel_temperature | |
): | |
#=============================================================================== | |
print('=' * 70) | |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
start_time = reqtime.time() | |
print('=' * 70) | |
fn = os.path.basename(input_midi) | |
fn1 = fn.split('.')[0] | |
print('=' * 70) | |
print('Requested settings:') | |
print('=' * 70) | |
print('Input MIDI file name:', fn) | |
print('Conversion type:', input_conv_type) | |
print('Number of prime notes:', input_number_prime_notes) | |
print('Number of notes to convert:', input_number_conv_notes) | |
print('Model durations sampling top value:', input_model_dur_top_k) | |
print('Model durations temperature:', input_model_dur_temperature) | |
print('Model velocities temperature:', input_model_vel_temperature) | |
print('=' * 70) | |
#================================================================== | |
melody_chords_f, src_melody_chords_f = load_midi(input_midi.name) | |
#================================================================== | |
print('Sample output events', melody_chords_f[0][:16]) | |
print('=' * 70) | |
print('Generating...') | |
model.to(DEVICE) | |
model.eval() | |
#================================================================== | |
composition_chunk_idx = 0 # Composition chunk idx to generate durations and velocities for. Each chunk is 300 notes | |
num_prime_notes = input_number_prime_notes # Priming improves the results but it is not necessary and you can set it to zero | |
dur_top_k = input_model_dur_top_k # Use k == 1 if src composition is score and k > 1 if src composition is performance | |
dur_temperature = input_model_dur_temperature # For best results, durations temperature should be more than 1.0 but less than velocities temperature | |
vel_temperature = input_model_vel_temperature # For best results, velocities temperature must be larger than 1.3 and larger than durations temperature | |
#================================================================== | |
song_chunk = src_melody_chords_f[composition_chunk_idx] | |
song = [768] | |
for m in song_chunk: | |
song.extend(m[:2]) | |
song.append(769) | |
for i in tqdm.tqdm(range(len(song_chunk))): | |
song.extend(song_chunk[i][:2]) | |
# Durations | |
if i < num_prime_notes: | |
song.append(song_chunk[i][2]) | |
else: | |
x = torch.LongTensor(song).cuda() | |
y = 0 | |
while not 384 < y < 640: | |
with ctx: | |
out = model.generate(x, | |
1, | |
temperature=dur_temperature, | |
filter_logits_fn=top_k, | |
filter_kwargs={'k': dur_top_k}, | |
return_prime=False, | |
verbose=False) | |
y = out.tolist()[0][0] | |
song.append(y) | |
# Velocities | |
if i < num_prime_notes: | |
song.append(song_chunk[i][3]) | |
else: | |
x = torch.LongTensor(song).cuda() | |
y = 0 | |
while not 640 < y < 768: | |
with ctx: | |
out = model.generate(x, | |
1, | |
temperature=vel_temperature, | |
#filter_logits_fn=top_k, | |
#filter_kwargs={'k': 10}, | |
return_prime=False, | |
verbose=False) | |
y = out.tolist()[0][0] | |
song.append(y) | |
print('=' * 70) | |
print('Done!') | |
print('=' * 70) | |
#=============================================================================== | |
print('Rendering results...') | |
print('=' * 70) | |
print('Sample INTs', song[:15]) | |
print('=' * 70) | |
song_f = [] | |
if len(song) != 0: | |
time = 0 | |
dur = 0 | |
vel = 90 | |
pitch = 60 | |
channel = 0 | |
patch = 0 | |
patches = [0] * 16 | |
for ss in song[602:]: | |
if 0 <= ss < 256: | |
time += ss * 16 | |
if 256 <= ss < 384: | |
pitch = ss-256 | |
if 384 <= ss < 640: | |
dur = (ss-384) * 16 | |
if 640 <= ss < 768: | |
vel = (ss-640) | |
song_f.append(['note', time, dur, channel, pitch, vel, patch]) | |
fn1 = "Score-2-Performance-Transformer-Composition" | |
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f, | |
output_signature = 'Score 2 Performance Transformer', | |
output_file_name = fn1, | |
track_name='Project Los Angeles', | |
list_of_MIDI_patches=patches | |
) | |
new_fn = fn1+'.mid' | |
audio = midi_to_colab_audio(new_fn, | |
soundfont_path=soundfont, | |
sample_rate=16000, | |
volume_scale=10, | |
output_for_gradio=True | |
) | |
print('Done!') | |
print('=' * 70) | |
#======================================================== | |
output_midi_title = str(fn1) | |
output_midi_summary = str(song_f[:3]) | |
output_midi = str(new_fn) | |
output_audio = (16000, audio) | |
output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True) | |
print('Output MIDI file name:', output_midi) | |
print('Output MIDI title:', output_midi_title) | |
print('Output MIDI summary:', output_midi_summary) | |
print('=' * 70) | |
#======================================================== | |
print('-' * 70) | |
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
print('-' * 70) | |
print('Req execution time:', (reqtime.time() - start_time), 'sec') | |
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot | |
# ================================================================================================= | |
if __name__ == "__main__": | |
PDT = timezone('US/Pacific') | |
print('=' * 70) | |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
print('=' * 70) | |
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Score 2 Performance Transformer</h1>") | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Convert any MIDI score to a nice performance</h1>") | |
gr.Markdown("## Upload your MIDI or select a sample example MIDI below") | |
gr.Markdown("### PLEASE NOTE that the score MIDI MUST HAVE at least 300 notes for this demo to work") | |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"]) | |
gr.Markdown("## Select conversion type") | |
input_conv_type = gr.Radio(["Durations and Velocities", "Durations", "Velocities"], | |
value="Durations and Velocities", | |
label="Conversion type" | |
) | |
gr.Markdown("## Conversion options") | |
input_number_prime_notes = gr.Slider(0, 512, value=0, step=8, label="Number of prime notes") | |
input_number_conv_notes = gr.Slider(0, 3072, value=1024, step=16, label="Number of notes to convert") | |
gr.Markdown("## Model options") | |
input_model_dur_top_k = gr.Slider(1, 100, value=1, step=1, label="Model sampling top k value for durations") | |
input_model_dur_temperature = gr.Slider(0.5, 1.5, value=1.1, step=0.05, label="Model temperature for durations") | |
input_model_vel_temperature = gr.Slider(0.5, 1.5, value=1.5, step=0.05, label="Model temperature for velocities") | |
run_btn = gr.Button("convert", variant="primary") | |
gr.Markdown("## Generation results") | |
output_midi_title = gr.Textbox(label="Output MIDI title") | |
output_midi_summary = gr.Textbox(label="Output MIDI summary") | |
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio") | |
output_plot = gr.Plot(label="Output MIDI score plot") | |
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) | |
run_event = run_btn.click(Convert_Score_to_Performance, [input_midi, | |
input_conv_type, | |
input_number_prime_notes, | |
input_number_conv_notes, | |
input_model_dur_top_k, | |
input_model_dur_temperature, | |
input_model_vel_temperature | |
], | |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot]) | |
gr.Examples( | |
[["asap_midi_score_21.mid", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], | |
["asap_midi_score_45.mid", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], | |
["asap_midi_score_69.mid", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], | |
["asap_midi_score_118.mid", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], | |
["asap_midi_score_167.mid", "Durations and Velocities", 8, 600, 1, 1.1, 1.5], | |
], | |
[input_midi, | |
input_conv_type, | |
input_number_prime_notes, | |
input_number_conv_notes, | |
input_model_dur_top_k, | |
input_model_dur_temperature, | |
input_model_vel_temperature | |
], | |
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot], | |
Convert_Score_to_Performance | |
) | |
app.queue().launch() |