keshavbhandari commited on
Commit
bdf45ad
·
1 Parent(s): 626a7e2

Refactor generate_midi function and remove commented code

Browse files
Files changed (1) hide show
  1. app.py +45 -45
app.py CHANGED
@@ -68,52 +68,52 @@ def save_wav(filepath):
68
  # modified_midi.dump_midi(Path(output_midi_path))
69
 
70
 
71
- # def generate_midi(caption, temperature=0.9, max_len=500):
72
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
73
- # artifact_folder = 'artifacts'
74
-
75
- # tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
76
- # # Load the tokenizer dictionary
77
- # with open(tokenizer_filepath, "rb") as f:
78
- # r_tokenizer = pickle.load(f)
79
-
80
- # # Get the vocab size
81
- # vocab_size = len(r_tokenizer)
82
- # print("Vocab size: ", vocab_size)
83
- # model = Transformer(vocab_size, 768, 8, 5000, 18, 1024, False, 8, device=device)
84
- # model_path = os.path.join(artifact_folder, "pytorch_model_140.bin")
85
- # model.load_state_dict(torch.load(model_path, map_location=device))
86
- # model.eval()
87
- # tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
88
-
89
- # inputs = tokenizer(caption, return_tensors='pt', padding=True, truncation=True)
90
- # input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
91
- # input_ids = input_ids.to(device)
92
- # attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
93
- # attention_mask = attention_mask.to(device)
94
- # output = model.generate(input_ids, attention_mask, max_len=max_len,temperature = temperature)
95
- # output_list = output[0].tolist()
96
- # generated_midi = r_tokenizer.decode(output_list)
97
- # generated_midi.dump_midi("output.mid")
98
- # post_processing("output.mid", "output.mid")
99
-
100
-
101
- # @spaces.GPU(duration=120)
102
- # def gradio_generate(prompt, temperature, max_length):
103
- # # Generate midi
104
- # generate_midi(prompt, temperature, max_length)
105
-
106
- # # Convert midi to wav
107
- # midi_filename = "output.mid"
108
- # save_wav(midi_filename)
109
- # wav_filename = midi_filename.replace(".mid", ".wav")
110
-
111
- # # Read the generated WAV file
112
- # output_wave, samplerate = sf.read(wav_filename, dtype='float32')
113
- # temp_wav_filename = "temp.wav"
114
- # wavio.write(temp_wav_filename, output_wave, rate=16000, sampwidth=2)
115
 
116
- # return temp_wav_filename, midi_filename # Return both WAV and MIDI file paths
117
 
118
  @spaces.GPU(duration=120)
119
  def gradio_generate(prompt, temperature, max_length):
 
68
  # modified_midi.dump_midi(Path(output_midi_path))
69
 
70
 
71
+ def generate_midi(caption, temperature=0.9, max_len=500):
72
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
73
+ artifact_folder = 'artifacts'
74
+
75
+ tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl")
76
+ # Load the tokenizer dictionary
77
+ with open(tokenizer_filepath, "rb") as f:
78
+ r_tokenizer = pickle.load(f)
79
+
80
+ # Get the vocab size
81
+ vocab_size = len(r_tokenizer)
82
+ print("Vocab size: ", vocab_size)
83
+ model = Transformer(vocab_size, 768, 8, 2048, 18, 1024, False, 8, device=device)
84
+ model_path = os.path.join("amaai-lab/text2midi", "pytorch_model.bin")
85
+ model.load_state_dict(torch.load(model_path, map_location=device))
86
+ model.eval()
87
+ tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
88
+
89
+ inputs = tokenizer(caption, return_tensors='pt', padding=True, truncation=True)
90
+ input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0)
91
+ input_ids = input_ids.to(device)
92
+ attention_mask =nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0)
93
+ attention_mask = attention_mask.to(device)
94
+ output = model.generate(input_ids, attention_mask, max_len=max_len,temperature = temperature)
95
+ output_list = output[0].tolist()
96
+ generated_midi = r_tokenizer.decode(output_list)
97
+ generated_midi.dump_midi("output.mid")
98
+ # post_processing("output.mid", "output.mid")
99
+
100
+
101
+ @spaces.GPU(duration=120)
102
+ def gradio_generate(prompt, temperature, max_length):
103
+ # Generate midi
104
+ generate_midi(prompt, temperature, max_length)
105
+
106
+ # Convert midi to wav
107
+ midi_filename = "output.mid"
108
+ save_wav(midi_filename)
109
+ wav_filename = midi_filename.replace(".mid", ".wav")
110
+
111
+ # Read the generated WAV file
112
+ output_wave, samplerate = sf.read(wav_filename, dtype='float32')
113
+ temp_wav_filename = "temp.wav"
114
+ wavio.write(temp_wav_filename, output_wave, rate=16000, sampwidth=2)
115
 
116
+ return temp_wav_filename, midi_filename # Return both WAV and MIDI file paths
117
 
118
  @spaces.GPU(duration=120)
119
  def gradio_generate(prompt, temperature, max_length):