Spaces:
Runtime error
Runtime error
| import json | |
| import random | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from math import ceil, floor, log | |
| from typing import Dict, Iterator, List, Optional, Tuple | |
| import mido | |
| class VocabConfig: | |
| # Number of note events. Should be 128. | |
| note_events: int | |
| # Number of wait events. Configurable, must evenly divide max_wait_time. | |
| wait_events: int | |
| # Max wait time in milliseconds to be represented by a single token. | |
| max_wait_time: int | |
| # Number of velocity events. Should be 128 (or 100? need to check midi standard) | |
| velocity_events: int | |
| # Number of bins to quantize velocity into. Should evenly divide velocity_events. | |
| velocity_bins: int | |
| # Exponential scaling factor for velocity bin sizes. 1.0 = linear scaling. | |
| velocity_exp: float | |
| # Whether to sort tokens by instrument, note. This should improve data reducibility. | |
| do_token_sorting: bool | |
| # Whether tokens should be represented as combined instrument/note/velocity tokens, or separate tokens for each. | |
| unrolled_tokens: bool | |
| # If non-zero, notes held for this many seconds will be automatically released during str->midi decoding. | |
| decode_end_held_note_delay: float | |
| # If true, repeated notes will be automatically released before playing again during str->midi decoding. | |
| decode_fix_repeated_notes: bool | |
| # List of instrument names to use for binning. Must have at most 16 values. | |
| bin_instrument_names: List[str] | |
| # Indicates which bin name represents percussion instruments on MIDI channel 10. | |
| ch10_instrument_bin_name: str | |
| # Mapping from instrument name to bin name. | |
| program_name_to_bin_name: Dict[str, str] | |
| # Mapping from bin name to program name. | |
| bin_name_to_program_name: Dict[str, str] | |
| # Mapping from program number to instrument name. | |
| instrument_names: Dict[str, str] | |
| # Manual override for velocity bins. Each element is the max velocity value for that bin by index. | |
| velocity_bins_override: Optional[List[int]] = None | |
| def __post_init__(self): | |
| self.validate() | |
| self._instrument_names_str_to_int = {name: int(i) for i, name in self.instrument_names.items()} | |
| self._instrument_names_int_to_str = {int(i): name for i, name in self.instrument_names.items()} | |
| self._bin_str_to_int = {name: int(i) for i, name in enumerate(self.bin_instrument_names)} | |
| self._bin_int_to_instrument_int = [self._instrument_names_str_to_int[self.bin_name_to_program_name[name]] if name != self.ch10_instrument_bin_name else 0 for name in self.bin_instrument_names] | |
| self._instrument_int_to_bin_int = [self._bin_str_to_int[self.program_name_to_bin_name[instr]] if self.program_name_to_bin_name[instr] != "" else -1 for instr in self.program_name_to_bin_name.keys()] | |
| self._ch10_bin_int = self._bin_str_to_int[self.ch10_instrument_bin_name] if self.ch10_instrument_bin_name else -1 | |
| self.short_instr_bin_names = [] | |
| for instr in self.bin_instrument_names: | |
| i = min(1, len(instr)) | |
| while instr[:i] in self.short_instr_bin_names: | |
| i += 1 | |
| self.short_instr_bin_names.append(instr[:i]) | |
| self._short_instrument_names_str_to_int = {name: int(i) for i, name in enumerate(self.short_instr_bin_names)} | |
| range_excluding_ch10 = [(i if i < 9 else i+1) for i in range(len(self.bin_instrument_names))] | |
| bins_excluding_ch10 = [n for n in self.bin_instrument_names if n != self.ch10_instrument_bin_name] | |
| self.bin_channel_map = {bin: channel for channel, bin in zip(range_excluding_ch10, bins_excluding_ch10)} | |
| if self.ch10_instrument_bin_name: | |
| self.bin_channel_map[self.ch10_instrument_bin_name] = 9 | |
| def validate(self): | |
| if self.max_wait_time % self.wait_events != 0: | |
| raise ValueError("max_wait_time must be exactly divisible by wait_events") | |
| if self.velocity_bins < 2: | |
| raise ValueError("velocity_bins must be at least 2") | |
| if len(self.bin_instrument_names) > 16: | |
| raise ValueError("bin_instruments must have at most 16 values") | |
| if self.velocity_bins_override: | |
| print("VocabConfig is using velocity_bins_override. Ignoring velocity_exp.") | |
| if len(self.velocity_bins_override) != self.velocity_bins: | |
| raise ValueError("velocity_bins_override must have same length as velocity_bins") | |
| if self.ch10_instrument_bin_name and self.ch10_instrument_bin_name not in self.bin_instrument_names: | |
| raise ValueError("ch10_instrument_bin_name must be in bin_instruments") | |
| if self.velocity_exp <= 0: | |
| raise ValueError("velocity_exp must be greater than 0") | |
| def from_json(cls, path: str): | |
| with open(path, "r") as f: | |
| config = json.load(f) | |
| return cls(**config) | |
| class VocabUtils: | |
| def __init__(self, cfg: VocabConfig) -> None: | |
| self.cfg = cfg | |
| def format_wait_token(self, wait: int) -> str: | |
| return f"t{wait}" | |
| def format_note_token(self, instrument_bin: int, note: int, velocity_bin: int) -> str: | |
| return f"{self.cfg.short_instr_bin_names[instrument_bin]}:{note:x}:{velocity_bin:x}" | |
| def format_unrolled_note(self, note: int) -> str: | |
| return f"n{note:x}" | |
| def format_unrolled_velocity(self, velocity_bin: int) -> str: | |
| return f"v{velocity_bin:x}" | |
| def format_unrolled_instrument_bin(self, instrument_bin: int) -> str: | |
| return f"i{self.cfg.short_instr_bin_names[instrument_bin]}" | |
| def velocity_to_bin(self, velocity: float) -> int: | |
| velocity = max(0, min(velocity, self.cfg.velocity_events - 1)) | |
| if self.cfg.velocity_bins_override: | |
| for i, v in enumerate(self.cfg.velocity_bins_override): | |
| if velocity <= v: | |
| return i | |
| return 0 | |
| binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1) | |
| if self.cfg.velocity_exp == 1.0: | |
| return ceil(velocity / binsize) | |
| else: | |
| return ceil((self.cfg.velocity_events*((self.cfg.velocity_exp**(velocity/self.cfg.velocity_events)-1.0) / (self.cfg.velocity_exp-1.0))) / binsize) | |
| def bin_to_velocity(self, bin: int) -> int: | |
| if self.cfg.velocity_bins_override: | |
| return self.cfg.velocity_bins_override[bin] | |
| binsize = self.cfg.velocity_events / (self.cfg.velocity_bins - 1) | |
| if self.cfg.velocity_exp == 1.0: | |
| return max(0, ceil(bin * binsize - 1)) | |
| else: | |
| return max(0, ceil(self.cfg.velocity_events*log(((self.cfg.velocity_exp-1)*binsize*bin)/self.cfg.velocity_events+1, self.cfg.velocity_exp) - 1)) | |
| def delta_to_wait_ids(self, delta_ms: float) -> Iterator[int]: | |
| def roundi(f: float): | |
| return ceil(f - 0.5) | |
| max_wait_ms = self.cfg.max_wait_time | |
| div = max_wait_ms / self.cfg.wait_events | |
| #if delta_ms // max_wait_ms > 512: # arbitrary limit to avoid excessive time_shifts | |
| # raise ValueError("delta_time is too large") | |
| if delta_ms > max_wait_ms * 10: | |
| delta_ms = max_wait_ms * 10 # truncate time | |
| for _ in range(floor(delta_ms / max_wait_ms)): | |
| yield roundi(max_wait_ms / div) | |
| leftover_time_shift = roundi((delta_ms % max_wait_ms) / div) | |
| if leftover_time_shift > 0: | |
| yield leftover_time_shift | |
| def prog_data_to_token_data(self, program: int, channel: int, note: int, velocity: float) -> Optional[Tuple[int, int, int]]: | |
| if channel == 9: | |
| if self.cfg._ch10_bin_int == -1: | |
| return None | |
| return self.cfg._ch10_bin_int, note, self.velocity_to_bin(velocity) | |
| instrument_bin = self.cfg._instrument_int_to_bin_int[program] | |
| if instrument_bin != -1: | |
| return instrument_bin, note, self.velocity_to_bin(velocity) | |
| return None | |
| def prog_data_list_to_token_data_list(self, data: List[Tuple[int, int, int, float]]) -> Iterator[Tuple[int, int, int]]: | |
| for d in data: | |
| token_data = self.prog_data_to_token_data(*d) | |
| if token_data is not None: | |
| yield token_data | |
| def sort_token_data(self, data: List[Tuple[int, int, int]]) -> List[Tuple[int, int, int]]: | |
| # ensure order is preserved for tokens with the same instrument, note | |
| data = [(i, n, v, x) for x, (i, n, v) in enumerate(data)] | |
| data.sort(key=lambda x: (x[0]!=self.cfg._ch10_bin_int, x[0], x[1], x[3])) | |
| return [(i, n, v) for i, n, v, _ in data] | |
| def data_to_wait_tokens(self, delta_ms: float) -> List[str]: | |
| if delta_ms == 0.0: | |
| return [] | |
| return [self.format_wait_token(i) for i in self.delta_to_wait_ids(delta_ms)] | |
| def wait_token_to_delta(self, token: str) -> float: | |
| return self.cfg.max_wait_time / self.cfg.wait_events * int(token[1:]) | |
| def note_token_to_data(self, token: str) -> Tuple[int, int, int]: | |
| instr_str, note_str, velocity_str = token.strip().split(":") | |
| instr_bin = self.cfg._short_instrument_names_str_to_int[instr_str] | |
| note = int(note_str, base=16) | |
| velocity = self.bin_to_velocity(int(velocity_str, base=16)) | |
| return instr_bin, note, velocity | |
| class AugmentValues: | |
| instrument_bin_remap: Dict[int, int] | |
| velocity_mod_factor: float | |
| transpose_semitones: int | |
| time_stretch_factor: float | |
| def default(cls) -> "AugmentValues": | |
| return cls( | |
| instrument_bin_remap={}, | |
| velocity_mod_factor=1.0, | |
| transpose_semitones=0, | |
| time_stretch_factor=1.0, | |
| ) | |
| class AugmentConfig: | |
| # The number of times to augment each MIDI file. The dataset size will be multiplied by this number. | |
| augment_data_factor: int | |
| # A list of instrument names to randomly swap with each other. | |
| instrument_mixups: List[List[str]] | |
| # A list of percentages to change the note velocity by. 0.0 = no change. 0 is included by default. | |
| velocity_mod_pct: List[float] | |
| # A list of semitones to transpose by. 0 is included by default. | |
| transpose_semitones: List[int] | |
| # A list of percentages to stretch the tempo by. 0.0 = no stretch. 0 is included by default. | |
| time_stretch_pct: List[float] | |
| # Random seed to use for reproducibility. | |
| seed: int | |
| cfg: VocabConfig | |
| def __post_init__(self): | |
| self.validate() | |
| if len(self.velocity_mod_pct) == 0: | |
| self.velocity_mod_pct = [0.0] | |
| if len(self.transpose_semitones) == 0: | |
| self.transpose_semitones = [0] | |
| if len(self.time_stretch_pct) == 0: | |
| self.time_stretch_pct = [0.0] | |
| self._instrument_mixups_int = [[self.cfg._bin_str_to_int[i] for i in l if i in self.cfg._bin_str_to_int] for l in self.instrument_mixups] | |
| self._instrument_mixups_int = [l for l in self._instrument_mixups_int if len(l) > 0] # remove empty lists | |
| self._instrument_pool_assignments = {} | |
| self._mixup_pools = [] | |
| for pool_i, mixup_list in enumerate(self._instrument_mixups_int): | |
| pool = set() | |
| for i in mixup_list: | |
| pool.add(i) | |
| self._instrument_pool_assignments[i] = pool_i | |
| self._mixup_pools.append(pool) | |
| def validate(self): | |
| if self.augment_data_factor < 1: | |
| raise ValueError("augment_data_factor must be at least 1") | |
| used_instruments = set() | |
| for mixup_list in self.instrument_mixups: | |
| for n in mixup_list: | |
| if n in used_instruments: | |
| raise ValueError(f"Duplicate instrument name: {n}") | |
| used_instruments.add(n) | |
| def from_json(cls, path: str, cfg: VocabConfig): | |
| with open(path, "r") as f: | |
| config = json.load(f) | |
| config["cfg"] = cfg | |
| if "seed" not in config: | |
| config["seed"] = random.randint(0, 2**32 - 1) | |
| return cls(**config) | |
| def get_augment_values(self, filename: str) -> Iterator[AugmentValues]: | |
| # first yield default values | |
| yield AugmentValues.default() | |
| rng = random.Random(self.seed + hash(filename)) | |
| for _ in range(int(self.augment_data_factor - 1)): | |
| # randomize order for each pool | |
| randomized_pools = [list(pool) for pool in self._mixup_pools] | |
| for pool in randomized_pools: | |
| rng.shuffle(pool) | |
| # distribute reassignments | |
| instrument_bin_remap = {} | |
| for i, pool in enumerate(randomized_pools): | |
| for j, instrument in enumerate(pool): | |
| instrument_bin_remap[instrument] = randomized_pools[i - 1][j] | |
| yield AugmentValues( | |
| instrument_bin_remap=instrument_bin_remap, | |
| velocity_mod_factor=1.0 + rng.choice(self.velocity_mod_pct), | |
| transpose_semitones=rng.choice(self.transpose_semitones), | |
| time_stretch_factor=1.0 + rng.choice(self.time_stretch_pct), | |
| ) | |
| class FilterConfig: | |
| # Whether to filter out MIDI files with duplicate MD5 hashes. | |
| deduplicate_md5: bool | |
| # Minimum time delay between notes in a file before splitting into multiple documents. | |
| piece_split_delay: float | |
| # Minimum length of a piece in milliseconds. | |
| min_piece_length: float | |
| def from_json(cls, path: str): | |
| with open(path, "r") as f: | |
| config = json.load(f) | |
| return cls(**config) | |
| def mix_volume(velocity: int, volume: int, expression: int) -> float: | |
| return velocity * (volume / 127.0) * (expression / 127.0) | |
| def convert_midi_to_str(cfg: VocabConfig, filter_cfg: FilterConfig, mid: mido.MidiFile, augment: AugmentValues = None) -> List[str]: | |
| utils = VocabUtils(cfg) | |
| if augment is None: | |
| augment = AugmentValues.default() | |
| # filter out unknown meta messages before merge (https://github.com/mido/mido/pull/286) | |
| for i in range(len(mid.tracks)): | |
| mid.tracks[i] = [msg for msg in mid.tracks[i] if msg.type != "unknown_meta"] | |
| if len(mid.tracks) > 1: | |
| mid.tracks = [mido.merge_tracks(mid.tracks)] | |
| delta_time_ms = 0.0 | |
| tempo = 500000 | |
| channel_program = {i: 0 for i in range(16)} | |
| channel_volume = {i: 127 for i in range(16)} | |
| channel_expression = {i: 127 for i in range(16)} # unlikely to be useful. expression usually modifies an already played note. | |
| channel_notes = {i: {} for i in range(16)} | |
| channel_pedal_on = {i: False for i in range(16)} | |
| channel_pedal_events = {i: {} for i in range(16)} # {channel: {(note, program) -> True}} | |
| started_flag = False | |
| output_list = [] | |
| output = ["<start>"] | |
| output_length_ms = 0.0 | |
| token_data_buffer: List[Tuple[int, int, int, float]] = [] # need to sort notes between wait tokens | |
| def flush_token_data_buffer(): | |
| nonlocal token_data_buffer, output, cfg, utils, augment | |
| token_data = [x for x in utils.prog_data_list_to_token_data_list(token_data_buffer)] | |
| if augment.instrument_bin_remap or augment.transpose_semitones: | |
| # TODO put transpose in a real function | |
| raw_transpose = lambda bin, n: n + augment.transpose_semitones if bin != cfg._ch10_bin_int else n | |
| octave_shift_if_oob = lambda n: n + 12 if n < 0 else n - 12 if n >= cfg.note_events else n | |
| # TODO handle ranges beyond 12 | |
| #octave_shift_if_oob = lambda n: 0 if n < 0 else (n - cfg.note_events) % 12 + cfg.note_events if n >= cfg.note_events else n | |
| transpose = lambda bin, n: octave_shift_if_oob(raw_transpose(bin, n)) | |
| token_data = [(augment.instrument_bin_remap.get(i, i), transpose(i, n), v) for i, n, v in token_data] | |
| if cfg.do_token_sorting: | |
| token_data = utils.sort_token_data(token_data) | |
| if cfg.unrolled_tokens: | |
| for t in token_data: | |
| output += [utils.format_unrolled_instrument_bin(t[0]), utils.format_unrolled_note(t[1]), utils.format_unrolled_velocity(t[2])] | |
| else: | |
| output += [utils.format_note_token(*t) for t in token_data] | |
| token_data_buffer = [] | |
| def consume_note_program_data(prog: int, chan: int, note: int, vel: float): | |
| nonlocal output, output_length_ms, started_flag, delta_time_ms, cfg, utils, token_data_buffer | |
| is_token_valid = utils.prog_data_to_token_data(prog, chan, note, vel) is not None | |
| if not is_token_valid: | |
| return | |
| if delta_time_ms > filter_cfg.piece_split_delay * 1000.0: | |
| # check if any notes are still held | |
| silent = True | |
| for channel in channel_notes.keys(): | |
| if len(channel_notes[channel]) > 0: | |
| silent = False | |
| break | |
| if silent: | |
| flush_token_data_buffer() | |
| output.append("<end>") | |
| if output_length_ms > filter_cfg.min_piece_length * 1000.0: | |
| output_list.append(" ".join(output)) | |
| output = ["<start>"] | |
| output_length_ms = 0.0 | |
| started_flag = False | |
| if started_flag: | |
| wait_tokens = utils.data_to_wait_tokens(delta_time_ms) | |
| if len(wait_tokens) > 0: | |
| flush_token_data_buffer() | |
| output_length_ms += delta_time_ms | |
| output += wait_tokens | |
| delta_time_ms = 0.0 | |
| token_data_buffer.append((prog, chan, note, vel * augment.velocity_mod_factor)) | |
| started_flag = True | |
| for msg in mid.tracks[0]: | |
| time_ms = mido.tick2second(msg.time, mid.ticks_per_beat, tempo) * 1000.0 | |
| delta_time_ms += time_ms | |
| t = msg.type | |
| if msg.is_meta: | |
| if t == "set_tempo": | |
| tempo = msg.tempo * augment.time_stretch_factor | |
| continue | |
| def handle_note_off(ch, prog, n): | |
| if channel_pedal_on[ch]: | |
| channel_pedal_events[ch][(n, prog)] = True | |
| else: | |
| consume_note_program_data(prog, ch, n, 0) | |
| if n in channel_notes[ch]: | |
| del channel_notes[ch][n] | |
| if t == "program_change": | |
| channel_program[msg.channel] = msg.program | |
| elif t == "note_on": | |
| if msg.velocity == 0: | |
| handle_note_off(msg.channel, channel_program[msg.channel], msg.note) | |
| else: | |
| if (msg.note, channel_program[msg.channel]) in channel_pedal_events[msg.channel]: | |
| del channel_pedal_events[msg.channel][(msg.note, channel_program[msg.channel])] | |
| consume_note_program_data( | |
| channel_program[msg.channel], | |
| msg.channel, | |
| msg.note, | |
| mix_volume(msg.velocity, channel_volume[msg.channel], channel_expression[msg.channel]), | |
| ) | |
| channel_notes[msg.channel][msg.note] = True | |
| elif t == "note_off": | |
| handle_note_off(msg.channel, channel_program[msg.channel], msg.note) | |
| elif t == "control_change": | |
| if msg.control == 7 or msg.control == 39: # volume | |
| channel_volume[msg.channel] = msg.value | |
| elif msg.control == 11: # expression | |
| channel_expression[msg.channel] = msg.value | |
| elif msg.control == 64: # sustain pedal | |
| channel_pedal_on[msg.channel] = msg.value >= 64 | |
| if not channel_pedal_on[msg.channel]: | |
| for (note, program) in channel_pedal_events[msg.channel]: | |
| handle_note_off(msg.channel, program, note) | |
| channel_pedal_events[msg.channel] = {} | |
| elif msg.control == 123: # all notes off | |
| for channel in channel_notes.keys(): | |
| for note in list(channel_notes[channel]).copy(): | |
| handle_note_off(channel, channel_program[channel], note) | |
| else: | |
| pass | |
| flush_token_data_buffer() | |
| output.append("<end>") | |
| if output_length_ms > filter_cfg.min_piece_length * 1000.0: | |
| output_list.append(" ".join(output)) | |
| return output_list | |
| def generate_program_change_messages(cfg: VocabConfig): | |
| for bin_name, channel in cfg.bin_channel_map.items(): | |
| if channel == 9: | |
| continue | |
| program = cfg._instrument_names_str_to_int[cfg.bin_name_to_program_name[bin_name]] | |
| yield mido.Message("program_change", program=program, time=0, channel=channel) | |
| yield mido.Message("program_change", program=0, time=0, channel=9) | |
| class DecodeState: | |
| total_time: float # milliseconds | |
| delta_accum: float # milliseconds | |
| current_bin: int | |
| current_note: int | |
| active_notes: Dict[Tuple[int, int], float] # { (channel, note): time started, ... } | |
| def token_to_midi_message(utils: VocabUtils, token: str, state: DecodeState, end_token_pause: float = 3.0) -> Iterator[Tuple[Optional[mido.Message], DecodeState]]: | |
| if state is None: | |
| state = DecodeState(total_time=0.0, delta_accum=0.0, current_bin=utils.cfg._short_instrument_names_str_to_int[utils.cfg.short_instr_bin_names[0]], current_note=0, active_notes={}) | |
| token = token.strip() | |
| if not token: | |
| yield None, state | |
| return | |
| if token == "<end>": | |
| d = end_token_pause * 1000.0 | |
| state.delta_accum += d | |
| state.total_time += d | |
| if utils.cfg.decode_end_held_note_delay != 0.0: | |
| # end held notes | |
| for (channel, note), start_time in list(state.active_notes.items()).copy(): | |
| ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000)) | |
| state.delta_accum = 0.0 | |
| del state.active_notes[(channel, note)] | |
| yield mido.Message("note_off", note=note, time=ticks, channel=channel), state | |
| yield None, state | |
| return | |
| if token.startswith("<"): | |
| yield None, state | |
| return | |
| if utils.cfg.unrolled_tokens: | |
| if token[0] == "t": | |
| d = utils.wait_token_to_delta(token) | |
| state.delta_accum += d | |
| state.total_time += d | |
| elif token[0] == "n": | |
| state.current_note = int(token[1:], base=16) | |
| elif token[0] == "i": | |
| state.current_bin = utils.cfg._short_instrument_names_str_to_int[token[1:]] | |
| elif token[0] == "v": | |
| current_velocity = utils.bin_to_velocity(int(token[1:], base=16)) | |
| channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[state.current_bin]] | |
| ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000)) | |
| state.delta_accum = 0.0 | |
| if current_velocity > 0: | |
| yield mido.Message("note_on", note=state.current_note, velocity=current_velocity, time=ticks, channel=channel), state | |
| else: | |
| yield mido.Message("note_off", note=state.current_note, velocity=0, time=ticks, channel=channel), state | |
| else: | |
| if token[0] == "t" and token[1].isdigit(): # wait token | |
| d = utils.wait_token_to_delta(token) | |
| state.delta_accum += d | |
| state.total_time += d | |
| if utils.cfg.decode_end_held_note_delay != 0.0: | |
| # remove notes that have been held for too long | |
| for (channel, note), start_time in list(state.active_notes.items()).copy(): | |
| if state.total_time - start_time > utils.cfg.decode_end_held_note_delay * 1000.0: | |
| ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000)) | |
| state.delta_accum = 0.0 | |
| del state.active_notes[(channel, note)] | |
| yield mido.Message("note_off", note=note, time=ticks, channel=channel), state | |
| return | |
| else: # note token | |
| bin, note, velocity = utils.note_token_to_data(token) | |
| channel = utils.cfg.bin_channel_map[utils.cfg.bin_instrument_names[bin]] | |
| ticks = int(mido.second2tick(state.delta_accum / 1000.0, 480, 500000)) | |
| state.delta_accum = 0.0 | |
| if velocity > 0: | |
| if utils.cfg.decode_fix_repeated_notes: | |
| if (channel, note) in state.active_notes: | |
| del state.active_notes[(channel, note)] | |
| yield mido.Message("note_off", note=note, time=ticks, channel=channel), state | |
| ticks = 0 | |
| state.active_notes[(channel, note)] = state.total_time | |
| yield mido.Message("note_on", note=note, velocity=velocity, time=ticks, channel=channel), state | |
| return | |
| else: | |
| if (channel, note) in state.active_notes: | |
| del state.active_notes[(channel, note)] | |
| yield mido.Message("note_off", note=note, time=ticks, channel=channel), state | |
| return | |
| yield None, state | |
| def str_to_midi_messages(utils: VocabUtils, data: str) -> Iterator[mido.Message]: | |
| state = None | |
| for token in data.split(" "): | |
| for msg, new_state in token_to_midi_message(utils, token, state): | |
| state = new_state | |
| if msg is not None: | |
| yield msg | |
| def convert_str_to_midi(cfg: VocabConfig, data: str, meta_text: str = "Generated by MIDI-LLM-tokenizer") -> mido.MidiFile: | |
| utils = VocabUtils(cfg) | |
| mid = mido.MidiFile() | |
| track = mido.MidiTrack() | |
| mid.tracks.append(track) | |
| tempo = 500000 | |
| if meta_text: | |
| track.append(mido.MetaMessage("text", text=meta_text, time=0)) | |
| track.append(mido.MetaMessage("set_tempo", tempo=tempo, time=0)) | |
| for msg in generate_program_change_messages(cfg): | |
| track.append(msg) | |
| #data = data.replace("<start>", "").replace("<end>", "").replace("<pad>", "").strip() | |
| for msg in str_to_midi_messages(utils, data): | |
| track.append(msg) | |
| track.append(mido.MetaMessage("end_of_track", time=0)) | |
| return mid |