File size: 8,181 Bytes
364607e
 
 
 
 
 
 
 
 
 
 
 
 
 
60e9e83
 
 
 
 
1514a4e
 
 
 
364607e
 
881a14f
 
 
 
 
 
 
 
 
 
1514a4e
881a14f
1514a4e
881a14f
 
 
 
 
364607e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdf45ad
 
 
 
1514a4e
bdf45ad
1514a4e
bdf45ad
 
 
 
 
 
60e9e83
bdf45ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364607e
bdf45ad
364607e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c183bed
 
 
364607e
 
 
 
 
261a0f5
364607e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import gradio as gr
import spaces
import librosa
import soundfile as sf
import wavio
import os
import subprocess
import pickle
import torch
import torch.nn as nn
from transformers import T5Tokenizer
from transformer_model import Transformer
from miditok import REMI, TokenizerConfig
from pathlib import Path
from huggingface_hub import hf_hub_download

repo_id = "amaai-lab/text2midi"
# Download the model.bin file
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
# Download the vocab_remi.pkl file
tokenizer_path = hf_hub_download(repo_id=repo_id, filename="vocab_remi.pkl")
# Download the soundfont file
soundfont_path = hf_hub_download(repo_id=repo_id, filename="soundfont.sf2")


def save_wav(filepath):
    # Extract the directory and the stem (filename without extension)
    directory = os.path.dirname(filepath)
    stem = os.path.splitext(os.path.basename(filepath))[0]

    # Construct the full paths for MIDI and WAV files
    midi_filepath = os.path.join(directory, f"{stem}.mid")
    wav_filepath = os.path.join(directory, f"{stem}.wav")

    # Run the fluidsynth command to convert MIDI to WAV
    # f"fluidsynth -r 16000 soundfont.sf2 -g 1.0 --quiet --no-shell {midi_filepath} -T wav -F {wav_filepath} > /dev/null",
    process = subprocess.Popen(
        f"fluidsynth -r 16000 {soundfont_path} -g 1.0 --quiet --no-shell {midi_filepath} -T wav -F {wav_filepath} > /dev/null",
        shell=True
    )
    process.wait()

    return wav_filepath


# def post_processing(input_midi_path: str, output_midi_path: str):
#     # Define tokenizer configuration
#     config = TokenizerConfig(
#         pitch_range=(21, 109),
#         beat_res={(0, 4): 8, (4, 12): 4},
#         num_velocities=32,
#         special_tokens=["PAD", "BOS", "EOS", "MASK"],
#         use_chords=True,
#         use_rests=False,
#         use_tempos=True,
#         use_time_signatures=False,
#         use_programs=True
#     )

#     # Initialize tokenizer
#     tokenizer = REMI(config)

#     # Tokenize the input MIDI
#     tokens = tokenizer(Path(input_midi_path))

#     # Remove notes in the first bar
#     modified_tokens = []
#     bar_count = 0
#     bars_after = 2
#     for token in tokens.tokens:
#         if token == "Bar_None":
#             bar_count += 1
#         if bar_count > bars_after:
#             modified_tokens.append(token)

#     # Decode tokens back into MIDI
#     modified_midi = tokenizer(modified_tokens)
#     modified_midi.dump_midi(Path(output_midi_path))


def generate_midi(caption, temperature=0.9, max_len=500):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    artifact_folder = 'artifacts'

    # tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
    # Load the tokenizer dictionary
    with open(tokenizer_path, "rb") as f:
        r_tokenizer = pickle.load(f)

    # Get the vocab size
    vocab_size = len(r_tokenizer)
    print("Vocab size: ", vocab_size)
    model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
    # model_path = os.path.join("amaai-lab/text2midi", "pytorch_model.bin")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

    inputs = tokenizer(caption, return_tensors='pt', padding=True, truncation=True)
    input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
    input_ids = input_ids.to(device)
    attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0) 
    attention_mask = attention_mask.to(device)
    output = model.generate(input_ids, attention_mask, max_len=max_len,temperature = temperature)
    output_list = output[0].tolist()
    generated_midi = r_tokenizer.decode(output_list)
    generated_midi.dump_midi("output.mid")
    # post_processing("output.mid", "output.mid")


@spaces.GPU(duration=120)
def gradio_generate(prompt, temperature, max_length):
    # Generate midi
    generate_midi(prompt, temperature, max_length)

    # Convert midi to wav
    midi_filename = "output.mid"
    save_wav(midi_filename)
    wav_filename = midi_filename.replace(".mid", ".wav")

    # Read the generated WAV file
    output_wave, samplerate = sf.read(wav_filename, dtype='float32')
    temp_wav_filename = "temp.wav"
    wavio.write(temp_wav_filename, output_wave, rate=16000, sampwidth=2)
    
    return temp_wav_filename, midi_filename  # Return both WAV and MIDI file paths


title="Text2midi: Generating Symbolic Music from Captions"
description_text = """
<p><a href="https://huggingface.co/spaces/amaai-lab/text2midi/blob/main/app.py?duplicate=true"> <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings. <br/><br/>
Generate midi music using Text2midi by providing a text prompt.
<br/><br/> This is the demo for Text2midi for controllable text to midi generation: <a href="https://arxiv.org/abs/tbd">Read our paper.</a>
<p/>
"""
#description_text = ""
# Gradio input and output components
input_text = gr.Textbox(lines=2, label="Prompt")
output_audio = gr.Audio(label="Generated Music", type="filepath")
output_midi = gr.File(label="Download MIDI File")
temperature = gr.Slider(minimum=0.9, maximum=1.1, value=1.0, step=0.01, label="Temperature", interactive=True)
max_length = gr.Number(value=1500, label="Max Length", minimum=500, maximum=2000, step=100)

# CSS styling for the Duplicate button
css = '''
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
'''

# Gradio interface
gr_interface = gr.Interface(
    fn=gradio_generate,
    inputs=[input_text, temperature, max_length],
    outputs=[output_audio, output_midi],
    description=description_text,
    allow_flagging=False,
    examples=[
        ["A haunting electronic ambient piece that evokes a sense of darkness and space, perfect for a film soundtrack. The string ensemble, trumpet, piano, timpani, and synth pad weave together to create a meditative atmosphere. Set in F minor with a 4/4 time signature, the song progresses at an Andante tempo, with the chords F, Fdim, and F/C recurring throughout."],
        ["A slow and emotional classical piece, likely used in a film soundtrack, featuring a church organ as the sole instrument. Written in the key of Eb major with a 3/4 time signature, it evokes a sense of drama and romance. The chord progression of Bb7, Eb, and Ab contributes to the relaxing atmosphere throughout the song."],
        ["An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration."],
        ["This short electronic song in C minor features a brass section, string ensemble, tenor saxophone, clean electric guitar, and slap bass, creating a melodic and slightly dark atmosphere. With a tempo of 124 BPM (Allegro) and a 4/4 time signature, the track incorporates a chord progression of C7/E, Eb6, and Bbm6, adding a touch of corporate and motivational vibes to the overall composition."],
        ["An energetic and melodic electronic trance track with a space and retro vibe, featuring drums, distortion guitar, flute, synth bass, and slap bass. Set in A minor with a fast tempo of 138 BPM, the song maintains a 4/4 time signature throughout its duration."],
        ["A short but energetic rock fragment in C minor, featuring overdriven guitars, electric bass, and drums, with a vivacious tempo of 155 BPM and a 4/4 time signature, evoking a blend of dark and melodic tones."],
    ],
    cache_examples="lazy",
    css=".example-caption { text-align: left; }"
)

with gr.Blocks(css=css) as demo:
    title=gr.HTML(f"<h1><center>{title}</center></h1>")
    dupe = gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    gr_interface.render()
   

# Launch Gradio app.
demo.queue().launch()