sander-wood commited on
Commit
3c428bc
·
verified ·
1 Parent(s): 5c1bee3

Upload 32 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  overview.jpg filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  overview.jpg filter=lfs diff=lfs merge=lfs -text
37
+ benchmarks.z01 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sander Wood
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
benchmarks.z01 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d88d8f54b46b5ea5ce2ba318eaf7393ef6aa63a1a7fa5247b4c4cdb9b917f6b0
3
+ size 20971520
benchmarks.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6497fc7f335e8f0a5bbddd474d040a218ef13488e280f5ad0b32b51ca036e89f
3
+ size 14043301
code/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLaMP 2 Codebase
2
+
3
+ ## Overview
4
+ CLaMP 2 is a state-of-the-art multimodal music information retrieval system designed to work with 101 languages. This codebase includes scripts for training models, extracting features, and utility functions for processing music and text data. Below is a description of the scripts contained in the `code/` folder.
5
+
6
+ ## Repository Structure
7
+ The `code/` folder contains the following scripts:
8
+
9
+ ### 1. `config.py`
10
+ This script contains the training hyperparameters and file paths used in the `train_clamp2.py` and `train_m3.py` scripts. You can modify parameters such as learning rates, batch sizes, and file locations for training data.
11
+
12
+ ### 2. `extract_clamp2.py`
13
+ This script utilizes the pre-trained CLaMP 2 model to extract representations of text (.txt) or music (.abc or .mtf) from a specified input folder and save the features to a target output folder in `.npy` format. The extracted features can be normalized for semantic search or retain temporal information for classification tasks.
14
+
15
+ **Usage:**
16
+ ```bash
17
+ python extract_clamp2.py <input_dir> <output_dir> [--normalize]
18
+ ```
19
+ - `input_dir`: Directory containing input data files.
20
+ - `output_dir`: Directory to save the output features.
21
+ - `--normalize`: (Optional) Normalize the extracted features. Normalization is not required for music classification tasks, but it is required for semantic search tasks.
22
+
23
+ ### 3. `extract_m3.py`
24
+ This script employs the pre-trained M3 model to extract representations in interleaved ABC notation and MIDI Text Format (MTF) from the specified input folder, saving the features to the target folder as `.npy` files.
25
+
26
+ **Usage:**
27
+ ```bash
28
+ python extract_m3.py <input_dir> <output_dir>
29
+ ```
30
+ - `input_dir`: Directory with input files (in .abc or .mtf format).
31
+ - `output_dir`: Directory to save extracted features.
32
+
33
+ ### 4. `train_clamp2.py`
34
+ This script manages the training process for the CLaMP 2 model. It prepares training data from a path specified in the `TRAIN_JSONL` variable, which is defined in the `config.py` file. If `EVAL_JSONL` is provided in the configuration, it will be used for validation. By default, 1% of the training data is reserved for validation.
35
+
36
+ CLaMP 2 utilizes the multilingual text encoder `FacebookAI/xlm-roberta-base` for processing text data. Additionally, it employs the M3 model, pre-trained on both ABC and MIDI data, as the multimodal music encoder. If the pre-trained weights for M3 are available and the configuration variable `CLAMP2_LOAD_M3` is set to True, the training script will automatically load the M3 weights.
37
+
38
+ **Training Command:**
39
+ To start the training process, use the following command:
40
+
41
+ ```bash
42
+ torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env train_clamp2.py
43
+ ```
44
+
45
+ Replace `<number_of_GPUs>` with the number of GPUs you want to use for training.
46
+
47
+ **Input Data Format**
48
+ The input training data should be in JSONL format, where each line contains a single JSON object with the following structure. Fields that do not apply should be set to `None`:
49
+
50
+ ```json
51
+ {
52
+ "title": "Song Title",
53
+ "composer": "Composer Name",
54
+ "genres": ["Genre1", "Genre2"],
55
+ "description": "Song description.",
56
+ "lyrics": "Song lyrics.",
57
+ "tags": ["tag1", "tag2"],
58
+ "ensembles": ["Ensemble Name"],
59
+ "instruments": ["Instrument1", "Instrument2"],
60
+ "summary_en": "English summary.",
61
+ "summary_nen": {
62
+ "language": "Language Name",
63
+ "summary": "Summary in specified language."
64
+ },
65
+ "filepaths": [
66
+ "path/to/abc/file.abc",
67
+ "path/to/mtf/file.mtf"
68
+ ]
69
+ }
70
+ ```
71
+
72
+ For obtaining the English and non-English summaries generated by GPT-4, refer to the `process_data/gpt4_summarize.py` script.
73
+
74
+ ### 5. `train_m3.py`
75
+ This script is dedicated to training the M3 model using interleaved ABC and MTF files. The directories for training and optional evaluation data should be specified in the `TRAIN_FOLDERS` and `EVAL_FOLDERS` variables, respectively.
76
+
77
+ **Training Command:**
78
+ To start the training process for the M3 model, use the following command:
79
+
80
+ ```bash
81
+ torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env train_m3.py
82
+ ```
83
+
84
+ Replace `<number_of_GPUs>` with the number of GPUs you want to use for training.
85
+
86
+ **Data Preparation:**
87
+ The data should be structured in interleaved ABC (.abc) and MTF (.mtf) formats. Please refer to the `process_data/` folder for instructions on how to prepare these formats.
88
+
89
+ ### 6. `utils.py`
90
+ This utility script contains various classes for model definitions and functions used for training.
code/config.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation
2
+ WANDB_KEY = "<your_wandb_key>" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging
3
+
4
+ # -------------------- Configuration for M3 Training --------------------
5
+ TRAIN_FOLDERS = [
6
+ "<path_to_training_data>" # Directory containing training data
7
+ ]
8
+
9
+ EVAL_FOLDERS = [
10
+ "" # (Optional) Directory containing evaluation data
11
+ ]
12
+
13
+ PATCH_SIZE = 64 # Size of each patch
14
+ PATCH_LENGTH = 512 # Length of the patches
15
+ PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
16
+ TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder
17
+ M3_HIDDEN_SIZE = 768 # Size of the hidden layer
18
+
19
+ M3_NUM_EPOCH = 100 # Maximum number of epochs for training
20
+ M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer
21
+ M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training
22
+ M3_MASK_RATIO = 0.45 # Ratio of masked elements during training
23
+ M3_DETERMINISTIC = True # Ensures deterministic results with random seeds
24
+ M3_WANDB_LOG = True # Enable logging to Weights and Biases
25
+ M3_LOAD_CKPT = True # Load model weights from a checkpoint if available
26
+
27
+ M3_WEIGHTS_PATH = (
28
+ "weights_m3_p_size_" + str(PATCH_SIZE) +
29
+ "_p_length_" + str(PATCH_LENGTH) +
30
+ "_t_layers_" + str(TOKEN_NUM_LAYERS) +
31
+ "_p_layers_" + str(PATCH_NUM_LAYERS) +
32
+ "_h_size_" + str(M3_HIDDEN_SIZE) +
33
+ "_lr_" + str(M3_LEARNING_RATE) +
34
+ "_batch_" + str(M3_BATCH_SIZE) +
35
+ "_mask_" + str(M3_MASK_RATIO) + ".pth"
36
+ ) # Path to store the model weights
37
+ M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
38
+
39
+ # -------------------- Configuration for CLaMP2 Training ----------------
40
+ TRAIN_JSONL = "<path_to_training_jsonl>" # Path to the JSONL file with training data
41
+ EVAL_JSONL = "" # (Optional) Path to the JSONL file with evaluation data
42
+
43
+ CLAMP2_HIDDEN_SIZE = 768 # Size of the hidden layer
44
+ TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model
45
+
46
+ CLAMP2_NUM_EPOCH = 100 # Maximum number of epochs for training
47
+ CLAMP2_LEARNING_RATE = 5e-5 # Learning rate for the optimizer
48
+ CLAMP2_BATCH_SIZE = 128 # Batch size per GPU (single card) during training
49
+ LOGIT_SCALE = 1 # Scaling factor for contrastive loss
50
+ MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input
51
+ TEXT_DROPOUT = True # Whether to apply dropout during text processing
52
+ CLAMP2_DETERMINISTIC = True # Ensures deterministic results with random seeds
53
+ CLAMP2_LOAD_M3 = True # Load weights from the M3 model
54
+ CLAMP2_WANDB_LOG = True # Enable logging to Weights and Biases
55
+ CLAMP2_LOAD_CKPT = True # Load weights from a checkpoint if available
56
+
57
+ CLAMP2_WEIGHTS_PATH = (
58
+ "weights_clamp2_h_size_" + str(CLAMP2_HIDDEN_SIZE) +
59
+ "_lr_" + str(CLAMP2_LEARNING_RATE) +
60
+ "_batch_" + str(CLAMP2_BATCH_SIZE) +
61
+ "_scale_" + str(LOGIT_SCALE) +
62
+ "_t_length_" + str(MAX_TEXT_LENGTH) +
63
+ "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") +
64
+ "_t_dropout_" + str(TEXT_DROPOUT) +
65
+ "_m3_" + str(CLAMP2_LOAD_M3) + ".pth"
66
+ ) # Path to store CLaMP2 model weights
67
+ CLAMP2_LOGS_PATH = CLAMP2_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
code/extract_clamp2.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from config import *
8
+ from utils import *
9
+ from samplings import *
10
+ from accelerate import Accelerator
11
+ from transformers import BertConfig, AutoTokenizer
12
+ import argparse
13
+
14
+ # Parse command-line arguments
15
+ parser = argparse.ArgumentParser(description="Feature extraction for CLaMP2.")
16
+ parser.add_argument("input_dir", type=str, help="Directory containing input data files.")
17
+ parser.add_argument("output_dir", type=str, help="Directory to save the output features.")
18
+ parser.add_argument("--normalize", action="store_true", help="Normalize the extracted features.")
19
+
20
+ args = parser.parse_args()
21
+
22
+ # Retrieve arguments
23
+ input_dir = args.input_dir
24
+ output_dir = args.output_dir
25
+ normalize = args.normalize
26
+
27
+ os.makedirs("logs", exist_ok=True)
28
+ for file in ["logs/files_extract_clamp2.json",
29
+ "logs/files_shuffle_extract_clamp2.json",
30
+ "logs/log_extract_clamp2.txt",
31
+ "logs/pass_extract_clamp2.txt",
32
+ "logs/skip_extract_clamp2.txt"]:
33
+ if os.path.exists(file):
34
+ os.remove(file)
35
+
36
+ files = []
37
+ for root, dirs, fs in os.walk(input_dir):
38
+ for f in fs:
39
+ if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf"):
40
+ files.append(os.path.join(root, f))
41
+ print(f"Found {len(files)} files in total")
42
+ with open("logs/files_extract_clamp2.json", "w", encoding="utf-8") as f:
43
+ json.dump(files, f)
44
+ random.shuffle(files)
45
+ with open("logs/files_shuffle_extract_clamp2.json", "w", encoding="utf-8") as f:
46
+ json.dump(files, f)
47
+
48
+ accelerator = Accelerator()
49
+ device = accelerator.device
50
+ print("Using device:", device)
51
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
52
+ f.write("Using device: " + str(device) + "\n")
53
+
54
+ m3_config = BertConfig(vocab_size=1,
55
+ hidden_size=M3_HIDDEN_SIZE,
56
+ num_hidden_layers=PATCH_NUM_LAYERS,
57
+ num_attention_heads=M3_HIDDEN_SIZE//64,
58
+ intermediate_size=M3_HIDDEN_SIZE*4,
59
+ max_position_embeddings=PATCH_LENGTH)
60
+ model = CLaMP2Model(m3_config,
61
+ text_model_name=TEXT_MODEL_NAME,
62
+ hidden_size=CLAMP2_HIDDEN_SIZE,
63
+ load_m3=CLAMP2_LOAD_M3)
64
+ model = model.to(device)
65
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
66
+ patchilizer = M3Patchilizer()
67
+
68
+ # print parameter number
69
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
70
+
71
+ model.eval()
72
+ checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
73
+ print(f"Successfully Loaded CLaMP 2 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
74
+ model.load_state_dict(checkpoint['model'])
75
+
76
+ def extract_feature(filename, get_normalized=normalize):
77
+ with open(filename, "r", encoding="utf-8") as f:
78
+ item = f.read()
79
+
80
+ if filename.endswith(".txt"):
81
+ item = list(set(item.split("\n")))
82
+ item = "\n".join(item)
83
+ item = item.split("\n")
84
+ item = [c for c in item if len(c) > 0]
85
+ item = tokenizer.sep_token.join(item)
86
+ input_data = tokenizer(item, return_tensors="pt")
87
+ input_data = input_data['input_ids'].squeeze(0)
88
+ max_input_length = MAX_TEXT_LENGTH
89
+ else:
90
+ input_data = patchilizer.encode(item, add_special_patches=True)
91
+ input_data = torch.tensor(input_data)
92
+ max_input_length = PATCH_LENGTH
93
+
94
+ segment_list = []
95
+ for i in range(0, len(input_data), max_input_length):
96
+ segment_list.append(input_data[i:i+max_input_length])
97
+ segment_list[-1] = input_data[-max_input_length:]
98
+
99
+ last_hidden_states_list = []
100
+
101
+ for input_segment in segment_list:
102
+ input_masks = torch.tensor([1]*input_segment.size(0))
103
+ if filename.endswith(".txt"):
104
+ pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
105
+ else:
106
+ pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
107
+ input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
108
+ input_segment = torch.cat((input_segment, pad_indices), 0)
109
+
110
+ if filename.endswith(".txt"):
111
+ last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
112
+ text_masks=input_masks.unsqueeze(0).to(device),
113
+ get_normalized=get_normalized)
114
+ else:
115
+ last_hidden_states = model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(device),
116
+ music_masks=input_masks.unsqueeze(0).to(device),
117
+ get_normalized=get_normalized)
118
+ if not get_normalized:
119
+ last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
120
+ last_hidden_states_list.append(last_hidden_states)
121
+
122
+ if not get_normalized:
123
+ last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
124
+ last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
125
+ last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
126
+ else:
127
+ full_chunk_cnt = len(input_data) // max_input_length
128
+ remain_chunk_len = len(input_data) % max_input_length
129
+ if remain_chunk_len == 0:
130
+ feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
131
+ else:
132
+ feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
133
+
134
+ last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
135
+ last_hidden_states_list = last_hidden_states_list * feature_weights
136
+ last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()
137
+
138
+ return last_hidden_states_list
139
+
140
+ def process_directory(input_dir, output_dir, files):
141
+ print(f"Found {len(files)} files in total")
142
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
143
+ f.write("Found " + str(len(files)) + " files in total\n")
144
+
145
+ # calculate the number of files to process per GPU
146
+ num_files_per_gpu = len(files) // accelerator.num_processes
147
+
148
+ # calculate the start and end index for the current GPU
149
+ start_idx = accelerator.process_index * num_files_per_gpu
150
+ end_idx = start_idx + num_files_per_gpu
151
+ if accelerator.process_index == accelerator.num_processes - 1:
152
+ end_idx = len(files)
153
+
154
+ files_to_process = files[start_idx:end_idx]
155
+
156
+ # process the files
157
+ for file in tqdm(files_to_process):
158
+ output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
159
+ try:
160
+ os.makedirs(output_subdir, exist_ok=True)
161
+ except Exception as e:
162
+ print(output_subdir + " can not be created\n" + str(e))
163
+ with open("logs/log_extract_clamp.txt", "a") as f:
164
+ f.write(output_subdir + " can not be created\n" + str(e) + "\n")
165
+
166
+ output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
167
+
168
+ if os.path.exists(output_file):
169
+ print(f"Skipping {file}, output already exists")
170
+ with open("logs/skip_extract_clamp2.txt", "a", encoding="utf-8") as f:
171
+ f.write(file + "\n")
172
+ continue
173
+
174
+ try:
175
+ with torch.no_grad():
176
+ features = extract_feature(file).unsqueeze(0)
177
+ np.save(output_file, features.detach().cpu().numpy())
178
+ with open("logs/pass_extract_clamp2.txt", "a", encoding="utf-8") as f:
179
+ f.write(file + "\n")
180
+ except Exception as e:
181
+ print(f"Failed to process {file}: {e}")
182
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
183
+ f.write("Failed to process " + file + ": " + str(e) + "\n")
184
+
185
+ with open("logs/files_shuffle_extract_clamp2.json", "r", encoding="utf-8") as f:
186
+ files = json.load(f)
187
+
188
+ # process the files
189
+ process_directory(input_dir, output_dir, files)
190
+
191
+ with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
192
+ f.write("GPU ID: " + str(device) + "\n")
code/extract_m3.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import torch
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from config import *
8
+ from utils import *
9
+ from samplings import *
10
+ from accelerate import Accelerator
11
+ from transformers import BertConfig, GPT2Config
12
+ import argparse
13
+
14
+ # Parse command-line arguments for input_dir and output_dir
15
+ parser = argparse.ArgumentParser(description="Process files to extract features.")
16
+ parser.add_argument("input_dir", type=str, help="Directory with input files")
17
+ parser.add_argument("output_dir", type=str, help="Directory to save extracted features")
18
+ args = parser.parse_args()
19
+
20
+ # Use args for input and output directories
21
+ input_dir = args.input_dir
22
+ output_dir = args.output_dir
23
+
24
+ # Create logs directory if it doesn't exist
25
+ os.makedirs("logs", exist_ok=True)
26
+
27
+ # Remove existing log files if present
28
+ for file in [
29
+ "logs/files_extract_m3.json",
30
+ "logs/files_shuffle_extract_m3.json",
31
+ "logs/log_extract_m3.txt",
32
+ "logs/pass_extract_m3.txt",
33
+ "logs/skip_extract_m3.txt",
34
+ ]:
35
+ if os.path.exists(file):
36
+ os.remove(file)
37
+
38
+ # Collect input files
39
+ files = []
40
+ for root, dirs, fs in os.walk(input_dir):
41
+ for f in fs:
42
+ if f.endswith(".abc") or f.endswith(".mtf"):
43
+ files.append(os.path.join(root, f))
44
+
45
+ print(f"Found {len(files)} files in total")
46
+ with open("logs/files_extract_m3.json", "w", encoding="utf-8") as f:
47
+ json.dump(files, f)
48
+
49
+ # Shuffle files and save the shuffled order
50
+ random.shuffle(files)
51
+ with open("logs/files_shuffle_extract_m3.json", "w", encoding="utf-8") as f:
52
+ json.dump(files, f)
53
+
54
+ # Initialize accelerator and device
55
+ accelerator = Accelerator()
56
+ device = accelerator.device
57
+ print("Using device:", device)
58
+ with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
59
+ f.write("Using device: " + str(device) + "\n")
60
+
61
+ # Model and configuration setup
62
+ patchilizer = M3Patchilizer()
63
+ encoder_config = BertConfig(
64
+ vocab_size=1,
65
+ hidden_size=M3_HIDDEN_SIZE,
66
+ num_hidden_layers=PATCH_NUM_LAYERS,
67
+ num_attention_heads=M3_HIDDEN_SIZE // 64,
68
+ intermediate_size=M3_HIDDEN_SIZE * 4,
69
+ max_position_embeddings=PATCH_LENGTH,
70
+ )
71
+ decoder_config = GPT2Config(
72
+ vocab_size=128,
73
+ n_positions=PATCH_SIZE,
74
+ n_embd=M3_HIDDEN_SIZE,
75
+ n_layer=TOKEN_NUM_LAYERS,
76
+ n_head=M3_HIDDEN_SIZE // 64,
77
+ n_inner=M3_HIDDEN_SIZE * 4,
78
+ )
79
+ model = M3Model(encoder_config, decoder_config).to(device)
80
+
81
+ # Print parameter count
82
+ print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
83
+
84
+ # Load model weights
85
+ model.eval()
86
+ checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
87
+ print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
88
+ model.load_state_dict(checkpoint['model'])
89
+
90
+ def extract_feature(item):
91
+ """Extracts features from input data."""
92
+ target_patches = patchilizer.encode(item, add_special_patches=True)
93
+ target_patches_list = [target_patches[i:i + PATCH_LENGTH] for i in range(0, len(target_patches), PATCH_LENGTH)]
94
+ target_patches_list[-1] = target_patches[-PATCH_LENGTH:]
95
+
96
+ last_hidden_states_list = []
97
+ for input_patches in target_patches_list:
98
+ input_masks = torch.tensor([1] * len(input_patches))
99
+ input_patches = torch.tensor(input_patches)
100
+ last_hidden_states = model.encoder(
101
+ input_patches.unsqueeze(0).to(device), input_masks.unsqueeze(0).to(device)
102
+ )["last_hidden_state"][0]
103
+ last_hidden_states_list.append(last_hidden_states)
104
+
105
+ # Handle the last segment padding correctly
106
+ last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(target_patches) % PATCH_LENGTH):]
107
+ return torch.concat(last_hidden_states_list, 0)
108
+
109
+ def process_directory(input_dir, output_dir, files):
110
+ """Processes files in the input directory and saves features to the output directory."""
111
+ print(f"Found {len(files)} files in total")
112
+ with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
113
+ f.write("Found " + str(len(files)) + " files in total\n")
114
+
115
+ # Distribute files across processes for parallel processing
116
+ num_files_per_gpu = len(files) // accelerator.num_processes
117
+ start_idx = accelerator.process_index * num_files_per_gpu
118
+ end_idx = start_idx + num_files_per_gpu if accelerator.process_index < accelerator.num_processes - 1 else len(files)
119
+ files_to_process = files[start_idx:end_idx]
120
+
121
+ # Process each file
122
+ for file in tqdm(files_to_process):
123
+ output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
124
+ try:
125
+ os.makedirs(output_subdir, exist_ok=True)
126
+ except Exception as e:
127
+ print(f"{output_subdir} cannot be created\n{e}")
128
+ with open("logs/log_extract_m3.txt", "a") as f:
129
+ f.write(f"{output_subdir} cannot be created\n{e}\n")
130
+
131
+ output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
132
+
133
+ if os.path.exists(output_file):
134
+ print(f"Skipping {file}, output already exists")
135
+ with open("logs/skip_extract_m3.txt", "a", encoding="utf-8") as f:
136
+ f.write(file + "\n")
137
+ continue
138
+
139
+ try:
140
+ with open(file, "r", encoding="utf-8") as f:
141
+ item = f.read()
142
+ if not item.startswith("ticks_per_beat"):
143
+ item = item.replace("L:1/8\n", "")
144
+ with torch.no_grad():
145
+ features = extract_feature(item).unsqueeze(0)
146
+ np.save(output_file, features.detach().cpu().numpy())
147
+ with open("logs/pass_extract_m3.txt", "a", encoding="utf-8") as f:
148
+ f.write(file + "\n")
149
+ except Exception as e:
150
+ print(f"Failed to process {file}: {e}")
151
+ with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
152
+ f.write(f"Failed to process {file}: {e}\n")
153
+
154
+ # Load shuffled files list and start processing
155
+ with open("logs/files_shuffle_extract_m3.json", "r", encoding="utf-8") as f:
156
+ files = json.load(f)
157
+
158
+ # Process the directory
159
+ process_directory(input_dir, output_dir, files)
160
+
161
+ with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
162
+ f.write("GPU ID: " + str(device) + "\n")
code/logs_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.txt ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Epoch 1
2
+ train_loss: 4.135209383230448
3
+ eval_loss: 1.9609466393788655
4
+ time: Sun Sep 15 16:53:25 2024
5
+
6
+ Epoch 2
7
+ train_loss: 2.9014642586344244
8
+ eval_loss: 1.518966309229533
9
+ time: Sun Sep 15 18:43:11 2024
10
+
11
+ Epoch 3
12
+ train_loss: 2.5554833216573805
13
+ eval_loss: 1.2929850498835245
14
+ time: Sun Sep 15 20:35:14 2024
15
+
16
+ Epoch 4
17
+ train_loss: 2.3372960954320985
18
+ eval_loss: 1.1918509085973104
19
+ time: Sun Sep 15 22:23:17 2024
20
+
21
+ Epoch 5
22
+ train_loss: 2.183457492896627
23
+ eval_loss: 1.0725035190582275
24
+ time: Mon Sep 16 00:11:33 2024
25
+
26
+ Epoch 6
27
+ train_loss: 2.0555087922796216
28
+ eval_loss: 0.9798878033955892
29
+ time: Mon Sep 16 01:59:41 2024
30
+
31
+ Epoch 7
32
+ train_loss: 1.9556777615308922
33
+ eval_loss: 0.9242686867713928
34
+ time: Mon Sep 16 03:47:45 2024
35
+
36
+ Epoch 8
37
+ train_loss: 1.8689270445336208
38
+ eval_loss: 0.8476833502451578
39
+ time: Mon Sep 16 05:35:55 2024
40
+
41
+ Epoch 9
42
+ train_loss: 1.7932923043361795
43
+ eval_loss: 0.8043208519617716
44
+ time: Mon Sep 16 07:24:40 2024
45
+
46
+ Epoch 10
47
+ train_loss: 1.726651672090251
48
+ eval_loss: 0.7419429302215577
49
+ time: Mon Sep 16 09:12:24 2024
50
+
51
+ Epoch 11
52
+ train_loss: 1.6657210920084013
53
+ eval_loss: 0.7209376732508341
54
+ time: Mon Sep 16 11:01:02 2024
55
+
56
+ Epoch 12
57
+ train_loss: 1.6131336856822078
58
+ eval_loss: 0.6950737277666728
59
+ time: Mon Sep 16 12:50:07 2024
60
+
61
+ Epoch 13
62
+ train_loss: 1.5601606700647368
63
+ eval_loss: 0.652581254641215
64
+ time: Mon Sep 16 14:40:19 2024
65
+
66
+ Epoch 14
67
+ train_loss: 1.5198849670231784
68
+ eval_loss: 0.5868058403333029
69
+ time: Mon Sep 16 16:30:09 2024
70
+
71
+ Epoch 15
72
+ train_loss: 1.4755114693644882
73
+ eval_loss: 0.5867449243863424
74
+ time: Mon Sep 16 18:20:11 2024
75
+
76
+ Epoch 16
77
+ train_loss: 1.4336211351749464
78
+ eval_loss: 0.5479268968105316
79
+ time: Mon Sep 16 20:10:42 2024
80
+
81
+ Epoch 17
82
+ train_loss: 1.4039166571722088
83
+ eval_loss: 0.5280438164869944
84
+ time: Mon Sep 16 22:01:01 2024
85
+
86
+ Epoch 18
87
+ train_loss: 1.365380759842085
88
+ eval_loss: 0.5008598109086354
89
+ time: Mon Sep 16 23:51:36 2024
90
+
91
+ Epoch 19
92
+ train_loss: 1.332988672848894
93
+ eval_loss: 0.46479900081952413
94
+ time: Tue Sep 17 01:41:49 2024
95
+
96
+ Epoch 20
97
+ train_loss: 1.3014001791981087
98
+ eval_loss: 0.45230263272921245
99
+ time: Tue Sep 17 03:33:12 2024
100
+
101
+ Epoch 21
102
+ train_loss: 1.2688752540577755
103
+ eval_loss: 0.4297992348670959
104
+ time: Tue Sep 17 05:23:40 2024
105
+
106
+ Epoch 22
107
+ train_loss: 1.2425967695381415
108
+ eval_loss: 0.4219102164109548
109
+ time: Tue Sep 17 07:14:26 2024
110
+
111
+ Epoch 23
112
+ train_loss: 1.216824040410488
113
+ eval_loss: 0.40282649795214337
114
+ time: Tue Sep 17 09:05:53 2024
115
+
116
+ Epoch 24
117
+ train_loss: 1.1875996505747286
118
+ eval_loss: 0.36659018794695536
119
+ time: Tue Sep 17 10:56:46 2024
120
+
121
+ Epoch 25
122
+ train_loss: 1.1670776548255222
123
+ eval_loss: 0.36906688412030536
124
+ time: Tue Sep 17 12:47:39 2024
125
+
126
+ Epoch 26
127
+ train_loss: 1.1426405137974536
128
+ eval_loss: 0.3478178918361664
129
+ time: Tue Sep 17 14:38:34 2024
130
+
131
+ Epoch 27
132
+ train_loss: 1.1208335824466733
133
+ eval_loss: 0.33407697081565857
134
+ time: Tue Sep 17 16:28:45 2024
135
+
136
+ Epoch 28
137
+ train_loss: 1.0998876758880667
138
+ eval_loss: 0.33792892495791116
139
+ time: Tue Sep 17 18:19:59 2024
140
+
141
+ Epoch 29
142
+ train_loss: 1.0769698478377083
143
+ eval_loss: 0.3026650925477346
144
+ time: Tue Sep 17 20:11:16 2024
145
+
146
+ Epoch 30
147
+ train_loss: 1.0587592209657248
148
+ eval_loss: 0.2914476583401362
149
+ time: Tue Sep 17 22:03:30 2024
150
+
151
+ Epoch 31
152
+ train_loss: 1.0384011404245468
153
+ eval_loss: 0.27578969597816466
154
+ time: Tue Sep 17 23:55:15 2024
155
+
156
+ Epoch 32
157
+ train_loss: 1.0233595809527622
158
+ eval_loss: 0.2651842157046
159
+ time: Wed Sep 18 01:46:45 2024
160
+
161
+ Epoch 33
162
+ train_loss: 1.001824217418977
163
+ eval_loss: 0.2630385269721349
164
+ time: Wed Sep 18 03:39:08 2024
165
+
166
+ Epoch 34
167
+ train_loss: 0.9853754720520442
168
+ eval_loss: 0.25253995358943937
169
+ time: Wed Sep 18 05:30:33 2024
170
+
171
+ Epoch 35
172
+ train_loss: 0.9676362536067821
173
+ eval_loss: 0.24096360007921855
174
+ time: Wed Sep 18 07:22:09 2024
175
+
176
+ Epoch 36
177
+ train_loss: 0.9507065269691086
178
+ eval_loss: 0.2413844664891561
179
+ time: Wed Sep 18 09:12:59 2024
180
+
181
+ Epoch 37
182
+ train_loss: 0.9362979678186832
183
+ eval_loss: 0.23412639300028484
184
+ time: Wed Sep 18 11:04:09 2024
185
+
186
+ Epoch 38
187
+ train_loss: 0.9174621180856977
188
+ eval_loss: 0.21386308073997498
189
+ time: Wed Sep 18 12:54:52 2024
190
+
191
+ Epoch 39
192
+ train_loss: 0.9090870427650668
193
+ eval_loss: 0.19962686796983084
194
+ time: Wed Sep 18 14:45:52 2024
195
+
196
+ Epoch 40
197
+ train_loss: 0.8918763521271409
198
+ eval_loss: 0.20026112000147503
199
+ time: Wed Sep 18 16:36:37 2024
200
+
201
+ Epoch 41
202
+ train_loss: 0.8786202421428222
203
+ eval_loss: 0.18366556564966838
204
+ time: Wed Sep 18 18:27:31 2024
205
+
206
+ Epoch 42
207
+ train_loss: 0.8670675420604148
208
+ eval_loss: 0.17908457616964976
209
+ time: Wed Sep 18 20:18:16 2024
210
+
211
+ Epoch 43
212
+ train_loss: 0.8505593872931582
213
+ eval_loss: 0.17053016225496928
214
+ time: Wed Sep 18 22:10:39 2024
215
+
216
+ Epoch 44
217
+ train_loss: 0.8421949260766888
218
+ eval_loss: 0.17344878117243448
219
+ time: Thu Sep 19 00:02:24 2024
220
+
221
+ Epoch 45
222
+ train_loss: 0.8267569324702205
223
+ eval_loss: 0.1591893643140793
224
+ time: Thu Sep 19 01:53:48 2024
225
+
226
+ Epoch 46
227
+ train_loss: 0.8144617894466949
228
+ eval_loss: 0.15313500861326854
229
+ time: Thu Sep 19 03:44:58 2024
230
+
231
+ Epoch 47
232
+ train_loss: 0.8041844731303666
233
+ eval_loss: 0.14998503575722377
234
+ time: Thu Sep 19 05:36:50 2024
235
+
236
+ Epoch 48
237
+ train_loss: 0.7938160687423412
238
+ eval_loss: 0.1401842971642812
239
+ time: Thu Sep 19 07:28:21 2024
240
+
241
+ Epoch 49
242
+ train_loss: 0.7808867423096515
243
+ eval_loss: 0.1368137091398239
244
+ time: Thu Sep 19 09:20:09 2024
245
+
246
+ Epoch 50
247
+ train_loss: 0.7702171771933628
248
+ eval_loss: 0.13333487262328467
249
+ time: Thu Sep 19 11:12:37 2024
250
+
251
+ Epoch 51
252
+ train_loss: 0.7604444062967384
253
+ eval_loss: 0.13119754443566004
254
+ time: Thu Sep 19 13:04:26 2024
255
+
256
+ Epoch 52
257
+ train_loss: 0.7496546459894258
258
+ eval_loss: 0.1236343190073967
259
+ time: Thu Sep 19 14:55:53 2024
260
+
261
+ Epoch 53
262
+ train_loss: 0.7406523988345118
263
+ eval_loss: 0.12237562835216523
264
+ time: Thu Sep 19 16:47:51 2024
265
+
266
+ Epoch 54
267
+ train_loss: 0.7331518270251398
268
+ eval_loss: 0.11441469887892405
269
+ time: Thu Sep 19 18:38:48 2024
270
+
271
+ Epoch 55
272
+ train_loss: 0.7238280263746373
273
+ eval_loss: 0.10651812156041464
274
+ time: Thu Sep 19 20:29:18 2024
275
+
276
+ Epoch 56
277
+ train_loss: 0.7141688125488486
278
+ eval_loss: 0.10959143290917078
279
+ time: Thu Sep 19 22:19:28 2024
280
+
281
+ Epoch 57
282
+ train_loss: 0.7053173944645842
283
+ eval_loss: 0.10957898745934168
284
+ time: Fri Sep 20 00:10:06 2024
285
+
286
+ Epoch 58
287
+ train_loss: 0.6992166797548109
288
+ eval_loss: 0.09759224901596705
289
+ time: Fri Sep 20 02:01:02 2024
290
+
291
+ Epoch 59
292
+ train_loss: 0.6855367768623795
293
+ eval_loss: 0.10631066560745239
294
+ time: Fri Sep 20 03:51:25 2024
295
+
296
+ Epoch 60
297
+ train_loss: 0.6812366953699432
298
+ eval_loss: 0.08681503732999166
299
+ time: Fri Sep 20 05:41:32 2024
300
+
301
+ Epoch 61
302
+ train_loss: 0.6744320154854127
303
+ eval_loss: 0.08995070978999138
304
+ time: Fri Sep 20 07:32:33 2024
305
+
306
+ Epoch 62
307
+ train_loss: 0.6627048003782218
308
+ eval_loss: 0.08492780551314354
309
+ time: Fri Sep 20 09:22:52 2024
310
+
311
+ Epoch 63
312
+ train_loss: 0.6554694614403961
313
+ eval_loss: 0.09110054125388463
314
+ time: Fri Sep 20 11:15:14 2024
315
+
316
+ Epoch 64
317
+ train_loss: 0.6519363358224428
318
+ eval_loss: 0.08603844990332922
319
+ time: Fri Sep 20 13:05:45 2024
320
+
321
+ Epoch 65
322
+ train_loss: 0.6432196787488694
323
+ eval_loss: 0.07920929342508316
324
+ time: Fri Sep 20 14:56:27 2024
325
+
326
+ Epoch 66
327
+ train_loss: 0.6355774498505016
328
+ eval_loss: 0.08108622878789902
329
+ time: Fri Sep 20 16:47:00 2024
330
+
331
+ Epoch 67
332
+ train_loss: 0.628098195042665
333
+ eval_loss: 0.0835166151324908
334
+ time: Fri Sep 20 18:37:19 2024
335
+
336
+ Epoch 68
337
+ train_loss: 0.6229319736150211
338
+ eval_loss: 0.08126899500687917
339
+ time: Fri Sep 20 20:27:49 2024
340
+
341
+ Epoch 69
342
+ train_loss: 0.6162204064685376
343
+ eval_loss: 0.07405624414483707
344
+ time: Fri Sep 20 22:18:28 2024
345
+
346
+ Epoch 70
347
+ train_loss: 0.6093617768645045
348
+ eval_loss: 0.07916868552565574
349
+ time: Sat Sep 21 00:10:02 2024
350
+
351
+ Epoch 71
352
+ train_loss: 0.603765148576412
353
+ eval_loss: 0.07368899683157602
354
+ time: Sat Sep 21 02:00:29 2024
355
+
356
+ Epoch 72
357
+ train_loss: 0.5988557130088281
358
+ eval_loss: 0.06763924509286881
359
+ time: Sat Sep 21 03:51:46 2024
360
+
361
+ Epoch 73
362
+ train_loss: 0.590835969827209
363
+ eval_loss: 0.07139033873875936
364
+ time: Sat Sep 21 05:43:51 2024
365
+
366
+ Epoch 74
367
+ train_loss: 0.5864904869113879
368
+ eval_loss: 0.06859012718002001
369
+ time: Sat Sep 21 07:34:23 2024
370
+
371
+ Epoch 75
372
+ train_loss: 0.5819329118342274
373
+ eval_loss: 0.07611284777522087
374
+ time: Sat Sep 21 09:25:24 2024
375
+
376
+ Epoch 76
377
+ train_loss: 0.5750655913014898
378
+ eval_loss: 0.06813529431819916
379
+ time: Sat Sep 21 11:16:26 2024
380
+
381
+ Epoch 77
382
+ train_loss: 0.5703848759963817
383
+ eval_loss: 0.07192744488517443
384
+ time: Sat Sep 21 13:07:32 2024
385
+
386
+ Epoch 78
387
+ train_loss: 0.5666614368024667
388
+ eval_loss: 0.06931692684690158
389
+ time: Sat Sep 21 14:59:16 2024
390
+
391
+ Epoch 79
392
+ train_loss: 0.5610024514409998
393
+ eval_loss: 0.06487631574273109
394
+ time: Sat Sep 21 16:50:56 2024
395
+
396
+ Epoch 80
397
+ train_loss: 0.5552226794301296
398
+ eval_loss: 0.06034566586216291
399
+ time: Sat Sep 21 18:43:49 2024
400
+
401
+ Epoch 81
402
+ train_loss: 0.5512203840912394
403
+ eval_loss: 0.05962909683585167
404
+ time: Sat Sep 21 20:36:01 2024
405
+
406
+ Epoch 82
407
+ train_loss: 0.5477618443893468
408
+ eval_loss: 0.05546447386344274
409
+ time: Sat Sep 21 22:28:13 2024
410
+
411
+ Epoch 83
412
+ train_loss: 0.5428704522615506
413
+ eval_loss: 0.05013169844945272
414
+ time: Sun Sep 22 00:21:20 2024
415
+
416
+ Epoch 84
417
+ train_loss: 0.5396500316264258
418
+ eval_loss: 0.062498694161574046
419
+ time: Sun Sep 22 02:13:07 2024
420
+
421
+ Epoch 85
422
+ train_loss: 0.5349479554715307
423
+ eval_loss: 0.06073434228698413
424
+ time: Sun Sep 22 04:05:17 2024
425
+
426
+ Epoch 86
427
+ train_loss: 0.5292192482811466
428
+ eval_loss: 0.05734321524699529
429
+ time: Sun Sep 22 05:57:05 2024
430
+
431
+ Epoch 87
432
+ train_loss: 0.5249555090607058
433
+ eval_loss: 0.05274935985604922
434
+ time: Sun Sep 22 07:48:52 2024
435
+
436
+ Epoch 88
437
+ train_loss: 0.523276918144503
438
+ eval_loss: 0.05601314604282379
439
+ time: Sun Sep 22 09:41:05 2024
440
+
441
+ Epoch 89
442
+ train_loss: 0.5179934711230115
443
+ eval_loss: 0.057493301729361214
444
+ time: Sun Sep 22 11:33:47 2024
445
+
446
+ Epoch 90
447
+ train_loss: 0.5129834874146376
448
+ eval_loss: 0.05289425750573476
449
+ time: Sun Sep 22 13:25:54 2024
450
+
451
+ Epoch 91
452
+ train_loss: 0.5104886514866054
453
+ eval_loss: 0.0586332509915034
454
+ time: Sun Sep 22 15:18:13 2024
455
+
456
+ Epoch 92
457
+ train_loss: 0.5067275374282622
458
+ eval_loss: 0.0489634457975626
459
+ time: Sun Sep 22 17:10:39 2024
460
+
461
+ Epoch 93
462
+ train_loss: 0.5038576471461468
463
+ eval_loss: 0.05257208868861198
464
+ time: Sun Sep 22 19:04:46 2024
465
+
466
+ Epoch 94
467
+ train_loss: 0.5013840998762528
468
+ eval_loss: 0.05249967947602272
469
+ time: Sun Sep 22 20:57:55 2024
470
+
471
+ Epoch 95
472
+ train_loss: 0.4949465335763684
473
+ eval_loss: 0.048154672731955846
474
+ time: Sun Sep 22 22:50:30 2024
475
+
476
+ Epoch 96
477
+ train_loss: 0.4925781255166608
478
+ eval_loss: 0.052830965568621956
479
+ time: Mon Sep 23 00:43:13 2024
480
+
481
+ Epoch 97
482
+ train_loss: 0.4875780233282
483
+ eval_loss: 0.04684837857882182
484
+ time: Mon Sep 23 02:35:38 2024
485
+
486
+ Epoch 98
487
+ train_loss: 0.4858591078021573
488
+ eval_loss: 0.04507673804958661
489
+ time: Mon Sep 23 04:28:25 2024
490
+
491
+ Epoch 99
492
+ train_loss: 0.4804891498405977
493
+ eval_loss: 0.048148307204246524
494
+ time: Mon Sep 23 06:21:11 2024
495
+
496
+ Epoch 100
497
+ train_loss: 0.4782898508661265
498
+ eval_loss: 0.044317328557372096
499
+ time: Mon Sep 23 08:13:38 2024
500
+
code/logs_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.txt ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Epoch 1
2
+ train_loss: 0.3055062843047765
3
+ eval_loss: 0.03727418192925584
4
+ time: Wed Aug 7 08:54:27 2024
5
+
6
+ Epoch 2
7
+ train_loss: 0.1718286834194018
8
+ eval_loss: 0.020085958587123625
9
+ time: Wed Aug 7 14:38:40 2024
10
+
11
+ Epoch 3
12
+ train_loss: 0.1476379437283353
13
+ eval_loss: 0.013794219702922342
14
+ time: Wed Aug 7 20:24:50 2024
15
+
16
+ Epoch 4
17
+ train_loss: 0.13554848474242498
18
+ eval_loss: 0.011902455668817844
19
+ time: Thu Aug 8 02:13:14 2024
20
+
21
+ Epoch 5
22
+ train_loss: 0.12781724531702496
23
+ eval_loss: 0.008929020740909163
24
+ time: Thu Aug 8 08:00:32 2024
25
+
26
+ Epoch 6
27
+ train_loss: 0.12264176163121285
28
+ eval_loss: 0.008453744098166877
29
+ time: Thu Aug 8 13:47:12 2024
30
+
31
+ Epoch 7
32
+ train_loss: 0.11872020949762974
33
+ eval_loss: 0.007165850172573819
34
+ time: Thu Aug 8 19:34:09 2024
35
+
36
+ Epoch 8
37
+ train_loss: 0.1153576639058103
38
+ eval_loss: 0.006601243383027142
39
+ time: Fri Aug 9 01:22:19 2024
40
+
41
+ Epoch 9
42
+ train_loss: 0.11312788720856465
43
+ eval_loss: 0.005973297544609645
44
+ time: Fri Aug 9 07:11:25 2024
45
+
46
+ Epoch 10
47
+ train_loss: 0.11096722687304313
48
+ eval_loss: 0.005796642405204946
49
+ time: Fri Aug 9 12:59:13 2024
50
+
51
+ Epoch 11
52
+ train_loss: 0.10913465011501206
53
+ eval_loss: 0.005249736892483845
54
+ time: Fri Aug 9 18:46:14 2024
55
+
56
+ Epoch 12
57
+ train_loss: 0.10780615577682347
58
+ eval_loss: 0.005115668955435858
59
+ time: Sat Aug 10 00:34:28 2024
60
+
61
+ Epoch 13
62
+ train_loss: 0.10650418949283817
63
+ eval_loss: 0.00475350690255028
64
+ time: Sat Aug 10 06:22:51 2024
65
+
66
+ Epoch 14
67
+ train_loss: 0.10524643352798381
68
+ eval_loss: 0.004583307054632575
69
+ time: Sat Aug 10 12:11:18 2024
70
+
71
+ Epoch 15
72
+ train_loss: 0.1041887047117438
73
+ eval_loss: 0.004289783886142609
74
+ time: Sat Aug 10 17:59:48 2024
75
+
76
+ Epoch 16
77
+ train_loss: 0.10343191375801945
78
+ eval_loss: 0.004421581262192111
79
+ time: Sat Aug 10 23:47:37 2024
80
+
81
+ Epoch 17
82
+ train_loss: 0.10256196519161385
83
+ eval_loss: 0.004104017818401634
84
+ time: Sun Aug 11 05:35:49 2024
85
+
86
+ Epoch 18
87
+ train_loss: 0.10170993055767087
88
+ eval_loss: 0.0039769458375234585
89
+ time: Sun Aug 11 11:23:39 2024
90
+
91
+ Epoch 19
92
+ train_loss: 0.1011880517951369
93
+ eval_loss: 0.0039005329324529833
94
+ time: Sun Aug 11 17:11:19 2024
95
+
96
+ Epoch 20
97
+ train_loss: 0.10030771156829077
98
+ eval_loss: 0.0036845325137673237
99
+ time: Sun Aug 11 22:59:49 2024
100
+
101
+ Epoch 21
102
+ train_loss: 0.09972109616302548
103
+ eval_loss: 0.0038503893043940205
104
+ time: Mon Aug 12 04:48:03 2024
105
+
106
+ Epoch 22
107
+ train_loss: 0.09932596696744844
108
+ eval_loss: 0.00370702411211194
109
+ time: Mon Aug 12 10:36:32 2024
110
+
111
+ Epoch 23
112
+ train_loss: 0.09888291950362459
113
+ eval_loss: 0.0034573812171313834
114
+ time: Mon Aug 12 16:24:55 2024
115
+
116
+ Epoch 24
117
+ train_loss: 0.09852503939581284
118
+ eval_loss: 0.003370235667697582
119
+ time: Mon Aug 12 22:12:14 2024
120
+
121
+ Epoch 25
122
+ train_loss: 0.09825884147004627
123
+ eval_loss: 0.00346387299475209
124
+ time: Tue Aug 13 04:00:42 2024
125
+
126
+ Epoch 26
127
+ train_loss: 0.09756856258879791
128
+ eval_loss: 0.0033276399650575615
129
+ time: Tue Aug 13 09:49:22 2024
130
+
131
+ Epoch 27
132
+ train_loss: 0.09730380131801182
133
+ eval_loss: 0.003326884365762399
134
+ time: Tue Aug 13 15:36:05 2024
135
+
136
+ Epoch 28
137
+ train_loss: 0.09687296288584166
138
+ eval_loss: 0.0034621171255395573
139
+ time: Tue Aug 13 21:23:24 2024
140
+
141
+ Epoch 29
142
+ train_loss: 0.09668537175198876
143
+ eval_loss: 0.003284947640647648
144
+ time: Wed Aug 14 03:10:21 2024
145
+
146
+ Epoch 30
147
+ train_loss: 0.09628572566572022
148
+ eval_loss: 0.003119471057549999
149
+ time: Wed Aug 14 08:56:45 2024
150
+
151
+ Epoch 31
152
+ train_loss: 0.09617123452549026
153
+ eval_loss: 0.003124797866062776
154
+ time: Wed Aug 14 14:43:12 2024
155
+
156
+ Epoch 32
157
+ train_loss: 0.09578377932399237
158
+ eval_loss: 0.0030736677601092537
159
+ time: Wed Aug 14 20:31:01 2024
160
+
161
+ Epoch 33
162
+ train_loss: 0.09558304869954821
163
+ eval_loss: 0.003178201471396451
164
+ time: Thu Aug 15 02:19:14 2024
165
+
166
+ Epoch 34
167
+ train_loss: 0.0952804450174092
168
+ eval_loss: 0.0030847328114775225
169
+ time: Thu Aug 15 08:06:29 2024
170
+
171
+ Epoch 35
172
+ train_loss: 0.09513826066486042
173
+ eval_loss: 0.00303873973446682
174
+ time: Thu Aug 15 13:52:17 2024
175
+
176
+ Epoch 36
177
+ train_loss: 0.09466769916316405
178
+ eval_loss: 0.0030122215467611258
179
+ time: Thu Aug 15 19:38:29 2024
180
+
181
+ Epoch 37
182
+ train_loss: 0.09465687754501316
183
+ eval_loss: 0.00289094522015785
184
+ time: Fri Aug 16 01:25:14 2024
185
+
186
+ Epoch 38
187
+ train_loss: 0.09435585222324992
188
+ eval_loss: 0.0030173959307773393
189
+ time: Fri Aug 16 07:11:56 2024
190
+
191
+ Epoch 39
192
+ train_loss: 0.09413478592045794
193
+ eval_loss: 0.002968058454507435
194
+ time: Fri Aug 16 12:59:16 2024
195
+
196
+ Epoch 40
197
+ train_loss: 0.09393180562734375
198
+ eval_loss: 0.0030673167865746948
199
+ time: Fri Aug 16 18:45:23 2024
200
+
201
+ Epoch 41
202
+ train_loss: 0.09365266143982799
203
+ eval_loss: 0.00287582161187937
204
+ time: Sat Aug 17 00:31:47 2024
205
+
206
+ Epoch 42
207
+ train_loss: 0.09359205519747489
208
+ eval_loss: 0.0027280030162997134
209
+ time: Sat Aug 17 06:18:32 2024
210
+
211
+ Epoch 43
212
+ train_loss: 0.09349238520961266
213
+ eval_loss: 0.0029261269570300787
214
+ time: Sat Aug 17 12:05:26 2024
215
+
216
+ Epoch 44
217
+ train_loss: 0.09324607778116949
218
+ eval_loss: 0.002691730654519444
219
+ time: Sat Aug 17 17:52:20 2024
220
+
221
+ Epoch 45
222
+ train_loss: 0.09310021795996155
223
+ eval_loss: 0.0028863806760858132
224
+ time: Sat Aug 17 23:38:57 2024
225
+
226
+ Epoch 46
227
+ train_loss: 0.09307358593283441
228
+ eval_loss: 0.002793597210717352
229
+ time: Sun Aug 18 05:25:42 2024
230
+
231
+ Epoch 47
232
+ train_loss: 0.09299390690766882
233
+ eval_loss: 0.0027052024821456098
234
+ time: Sun Aug 18 11:12:32 2024
235
+
236
+ Epoch 48
237
+ train_loss: 0.09253486422624911
238
+ eval_loss: 0.0027312307396534247
239
+ time: Sun Aug 18 16:59:16 2024
240
+
241
+ Epoch 49
242
+ train_loss: 0.09243107154309635
243
+ eval_loss: 0.002648197562936772
244
+ time: Sun Aug 18 22:46:41 2024
245
+
246
+ Epoch 50
247
+ train_loss: 0.09237845186490301
248
+ eval_loss: 0.0026844193827840284
249
+ time: Mon Aug 19 04:35:10 2024
250
+
251
+ Epoch 51
252
+ train_loss: 0.09231985249015236
253
+ eval_loss: 0.002708845011956738
254
+ time: Mon Aug 19 10:24:17 2024
255
+
256
+ Epoch 52
257
+ train_loss: 0.0922615721153286
258
+ eval_loss: 0.0035362059711223225
259
+ time: Mon Aug 19 16:11:39 2024
260
+
261
+ Epoch 53
262
+ train_loss: 0.09200190843071623
263
+ eval_loss: 0.0025848455890180064
264
+ time: Mon Aug 19 21:58:31 2024
265
+
266
+ Epoch 54
267
+ train_loss: 0.09200848002425245
268
+ eval_loss: 0.0026311414897881983
269
+ time: Tue Aug 20 03:45:36 2024
270
+
271
+ Epoch 55
272
+ train_loss: 0.09154813869071807
273
+ eval_loss: 0.0025586662145983823
274
+ time: Tue Aug 20 09:34:48 2024
275
+
276
+ Epoch 56
277
+ train_loss: 0.09162745474034129
278
+ eval_loss: 0.0026280648907143545
279
+ time: Tue Aug 20 15:23:23 2024
280
+
281
+ Epoch 57
282
+ train_loss: 0.09156280245772795
283
+ eval_loss: 0.002539119078534093
284
+ time: Tue Aug 20 21:11:25 2024
285
+
286
+ Epoch 58
287
+ train_loss: 0.09142590950099329
288
+ eval_loss: 0.0026369429265152866
289
+ time: Wed Aug 21 02:59:19 2024
290
+
291
+ Epoch 59
292
+ train_loss: 0.09139848643851392
293
+ eval_loss: 0.0024354966580356916
294
+ time: Wed Aug 21 08:46:23 2024
295
+
296
+ Epoch 60
297
+ train_loss: 0.09131192888740647
298
+ eval_loss: 0.0024594995301248277
299
+ time: Wed Aug 21 14:33:28 2024
300
+
301
+ Epoch 61
302
+ train_loss: 0.09122042933562911
303
+ eval_loss: 0.002616936316367883
304
+ time: Wed Aug 21 20:20:57 2024
305
+
306
+ Epoch 62
307
+ train_loss: 0.09109125168796305
308
+ eval_loss: 0.0025555431279884297
309
+ time: Thu Aug 22 02:08:45 2024
310
+
311
+ Epoch 63
312
+ train_loss: 0.09106527324403817
313
+ eval_loss: 0.0025145284593781213
314
+ time: Thu Aug 22 07:56:26 2024
315
+
316
+ Epoch 64
317
+ train_loss: 0.09095406525682191
318
+ eval_loss: 0.0025151555842959678
319
+ time: Thu Aug 22 13:45:57 2024
320
+
321
+ Epoch 65
322
+ train_loss: 0.09102793501718281
323
+ eval_loss: 0.0024135450126194563
324
+ time: Thu Aug 22 19:54:28 2024
325
+
326
+ Epoch 66
327
+ train_loss: 0.0908411063853937
328
+ eval_loss: 0.002460922076728368
329
+ time: Fri Aug 23 01:59:41 2024
330
+
331
+ Epoch 67
332
+ train_loss: 0.09070221083785855
333
+ eval_loss: 0.002453409551882543
334
+ time: Fri Aug 23 07:52:30 2024
335
+
336
+ Epoch 68
337
+ train_loss: 0.0906545008953897
338
+ eval_loss: 0.0024080786435031784
339
+ time: Fri Aug 23 13:41:28 2024
340
+
341
+ Epoch 69
342
+ train_loss: 0.0907353380525871
343
+ eval_loss: 0.0024573436347799147
344
+ time: Fri Aug 23 19:27:14 2024
345
+
346
+ Epoch 70
347
+ train_loss: 0.09040538104085095
348
+ eval_loss: 0.0023765437401249566
349
+ time: Sat Aug 24 01:14:45 2024
350
+
351
+ Epoch 71
352
+ train_loss: 0.09036114065518137
353
+ eval_loss: 0.0023877528348234226
354
+ time: Sat Aug 24 07:02:04 2024
355
+
356
+ Epoch 72
357
+ train_loss: 0.09037455027205546
358
+ eval_loss: 0.002315233082103814
359
+ time: Sat Aug 24 12:49:24 2024
360
+
361
+ Epoch 73
362
+ train_loss: 0.09026183628343257
363
+ eval_loss: 0.0024284060419643228
364
+ time: Sat Aug 24 18:35:36 2024
365
+
366
+ Epoch 74
367
+ train_loss: 0.09019025581511034
368
+ eval_loss: 0.002393116130206718
369
+ time: Sun Aug 25 00:21:29 2024
370
+
371
+ Epoch 75
372
+ train_loss: 0.089901714783446
373
+ eval_loss: 0.002298152916632467
374
+ time: Sun Aug 25 06:08:01 2024
375
+
376
+ Epoch 76
377
+ train_loss: 0.09018262871273484
378
+ eval_loss: 0.002273971366672482
379
+ time: Sun Aug 25 11:54:02 2024
380
+
381
+ Epoch 77
382
+ train_loss: 0.08998425874228
383
+ eval_loss: 0.002317420323379338
384
+ time: Sun Aug 25 17:44:05 2024
385
+
386
+ Epoch 78
387
+ train_loss: 0.08983653943919646
388
+ eval_loss: 0.0024391192159878743
389
+ time: Sun Aug 25 23:31:34 2024
390
+
391
+ Epoch 79
392
+ train_loss: 0.08981405456901183
393
+ eval_loss: 0.002319374949895317
394
+ time: Mon Aug 26 05:24:56 2024
395
+
396
+ Epoch 80
397
+ train_loss: 0.08974534569690559
398
+ eval_loss: 0.0023008979344151066
399
+ time: Mon Aug 26 11:28:33 2024
400
+
401
+ Epoch 81
402
+ train_loss: 0.08972110153310983
403
+ eval_loss: 0.002406696710865237
404
+ time: Mon Aug 26 17:33:30 2024
405
+
406
+ Epoch 82
407
+ train_loss: 0.0895689915361898
408
+ eval_loss: 0.002241936448434926
409
+ time: Mon Aug 26 23:39:15 2024
410
+
411
+ Epoch 83
412
+ train_loss: 0.08950625452328584
413
+ eval_loss: 0.002408353965493697
414
+ time: Tue Aug 27 05:37:59 2024
415
+
416
+ Epoch 84
417
+ train_loss: 0.08959725393084628
418
+ eval_loss: 0.0023435966142665455
419
+ time: Tue Aug 27 11:34:29 2024
420
+
421
+ Epoch 85
422
+ train_loss: 0.08970333726515986
423
+ eval_loss: 0.0023965956810233086
424
+ time: Tue Aug 27 17:27:31 2024
425
+
426
+ Epoch 86
427
+ train_loss: 0.08948115523227308
428
+ eval_loss: 0.002325803569256709
429
+ time: Tue Aug 27 23:19:43 2024
430
+
431
+ Epoch 87
432
+ train_loss: 0.08933937688654775
433
+ eval_loss: 0.0023552257988114647
434
+ time: Wed Aug 28 05:11:34 2024
435
+
436
+ Epoch 88
437
+ train_loss: 0.08938353908107184
438
+ eval_loss: 0.0024397599904794043
439
+ time: Wed Aug 28 11:01:23 2024
440
+
441
+ Epoch 89
442
+ train_loss: 0.08921640703096091
443
+ eval_loss: 0.002223708766084243
444
+ time: Wed Aug 28 16:50:21 2024
445
+
446
+ Epoch 90
447
+ train_loss: 0.08929300930090782
448
+ eval_loss: 0.0022849828316260303
449
+ time: Wed Aug 28 22:38:53 2024
450
+
451
+ Epoch 91
452
+ train_loss: 0.08910525214309825
453
+ eval_loss: 0.0022257193633186227
454
+ time: Thu Aug 29 04:35:16 2024
455
+
456
+ Epoch 92
457
+ train_loss: 0.08905495976636461
458
+ eval_loss: 0.0022299331251850137
459
+ time: Thu Aug 29 10:29:46 2024
460
+
461
+ Epoch 93
462
+ train_loss: 0.08890526102100955
463
+ eval_loss: 0.0022962711695463786
464
+ time: Thu Aug 29 16:23:49 2024
465
+
466
+ Epoch 94
467
+ train_loss: 0.08908289874104246
468
+ eval_loss: 0.002243622880820028
469
+ time: Thu Aug 29 22:15:42 2024
470
+
471
+ Epoch 95
472
+ train_loss: 0.08908785978677156
473
+ eval_loss: 0.0022457318524397784
474
+ time: Fri Aug 30 04:06:57 2024
475
+
476
+ Epoch 96
477
+ train_loss: 0.08888098475318565
478
+ eval_loss: 0.002224675611787346
479
+ time: Fri Aug 30 09:58:43 2024
480
+
481
+ Epoch 97
482
+ train_loss: 0.08888529259134526
483
+ eval_loss: 0.0021844924980664493
484
+ time: Fri Aug 30 15:50:16 2024
485
+
486
+ Epoch 98
487
+ train_loss: 0.08885388837534758
488
+ eval_loss: 0.0022109088076294288
489
+ time: Fri Aug 30 21:41:59 2024
490
+
491
+ Epoch 99
492
+ train_loss: 0.08873902663868657
493
+ eval_loss: 0.0022606451996653202
494
+ time: Sat Aug 31 03:34:46 2024
495
+
496
+ Epoch 100
497
+ train_loss: 0.08877080098666765
498
+ eval_loss: 0.002279470525367602
499
+ time: Sat Aug 31 09:28:38 2024
500
+
code/train_clamp2.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import wandb
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ from utils import *
9
+ from config import *
10
+ from tqdm import tqdm
11
+ from copy import deepcopy
12
+ import torch.distributed as dist
13
+ from torch.amp import autocast, GradScaler
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torch.utils.data.distributed import DistributedSampler
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from transformers import AutoTokenizer, BertConfig, get_constant_schedule_with_warmup
18
+
19
+ def list_files_in_json(json_path):
20
+ file_list = []
21
+
22
+ if os.path.exists(json_path):
23
+ with open(json_path, 'r', encoding='utf-8') as f:
24
+ for line in f:
25
+ item = json.loads(line)
26
+ file_list.append(item)
27
+
28
+ return file_list
29
+
30
+ def collate_batch(batch):
31
+ text_inputs, text_masks, music_inputs, music_masks = zip(*batch)
32
+
33
+ text_inputs = torch.stack(text_inputs)
34
+ text_masks = torch.stack(text_masks)
35
+ music_inputs = torch.stack(music_inputs)
36
+ music_masks = torch.stack(music_masks)
37
+
38
+ return text_inputs, text_masks, music_inputs, music_masks
39
+
40
+ class TextMusicDataset(Dataset):
41
+ def __init__(self, items, mode):
42
+ print("The number of "+mode+" data: "+str(len(items)))
43
+ self.items = items
44
+ self.mode = mode
45
+ if self.mode == 'train' or not EVAL_JSONL:
46
+ self.datapath = os.path.dirname(TRAIN_JSONL)
47
+ elif self.mode == 'eval':
48
+ self.datapath = os.path.dirname(EVAL_JSONL)
49
+
50
+ def text_dropout(self, item):
51
+ if random.random() < 0.5:
52
+ candidates = []
53
+ for key in item.keys():
54
+ if key not in ["summary_en", "summary_nen", "filepaths"]:
55
+ if item[key] == None:
56
+ continue
57
+ elif isinstance(item[key], str):
58
+ candidates.append(item[key])
59
+ elif isinstance(item[key], list):
60
+ candidates.extend(item[key])
61
+ candidates = list(set(candidates))
62
+ candidates = "\n".join(candidates)
63
+ candidates = candidates.split("\n")
64
+ selected_candidates = [c for c in candidates if len(c) > 0 and random.random() < 0.5]
65
+ if len(selected_candidates) == 0:
66
+ selected_candidates = candidates
67
+ random.shuffle(selected_candidates)
68
+ text = tokenizer.sep_token.join(selected_candidates)
69
+ else:
70
+ if random.random() < 0.5:
71
+ text = random.choice(item["summary_en"])
72
+ else:
73
+ text = random.choice(item["summary_nen"])["summary"]
74
+
75
+ return text
76
+
77
+ def random_truncate(self, input_tensor, max_length):
78
+ choices = ["head", "tail", "middle"]
79
+ choice = random.choice(choices)
80
+ if choice == "head" or self.mode == 'eval':
81
+ input_tensor = input_tensor[:max_length]
82
+ elif choice == "tail":
83
+ input_tensor = input_tensor[-max_length:]
84
+ elif choice == "middle":
85
+ start = random.randint(1, input_tensor.size(0)-max_length)
86
+ input_tensor = input_tensor[start:start+max_length]
87
+
88
+ return input_tensor
89
+
90
+ def __len__(self):
91
+ return len(self.items)
92
+
93
+ def __getitem__(self, idx):
94
+ item = self.items[idx]
95
+
96
+ # randomly select text from the item
97
+ if self.mode == 'train' and TEXT_DROPOUT:
98
+ text = self.text_dropout(item)
99
+ else:
100
+ text = item["summary_en"][0]
101
+
102
+ # tokenize text and build mask for text tokens
103
+ text_inputs = tokenizer(text, return_tensors='pt')
104
+ text_inputs = text_inputs['input_ids'].squeeze(0)
105
+ if text_inputs.size(0) > MAX_TEXT_LENGTH:
106
+ text_inputs = self.random_truncate(text_inputs, MAX_TEXT_LENGTH)
107
+ text_masks = torch.ones(text_inputs.size(0))
108
+
109
+ # load music file
110
+ if self.mode == 'train':
111
+ filepath = random.choice(item["filepaths"])
112
+ else:
113
+ if item["filepaths"][0].endswith(".abc"):
114
+ filepath = item["filepaths"][0]
115
+ else:
116
+ filepath = item["filepaths"][1]
117
+ filepath = self.datapath + '/' + filepath
118
+
119
+ with open(filepath, "r", encoding="utf-8") as f:
120
+ item = f.read().replace("L:1/8\n", "") if filepath.endswith(".abc") else f.read()
121
+
122
+ # randomly remove instrument info from the music file
123
+ if random.random() < 0.9 and self.mode == 'train':
124
+ item = remove_instrument_info(item)
125
+
126
+ # mask music inputs
127
+ music_inputs = patchilizer.encode(item, add_special_patches=True, truncate=True, random_truncate=(self.mode=="train"))
128
+ music_inputs = torch.tensor(music_inputs)
129
+ music_masks = torch.ones(music_inputs.size(0))
130
+
131
+ # pad text inputs and masks
132
+ pad_indices = torch.ones(MAX_TEXT_LENGTH - text_inputs.size(0)).long() * tokenizer.pad_token_id
133
+ text_inputs = torch.cat((text_inputs, pad_indices), 0)
134
+ text_masks = torch.cat((text_masks, torch.zeros(MAX_TEXT_LENGTH - text_masks.size(0))), 0)
135
+
136
+ # pad music inputs and masks
137
+ pad_indices = torch.ones((PATCH_LENGTH - music_inputs.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
138
+ music_inputs = torch.cat((music_inputs, pad_indices), 0)
139
+ music_masks = torch.cat((music_masks, torch.zeros(PATCH_LENGTH - music_masks.size(0))), 0)
140
+
141
+ return text_inputs, text_masks, music_inputs, music_masks
142
+
143
+ # call model with a batch of input
144
+ def process_one_batch(batch):
145
+ text_inputs, text_masks, music_inputs, music_masks = batch
146
+
147
+ loss = model(text_inputs,
148
+ text_masks,
149
+ music_inputs,
150
+ music_masks)
151
+
152
+ # Reduce the loss on GPU 0
153
+ if world_size > 1:
154
+ loss = loss.unsqueeze(0)
155
+ dist.reduce(loss, dst=0)
156
+ loss = loss / world_size
157
+ dist.broadcast(loss, src=0)
158
+
159
+ return loss.mean()
160
+
161
+ # do one epoch for training
162
+ def train_epoch(epoch):
163
+ tqdm_train_set = tqdm(train_set)
164
+ total_train_loss = 0
165
+ iter_idx = 1
166
+ model.train()
167
+ train_steps = (epoch-1)*len(train_set)
168
+
169
+ for batch in tqdm_train_set:
170
+ with autocast(device_type='cuda'):
171
+ loss = process_one_batch(batch)
172
+ scaler.scale(loss).backward()
173
+ total_train_loss += loss.item()
174
+ scaler.step(optimizer)
175
+ scaler.update()
176
+
177
+ lr_scheduler.step()
178
+ model.zero_grad(set_to_none=True)
179
+ tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
180
+ train_steps += 1
181
+
182
+ # Log the training loss to wandb
183
+ if global_rank==0 and CLAMP2_WANDB_LOG:
184
+ wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
185
+
186
+ iter_idx += 1
187
+
188
+ return total_train_loss / (iter_idx-1)
189
+
190
+ # do one epoch for eval
191
+ def eval_epoch():
192
+ tqdm_eval_set = tqdm(eval_set)
193
+ total_eval_loss = 0
194
+ iter_idx = 1
195
+ model.eval()
196
+
197
+ # Evaluate data for one epoch
198
+ for batch in tqdm_eval_set:
199
+ with torch.no_grad():
200
+ loss = process_one_batch(batch)
201
+
202
+ total_eval_loss += loss.item()
203
+ tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
204
+ iter_idx += 1
205
+
206
+ return total_eval_loss / (iter_idx-1)
207
+
208
+ # train and eval
209
+ if __name__ == "__main__":
210
+
211
+ # Set up distributed training
212
+ world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
213
+ global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
214
+ local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
215
+
216
+ if world_size > 1:
217
+ torch.cuda.set_device(local_rank)
218
+ device = torch.device("cuda", local_rank)
219
+ dist.init_process_group(backend='nccl')
220
+ else:
221
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
222
+
223
+ if CLAMP2_DETERMINISTIC:
224
+ seed = 42 + global_rank
225
+ random.seed(seed)
226
+ np.random.seed(seed)
227
+ torch.manual_seed(seed)
228
+ torch.cuda.manual_seed_all(seed)
229
+ torch.backends.cudnn.deterministic = True
230
+ torch.backends.cudnn.benchmark = False
231
+
232
+ m3_config = BertConfig(vocab_size=1,
233
+ hidden_size=M3_HIDDEN_SIZE,
234
+ num_hidden_layers=PATCH_NUM_LAYERS,
235
+ num_attention_heads=M3_HIDDEN_SIZE//64,
236
+ intermediate_size=M3_HIDDEN_SIZE*4,
237
+ max_position_embeddings=PATCH_LENGTH)
238
+ model = CLaMP2Model(m3_config,
239
+ global_rank,
240
+ world_size,
241
+ TEXT_MODEL_NAME,
242
+ CLAMP2_HIDDEN_SIZE,
243
+ CLAMP2_LOAD_M3)
244
+ model = model.to(device)
245
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
246
+ patchilizer = M3Patchilizer()
247
+
248
+ # print parameter number
249
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
250
+
251
+ if world_size > 1:
252
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
253
+
254
+ scaler = GradScaler()
255
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CLAMP2_LEARNING_RATE)
256
+
257
+ if CLAMP2_WANDB_LOG and global_rank==0:
258
+ # Initialize wandb
259
+ if WANDB_KEY:
260
+ wandb.login(key=WANDB_KEY)
261
+ wandb.init(project="clamp2",
262
+ name=CLAMP2_WEIGHTS_PATH.replace("weights_", "").replace(".pth", ""))
263
+
264
+ # load filenames under train and eval folder
265
+ train_files = list_files_in_json(TRAIN_JSONL)
266
+ eval_files = list_files_in_json(EVAL_JSONL)
267
+
268
+ if len(eval_files)==0:
269
+ train_files, eval_files = split_data(train_files)
270
+
271
+ train_batch_nums = int(len(train_files) / CLAMP2_BATCH_SIZE)
272
+ eval_batch_nums = int(len(eval_files) / CLAMP2_BATCH_SIZE)
273
+
274
+ train_files = train_files[:train_batch_nums*CLAMP2_BATCH_SIZE]
275
+ eval_files = eval_files[:eval_batch_nums*CLAMP2_BATCH_SIZE]
276
+
277
+ train_set = TextMusicDataset(train_files, 'train')
278
+ eval_set = TextMusicDataset(eval_files, 'eval')
279
+
280
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank)
281
+ eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank)
282
+
283
+ train_set = DataLoader(train_set, batch_size=CLAMP2_BATCH_SIZE, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
284
+ eval_set = DataLoader(eval_set, batch_size=CLAMP2_BATCH_SIZE, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
285
+
286
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = 1000)
287
+
288
+ if CLAMP2_LOAD_CKPT and os.path.exists(CLAMP2_WEIGHTS_PATH):
289
+ # Load checkpoint to CPU
290
+ checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
291
+
292
+ # Here, model is assumed to be on GPU
293
+ # Load state dict to CPU model first, then move the model to GPU
294
+ if torch.cuda.device_count() > 1:
295
+ # If you have a DataParallel model, you need to load to model.module instead
296
+ cpu_model = deepcopy(model.module)
297
+ cpu_model.load_state_dict(checkpoint['model'])
298
+ model.module.load_state_dict(cpu_model.state_dict())
299
+ else:
300
+ # Load to a CPU clone of the model, then load back
301
+ cpu_model = deepcopy(model)
302
+ cpu_model.load_state_dict(checkpoint['model'])
303
+ model.load_state_dict(cpu_model.state_dict())
304
+ optimizer.load_state_dict(checkpoint['optimizer'])
305
+ lr_scheduler.load_state_dict(checkpoint['lr_sched'])
306
+ pre_epoch = checkpoint['epoch']
307
+ best_epoch = checkpoint['best_epoch']
308
+ min_eval_loss = checkpoint['min_eval_loss']
309
+ print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
310
+ checkpoint = None
311
+
312
+ else:
313
+ pre_epoch = 0
314
+ best_epoch = 0
315
+ min_eval_loss = float('inf')
316
+
317
+ model = model.to(device)
318
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CLAMP2_LEARNING_RATE)
319
+
320
+ for epoch in range(1+pre_epoch, CLAMP2_NUM_EPOCH+1):
321
+ train_sampler.set_epoch(epoch)
322
+ eval_sampler.set_epoch(epoch)
323
+ print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
324
+ train_loss = train_epoch(epoch)
325
+ eval_loss = eval_epoch()
326
+ if global_rank==0:
327
+ with open(CLAMP2_LOGS_PATH,'a') as f:
328
+ f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
329
+ if eval_loss < min_eval_loss:
330
+ best_epoch = epoch
331
+ min_eval_loss = eval_loss
332
+ checkpoint = {
333
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
334
+ 'optimizer': optimizer.state_dict(),
335
+ 'lr_sched': lr_scheduler.state_dict(),
336
+ 'epoch': epoch,
337
+ 'best_epoch': best_epoch,
338
+ 'min_eval_loss': min_eval_loss
339
+ }
340
+ torch.save(checkpoint, CLAMP2_WEIGHTS_PATH)
341
+ checkpoint = {
342
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
343
+ 'optimizer': optimizer.state_dict(),
344
+ 'lr_sched': lr_scheduler.state_dict(),
345
+ 'epoch': epoch,
346
+ 'best_epoch': best_epoch,
347
+ 'min_eval_loss': min_eval_loss
348
+ }
349
+ torch.save(checkpoint, "latest_"+CLAMP2_WEIGHTS_PATH)
350
+
351
+ if world_size > 1:
352
+ dist.barrier()
353
+
354
+ if global_rank==0:
355
+ print("Best Eval Epoch : "+str(best_epoch))
356
+ print("Min Eval Loss : "+str(min_eval_loss))
code/train_m3.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import time
4
+ import wandb
5
+ import torch
6
+ import random
7
+ import weakref
8
+ import numpy as np
9
+ from utils import *
10
+ from config import *
11
+ from tqdm import tqdm
12
+ from copy import deepcopy
13
+ import torch.distributed as dist
14
+ from torch.amp import autocast, GradScaler
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ from transformers import BertConfig, GPT2Config, get_constant_schedule_with_warmup
19
+
20
+ patchilizer = M3Patchilizer()
21
+
22
+ def clear_unused_tensors():
23
+ gc.disable() # Temporarily disable garbage collection
24
+ try:
25
+ # Get the set of tensor ids used by the model
26
+ if hasattr(model, "module"):
27
+ model_tensors = {id(p) for p in model.module.parameters()}
28
+ else:
29
+ model_tensors = {id(p) for p in model.parameters()}
30
+
31
+ # Get the set of tensor ids used by the optimizer
32
+ optimizer_tensors = {
33
+ id(state)
34
+ for state_dict in optimizer.state.values()
35
+ for state in state_dict.values()
36
+ if isinstance(state, torch.Tensor) # Ensure only tensors are considered
37
+ }
38
+
39
+ # List of all CUDA tensors currently in memory
40
+ tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
41
+
42
+ # Create weak references to avoid interfering with garbage collection
43
+ tensor_refs = [weakref.ref(tensor) for tensor in tensors]
44
+
45
+ for tensor_ref in tensor_refs:
46
+ tensor = tensor_ref() # Dereference the weak reference
47
+ if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
48
+ # Mark the tensor for deletion
49
+ tensor.detach_() # Detach from computation graph
50
+ del tensor # Delete the tensor reference
51
+ except:
52
+ pass
53
+
54
+ finally:
55
+ gc.enable() # Re-enable garbage collection
56
+ gc.collect() # Force a garbage collection
57
+ torch.cuda.empty_cache() # Clear the CUDA cache
58
+
59
+ def list_files_in_directory(directories, extensions=["abc", "mtf"]):
60
+ file_list = []
61
+
62
+ for directory in directories:
63
+ for root, dirs, files in os.walk(directory):
64
+ for file in files:
65
+ if any(file.endswith(ext) for ext in extensions):
66
+ file_path = os.path.join(root, file)
67
+ file_list.append(file_path)
68
+
69
+ return file_list
70
+
71
+ def collate_batch(batch):
72
+ input_patches, input_masks, selected_indices, target_patches = zip(*batch)
73
+
74
+ input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=patchilizer.pad_token_id)
75
+ input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
76
+ selected_indices = torch.nn.utils.rnn.pad_sequence(selected_indices, batch_first=True, padding_value=0)
77
+ target_patches = torch.nn.utils.rnn.pad_sequence(target_patches, batch_first=True, padding_value=patchilizer.pad_token_id)
78
+
79
+ return input_patches, input_masks, selected_indices, target_patches
80
+
81
+ class M3Dataset(Dataset):
82
+ def __init__(self, filenames, mode):
83
+ print("The number of "+mode+" data: "+str(len(filenames)))
84
+ self.filenames = filenames
85
+ self.mode = mode
86
+
87
+ def __len__(self):
88
+ return len(self.filenames)
89
+
90
+ def __getitem__(self, idx):
91
+ filename = self.filenames[idx]
92
+ try:
93
+ with open(filename, "r", encoding="utf-8") as f:
94
+ item = f.read().replace("L:1/8\n", "") if filename.endswith(".abc") else f.read()
95
+ except Exception as e:
96
+ print(e)
97
+ print("Failed to load: "+filename)
98
+ item = ""
99
+
100
+ target_patches = patchilizer.encode(item, add_special_patches=True, truncate=True, random_truncate=(self.mode=="train"))
101
+ input_masks = torch.tensor([1]*len(target_patches))
102
+ input_patches, selected_indices = mask_patches(target_patches, patchilizer, self.mode)
103
+ input_patches = input_patches.reshape(-1)
104
+ target_patches = torch.tensor(target_patches).reshape(-1)
105
+ return input_patches, input_masks, selected_indices, target_patches
106
+
107
+ # call model with a batch of input
108
+ def process_one_batch(batch):
109
+ input_patches, input_masks, selected_indices, target_patches = batch
110
+
111
+ loss = model(input_patches,
112
+ input_masks,
113
+ selected_indices,
114
+ target_patches).loss
115
+
116
+ # Reduce the loss on GPU 0
117
+ if world_size > 1:
118
+ loss = loss.unsqueeze(0)
119
+ dist.reduce(loss, dst=0)
120
+ loss = loss / world_size
121
+ dist.broadcast(loss, src=0)
122
+
123
+ return loss.mean()
124
+
125
+ # do one epoch for training
126
+ def train_epoch(epoch):
127
+ tqdm_train_set = tqdm(train_set)
128
+ total_train_loss = 0
129
+ iter_idx = 1
130
+ model.train()
131
+ train_steps = (epoch-1)*len(train_set)
132
+
133
+ for batch in tqdm_train_set:
134
+ with autocast(device_type='cuda'):
135
+ loss = process_one_batch(batch)
136
+ scaler.scale(loss).backward()
137
+ total_train_loss += loss.item()
138
+ scaler.step(optimizer)
139
+ scaler.update()
140
+
141
+ lr_scheduler.step()
142
+ model.zero_grad(set_to_none=True)
143
+ tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
144
+ train_steps += 1
145
+
146
+ # Log the training loss to wandb
147
+ if global_rank==0 and M3_WANDB_LOG:
148
+ wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
149
+
150
+ iter_idx += 1
151
+ if iter_idx % 1000 == 0:
152
+ clear_unused_tensors()
153
+
154
+ return total_train_loss / (iter_idx-1)
155
+
156
+ # do one epoch for eval
157
+ def eval_epoch():
158
+ tqdm_eval_set = tqdm(eval_set)
159
+ total_eval_loss = 0
160
+ iter_idx = 1
161
+ model.eval()
162
+
163
+ # Evaluate data for one epoch
164
+ for batch in tqdm_eval_set:
165
+ with torch.no_grad():
166
+ loss = process_one_batch(batch)
167
+
168
+ total_eval_loss += loss.item()
169
+ tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
170
+ iter_idx += 1
171
+
172
+ return total_eval_loss / (iter_idx-1)
173
+
174
+ # train and eval
175
+ if __name__ == "__main__":
176
+
177
+ # Set up distributed training
178
+ world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
179
+ global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
180
+ local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
181
+
182
+ if world_size > 1:
183
+ torch.cuda.set_device(local_rank)
184
+ device = torch.device("cuda", local_rank)
185
+ dist.init_process_group(backend='nccl')
186
+ else:
187
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
188
+
189
+ if M3_DETERMINISTIC:
190
+ seed = 42 + global_rank
191
+ random.seed(seed)
192
+ np.random.seed(seed)
193
+ torch.manual_seed(seed)
194
+ torch.cuda.manual_seed_all(seed)
195
+ torch.backends.cudnn.deterministic = True
196
+ torch.backends.cudnn.benchmark = False
197
+
198
+ encoder_config = BertConfig(vocab_size=1,
199
+ hidden_size=M3_HIDDEN_SIZE,
200
+ num_hidden_layers=PATCH_NUM_LAYERS,
201
+ num_attention_heads=M3_HIDDEN_SIZE//64,
202
+ intermediate_size=M3_HIDDEN_SIZE*4,
203
+ max_position_embeddings=PATCH_LENGTH)
204
+ decoder_config = GPT2Config(vocab_size=128,
205
+ n_positions=PATCH_SIZE,
206
+ n_embd=M3_HIDDEN_SIZE,
207
+ n_layer=TOKEN_NUM_LAYERS,
208
+ n_head=M3_HIDDEN_SIZE//64,
209
+ n_inner=M3_HIDDEN_SIZE*4)
210
+ model = M3Model(encoder_config, decoder_config)
211
+ model = model.to(device)
212
+
213
+ # print parameter number
214
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
215
+
216
+ if world_size > 1:
217
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
218
+
219
+ scaler = GradScaler()
220
+ optimizer = torch.optim.AdamW(model.parameters(), lr=M3_LEARNING_RATE)
221
+
222
+ if M3_WANDB_LOG and global_rank==0:
223
+ # Initialize wandb
224
+ if WANDB_KEY:
225
+ wandb.login(key=WANDB_KEY)
226
+ wandb.init(project="m3",
227
+ name=M3_WEIGHTS_PATH.replace("weights_", "").replace(".pth", ""))
228
+
229
+ # load filenames under train and eval folder
230
+ train_files = list_files_in_directory(TRAIN_FOLDERS)
231
+ eval_files = list_files_in_directory(EVAL_FOLDERS)
232
+
233
+ if len(eval_files)==0:
234
+ train_files, eval_files = split_data(train_files)
235
+
236
+ train_batch_nums = int(len(train_files) / M3_BATCH_SIZE)
237
+ eval_batch_nums = int(len(eval_files) / M3_BATCH_SIZE)
238
+
239
+ train_files = train_files[:train_batch_nums*M3_BATCH_SIZE]
240
+ eval_files = eval_files[:eval_batch_nums*M3_BATCH_SIZE]
241
+
242
+ train_set = M3Dataset(train_files, 'train')
243
+ eval_set = M3Dataset(eval_files, 'eval')
244
+
245
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank)
246
+ eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank)
247
+
248
+ train_set = DataLoader(train_set, batch_size=M3_BATCH_SIZE, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
249
+ eval_set = DataLoader(eval_set, batch_size=M3_BATCH_SIZE, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
250
+
251
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = 1000)
252
+
253
+ if M3_LOAD_CKPT and os.path.exists(M3_WEIGHTS_PATH):
254
+ # Load checkpoint to CPU
255
+ checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
256
+
257
+ # Here, model is assumed to be on GPU
258
+ # Load state dict to CPU model first, then move the model to GPU
259
+ if torch.cuda.device_count() > 1:
260
+ # If you have a DataParallel model, you need to load to model.module instead
261
+ cpu_model = deepcopy(model.module)
262
+ cpu_model.load_state_dict(checkpoint['model'])
263
+ model.module.load_state_dict(cpu_model.state_dict())
264
+ else:
265
+ # Load to a CPU clone of the model, then load back
266
+ cpu_model = deepcopy(model)
267
+ cpu_model.load_state_dict(checkpoint['model'])
268
+ model.load_state_dict(cpu_model.state_dict())
269
+ optimizer.load_state_dict(checkpoint['optimizer'])
270
+ lr_scheduler.load_state_dict(checkpoint['lr_sched'])
271
+ pre_epoch = checkpoint['epoch']
272
+ best_epoch = checkpoint['best_epoch']
273
+ min_eval_loss = checkpoint['min_eval_loss']
274
+ print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
275
+ checkpoint = None
276
+
277
+ else:
278
+ pre_epoch = 0
279
+ best_epoch = 0
280
+ min_eval_loss = float('inf')
281
+
282
+ model = model.to(device)
283
+ optimizer = torch.optim.AdamW(model.parameters(), lr=M3_LEARNING_RATE)
284
+
285
+ for epoch in range(1+pre_epoch, M3_NUM_EPOCH+1):
286
+ train_sampler.set_epoch(epoch)
287
+ eval_sampler.set_epoch(epoch)
288
+ print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
289
+ train_loss = train_epoch(epoch)
290
+ eval_loss = eval_epoch()
291
+ if global_rank==0:
292
+ with open(M3_LOGS_PATH,'a') as f:
293
+ f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
294
+ if eval_loss < min_eval_loss:
295
+ best_epoch = epoch
296
+ min_eval_loss = eval_loss
297
+ checkpoint = {
298
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
299
+ 'optimizer': optimizer.state_dict(),
300
+ 'lr_sched': lr_scheduler.state_dict(),
301
+ 'epoch': epoch,
302
+ 'best_epoch': best_epoch,
303
+ 'min_eval_loss': min_eval_loss
304
+ }
305
+ torch.save(checkpoint, M3_WEIGHTS_PATH)
306
+ checkpoint = {
307
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
308
+ 'optimizer': optimizer.state_dict(),
309
+ 'lr_sched': lr_scheduler.state_dict(),
310
+ 'epoch': epoch,
311
+ 'best_epoch': best_epoch,
312
+ 'min_eval_loss': min_eval_loss
313
+ }
314
+ torch.save(checkpoint, "latest_"+M3_WEIGHTS_PATH)
315
+
316
+ if world_size > 1:
317
+ dist.barrier()
318
+
319
+ if global_rank==0:
320
+ print("Best Eval Epoch : "+str(best_epoch))
321
+ print("Min Eval Loss : "+str(min_eval_loss))
code/utils.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import math
4
+ import torch
5
+ import random
6
+ from config import *
7
+ from unidecode import unidecode
8
+ from torch.nn import functional as F
9
+ from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config
10
+
11
+ try:
12
+ import torch.distributed.nn
13
+ from torch import distributed as dist
14
+
15
+ has_distributed = True
16
+ except ImportError:
17
+ has_distributed = False
18
+
19
+ try:
20
+ import horovod.torch as hvd
21
+ except ImportError:
22
+ hvd = None
23
+
24
+ class ClipLoss(torch.nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ local_loss=False,
29
+ gather_with_grad=False,
30
+ cache_labels=False,
31
+ rank=0,
32
+ world_size=1,
33
+ use_horovod=False,
34
+ ):
35
+ super().__init__()
36
+ self.local_loss = local_loss
37
+ self.gather_with_grad = gather_with_grad
38
+ self.cache_labels = cache_labels
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.use_horovod = use_horovod
42
+
43
+ # cache state
44
+ self.prev_num_logits = 0
45
+ self.labels = {}
46
+
47
+ def gather_features(
48
+ self,
49
+ image_features,
50
+ text_features,
51
+ local_loss=False,
52
+ gather_with_grad=False,
53
+ rank=0,
54
+ world_size=1,
55
+ use_horovod=False
56
+ ):
57
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
58
+ if use_horovod:
59
+ assert hvd is not None, 'Please install horovod'
60
+ if gather_with_grad:
61
+ all_image_features = hvd.allgather(image_features)
62
+ all_text_features = hvd.allgather(text_features)
63
+ else:
64
+ with torch.no_grad():
65
+ all_image_features = hvd.allgather(image_features)
66
+ all_text_features = hvd.allgather(text_features)
67
+ if not local_loss:
68
+ # ensure grads for local rank when all_* features don't have a gradient
69
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
70
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
71
+ gathered_image_features[rank] = image_features
72
+ gathered_text_features[rank] = text_features
73
+ all_image_features = torch.cat(gathered_image_features, dim=0)
74
+ all_text_features = torch.cat(gathered_text_features, dim=0)
75
+ else:
76
+ # We gather tensors from all gpus
77
+ if gather_with_grad:
78
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
79
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
80
+ else:
81
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
82
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
83
+ dist.all_gather(gathered_image_features, image_features)
84
+ dist.all_gather(gathered_text_features, text_features)
85
+ if not local_loss:
86
+ # ensure grads for local rank when all_* features don't have a gradient
87
+ gathered_image_features[rank] = image_features
88
+ gathered_text_features[rank] = text_features
89
+ all_image_features = torch.cat(gathered_image_features, dim=0)
90
+ all_text_features = torch.cat(gathered_text_features, dim=0)
91
+
92
+ return all_image_features, all_text_features
93
+
94
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
95
+ # calculated ground-truth and cache if enabled
96
+ if self.prev_num_logits != num_logits or device not in self.labels:
97
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
98
+ if self.world_size > 1 and self.local_loss:
99
+ labels = labels + num_logits * self.rank
100
+ if self.cache_labels:
101
+ self.labels[device] = labels
102
+ self.prev_num_logits = num_logits
103
+ else:
104
+ labels = self.labels[device]
105
+ return labels
106
+
107
+ def get_logits(self, image_features, text_features, logit_scale):
108
+ if self.world_size > 1:
109
+ all_image_features, all_text_features = self.gather_features(
110
+ image_features, text_features,
111
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
112
+
113
+ if self.local_loss:
114
+ logits_per_image = logit_scale * image_features @ all_text_features.T
115
+ logits_per_text = logit_scale * text_features @ all_image_features.T
116
+ else:
117
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
118
+ logits_per_text = logits_per_image.T
119
+ else:
120
+ logits_per_image = logit_scale * image_features @ text_features.T
121
+ logits_per_text = logit_scale * text_features @ image_features.T
122
+
123
+ return logits_per_image, logits_per_text
124
+
125
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
126
+ device = image_features.device
127
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
128
+
129
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
130
+
131
+ total_loss = (
132
+ F.cross_entropy(logits_per_image, labels) +
133
+ F.cross_entropy(logits_per_text, labels)
134
+ ) / 2
135
+
136
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
137
+
138
+ class M3Patchilizer:
139
+ def __init__(self):
140
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
141
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
142
+ self.pad_token_id = 0
143
+ self.bos_token_id = 1
144
+ self.eos_token_id = 2
145
+ self.mask_token_id = 3
146
+
147
+ def split_bars(self, body):
148
+ bars = re.split(self.regexPattern, ''.join(body))
149
+ bars = list(filter(None, bars)) # remove empty strings
150
+ if bars[0] in self.delimiters:
151
+ bars[1] = bars[0] + bars[1]
152
+ bars = bars[1:]
153
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
154
+ return bars
155
+
156
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
157
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
158
+ patch = patch[:patch_size]
159
+ patch += [self.pad_token_id] * (patch_size - len(patch))
160
+ return patch
161
+
162
+ def patch2bar(self, patch):
163
+ return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch)
164
+
165
+ def encode(self,
166
+ item,
167
+ patch_size=PATCH_SIZE,
168
+ add_special_patches=False,
169
+ truncate=False,
170
+ random_truncate=False):
171
+
172
+ item = unidecode(item)
173
+ lines = re.findall(r'.*?\n|.*$', item)
174
+ lines = list(filter(None, lines)) # remove empty lines
175
+
176
+ patches = []
177
+
178
+ if lines[0].split(" ")[0] == "ticks_per_beat":
179
+ patch = ""
180
+ for line in lines:
181
+ if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2):
182
+ patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:])
183
+ else:
184
+ if patch:
185
+ patches.append(patch)
186
+ patch = line
187
+ if patch!="":
188
+ patches.append(patch)
189
+ else:
190
+ for line in lines:
191
+ if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')):
192
+ patches.append(line)
193
+ else:
194
+ bars = self.split_bars(line)
195
+ if bars:
196
+ bars[-1] += '\n'
197
+ patches.extend(bars)
198
+
199
+ if add_special_patches:
200
+ bos_patch = chr(self.bos_token_id) * patch_size
201
+ eos_patch = chr(self.eos_token_id) * patch_size
202
+ patches = [bos_patch] + patches + [eos_patch]
203
+
204
+ if len(patches) > PATCH_LENGTH and truncate:
205
+ choices = ["head", "tail", "middle"]
206
+ choice = random.choice(choices)
207
+ if choice=="head" or random_truncate==False:
208
+ patches = patches[:PATCH_LENGTH]
209
+ elif choice=="tail":
210
+ patches = patches[-PATCH_LENGTH:]
211
+ else:
212
+ start = random.randint(1, len(patches)-PATCH_LENGTH)
213
+ patches = patches[start:start+PATCH_LENGTH]
214
+
215
+ patches = [self.bar2patch(patch) for patch in patches]
216
+
217
+ return patches
218
+
219
+ def decode(self, patches):
220
+ return ''.join(self.patch2bar(patch) for patch in patches)
221
+
222
+ class M3PatchEncoder(PreTrainedModel):
223
+ def __init__(self, config):
224
+ super(M3PatchEncoder, self).__init__(config)
225
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE)
226
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
227
+ self.base = BertModel(config=config)
228
+ self.pad_token_id = 0
229
+ self.bos_token_id = 1
230
+ self.eos_token_id = 2
231
+ self.mask_token_id = 3
232
+
233
+ def forward(self,
234
+ input_patches, # [batch_size, seq_length, hidden_size]
235
+ input_masks): # [batch_size, seq_length]
236
+ # Transform input_patches into embeddings
237
+ input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128)
238
+ input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor)
239
+ input_patches = self.patch_embedding(input_patches.to(self.device))
240
+
241
+ # Apply BERT model to input_patches and input_masks
242
+ return self.base(inputs_embeds=input_patches, attention_mask=input_masks)
243
+
244
+ class M3TokenDecoder(PreTrainedModel):
245
+ def __init__(self, config):
246
+ super(M3TokenDecoder, self).__init__(config)
247
+ self.base = GPT2LMHeadModel(config=config)
248
+ self.pad_token_id = 0
249
+ self.bos_token_id = 1
250
+ self.eos_token_id = 2
251
+ self.mask_token_id = 3
252
+
253
+ def forward(self,
254
+ patch_features, # [batch_size, hidden_size]
255
+ target_patches): # [batch_size, seq_length]
256
+ # get input embeddings
257
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
258
+
259
+ # concatenate the encoded patches with the input embeddings
260
+ inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
261
+
262
+ # preparing the labels for model training
263
+ target_masks = target_patches == self.pad_token_id
264
+ target_patches = target_patches.clone().masked_fill_(target_masks, -100)
265
+
266
+ # get the attention mask
267
+ target_masks = ~target_masks
268
+ target_masks = target_masks.type(torch.int)
269
+
270
+ return self.base(inputs_embeds=inputs_embeds,
271
+ attention_mask=target_masks,
272
+ labels=target_patches)
273
+
274
+ def generate(self,
275
+ patch_feature,
276
+ tokens):
277
+ # reshape the patch_feature and tokens
278
+ patch_feature = patch_feature.reshape(1, 1, -1)
279
+ tokens = tokens.reshape(1, -1)
280
+
281
+ # get input embeddings
282
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
283
+
284
+ # concatenate the encoded patches with the input embeddings
285
+ tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1)
286
+
287
+ # get the outputs from the model
288
+ outputs = self.base(inputs_embeds=tokens)
289
+
290
+ # get the probabilities of the next token
291
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
292
+
293
+ return probs.detach().cpu().numpy()
294
+
295
+ class M3Model(PreTrainedModel):
296
+ def __init__(self, encoder_config, decoder_config):
297
+ super(M3Model, self).__init__(encoder_config)
298
+ self.encoder = M3PatchEncoder(encoder_config)
299
+ self.decoder = M3TokenDecoder(decoder_config)
300
+ self.pad_token_id = 0
301
+ self.bos_token_id = 1
302
+ self.eos_token_id = 2
303
+ self.mask_token_id = 3
304
+
305
+ def forward(self,
306
+ input_patches, # [batch_size, seq_length, hidden_size]
307
+ input_masks, # [batch_size, seq_length]
308
+ selected_indices, # [batch_size, seq_length]
309
+ target_patches): # [batch_size, seq_length, hidden_size]
310
+ input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device)
311
+ input_masks = input_masks.to(self.device)
312
+ selected_indices = selected_indices.to(self.device)
313
+ target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device)
314
+
315
+ # Pass the input_patches and input_masks through the encoder
316
+ outputs = self.encoder(input_patches, input_masks)["last_hidden_state"]
317
+
318
+ # Use selected_indices to form target_patches
319
+ target_patches = target_patches[selected_indices.bool()]
320
+ patch_features = outputs[selected_indices.bool()]
321
+
322
+ # Pass patch_features and target_patches through the decoder
323
+ return self.decoder(patch_features, target_patches)
324
+
325
+ class CLaMP2Model(PreTrainedModel):
326
+ def __init__(self,
327
+ music_config,
328
+ global_rank=None,
329
+ world_size=None,
330
+ text_model_name=TEXT_MODEL_NAME,
331
+ hidden_size=CLAMP2_HIDDEN_SIZE,
332
+ load_m3=CLAMP2_LOAD_M3):
333
+ super(CLaMP2Model, self).__init__(music_config)
334
+
335
+ self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model
336
+ self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections
337
+ torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution
338
+
339
+ self.music_model = M3PatchEncoder(music_config) # Initialize the music model
340
+ self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections
341
+ torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution
342
+
343
+ if global_rank==None or world_size==None:
344
+ global_rank = 0
345
+ world_size = 1
346
+
347
+ self.loss_fn = ClipLoss(local_loss=False,
348
+ gather_with_grad=True,
349
+ cache_labels=False,
350
+ rank=global_rank,
351
+ world_size=world_size,
352
+ use_horovod=False)
353
+
354
+ if load_m3 and os.path.exists(M3_WEIGHTS_PATH):
355
+ checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
356
+ decoder_config = GPT2Config(vocab_size=128,
357
+ n_positions=PATCH_SIZE,
358
+ n_embd=M3_HIDDEN_SIZE,
359
+ n_layer=TOKEN_NUM_LAYERS,
360
+ n_head=M3_HIDDEN_SIZE//64,
361
+ n_inner=M3_HIDDEN_SIZE*4)
362
+ model = M3Model(music_config, decoder_config)
363
+ model.load_state_dict(checkpoint['model'])
364
+ self.music_model = model.encoder
365
+ model = None
366
+ print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
367
+
368
+ def avg_pooling(self, input_features, input_masks):
369
+ input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension
370
+ input_features = input_features * input_masks # apply mask to input_features
371
+ avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling
372
+
373
+ return avg_pool
374
+
375
+ def get_text_features(self,
376
+ text_inputs,
377
+ text_masks,
378
+ get_normalized=False):
379
+ text_features = self.text_model(text_inputs.to(self.device),
380
+ attention_mask=text_masks.to(self.device))['last_hidden_state']
381
+
382
+ if get_normalized:
383
+ text_features = self.avg_pooling(text_features, text_masks)
384
+ text_features = self.text_proj(text_features)
385
+
386
+ return text_features
387
+
388
+ def get_music_features(self,
389
+ music_inputs,
390
+ music_masks,
391
+ get_normalized=False):
392
+ music_features = self.music_model(music_inputs.to(self.device),
393
+ music_masks.to(self.device))['last_hidden_state']
394
+
395
+ if get_normalized:
396
+ music_features = self.avg_pooling(music_features, music_masks)
397
+ music_features = self.music_proj(music_features)
398
+
399
+ return music_features
400
+
401
+ def forward(self,
402
+ text_inputs, # [batch_size, seq_length]
403
+ text_masks, # [batch_size, seq_length]
404
+ music_inputs, # [batch_size, seq_length, hidden_size]
405
+ music_masks): # [batch_size, seq_length]
406
+ # Compute the text features
407
+ text_features = self.get_text_features(text_inputs, text_masks, get_normalized=True)
408
+
409
+ # Compute the music features
410
+ music_features = self.get_music_features(music_inputs, music_masks, get_normalized=True)
411
+
412
+ return self.loss_fn(text_features,
413
+ music_features,
414
+ LOGIT_SCALE,
415
+ output_dict=False)
416
+
417
+ def split_data(data, eval_ratio=EVAL_SPLIT):
418
+ random.shuffle(data)
419
+ split_idx = int(len(data)*eval_ratio)
420
+ eval_set = data[:split_idx]
421
+ train_set = data[split_idx:]
422
+ return train_set, eval_set
423
+
424
+ def mask_patches(target_patches, patchilizer, mode):
425
+ indices = list(range(len(target_patches)))
426
+ random.shuffle(indices)
427
+ selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))]
428
+ sorted_indices = sorted(selected_indices)
429
+ input_patches = torch.tensor(target_patches)
430
+
431
+ if mode=="eval":
432
+ choice = "original"
433
+ else:
434
+ choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0]
435
+
436
+ if choice=="mask":
437
+ input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE)
438
+ elif choice=="shuffle":
439
+ for idx in sorted_indices:
440
+ patch = input_patches[idx]
441
+ try:
442
+ index_eos = (patch == patchilizer.eos_token_id).nonzero().item()
443
+ except:
444
+ index_eos = len(patch)
445
+
446
+ indices = list(range(1, index_eos))
447
+ random.shuffle(indices)
448
+ indices = [0] + indices + list(range(index_eos, len(patch)))
449
+ input_patches[idx] = patch[indices]
450
+
451
+ selected_indices = torch.zeros(len(target_patches))
452
+ selected_indices[sorted_indices] = 1.
453
+
454
+ return input_patches, selected_indices
455
+
456
+ def remove_instrument_info(item):
457
+ # remove instrument information from symbolic music
458
+ lines = re.findall(r'.*?\n|.*$', item)
459
+ lines = list(filter(None, lines))
460
+ if lines[0].split(" ")[0] == "ticks_per_beat":
461
+ type = "mtf"
462
+ else:
463
+ type = "abc"
464
+
465
+ cleaned_lines = []
466
+ for line in lines:
467
+ if type=="abc" and line.startswith("V:"):
468
+ # find the position of " nm=" or " snm="
469
+ nm_pos = line.find(" nm=")
470
+ snm_pos = line.find(" snm=")
471
+ # keep the part before " nm=" or " snm="
472
+ if nm_pos != -1:
473
+ line = line[:nm_pos]
474
+ elif snm_pos != -1:
475
+ line = line[:snm_pos]
476
+ if nm_pos != -1 or snm_pos != -1:
477
+ line += "\n"
478
+ elif type=="mtf" and line.startswith("program_change"):
479
+ line = " ".join(line.split(" ")[:-1]) + " 0\n"
480
+
481
+ cleaned_lines.append(line)
482
+
483
+ return ''.join(cleaned_lines)
environment.yml ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: clamp2
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - blas=1.0=mkl
8
+ - brotli-python=1.0.9=py310hd77b12b_8
9
+ - bzip2=1.0.8=h2bbff1b_6
10
+ - ca-certificates=2024.7.2=haa95532_0
11
+ - cuda-cccl=12.6.37=0
12
+ - cuda-cccl_win-64=12.6.37=0
13
+ - cuda-cudart=11.8.89=0
14
+ - cuda-cudart-dev=11.8.89=0
15
+ - cuda-cupti=11.8.87=0
16
+ - cuda-libraries=11.8.0=0
17
+ - cuda-libraries-dev=11.8.0=0
18
+ - cuda-nvrtc=11.8.89=0
19
+ - cuda-nvrtc-dev=11.8.89=0
20
+ - cuda-nvtx=11.8.86=0
21
+ - cuda-profiler-api=12.6.68=0
22
+ - cuda-runtime=11.8.0=0
23
+ - cuda-version=12.6=3
24
+ - freetype=2.12.1=ha860e81_0
25
+ - gmpy2=2.1.2=py310h7f96b67_0
26
+ - intel-openmp=2023.1.0=h59b6b97_46320
27
+ - jinja2=3.1.4=py310haa95532_0
28
+ - jpeg=9e=h827c3e9_3
29
+ - lcms2=2.12=h83e58a3_0
30
+ - lerc=3.0=hd77b12b_0
31
+ - libcublas=11.11.3.6=0
32
+ - libcublas-dev=11.11.3.6=0
33
+ - libcufft=10.9.0.58=0
34
+ - libcufft-dev=10.9.0.58=0
35
+ - libcurand=10.3.7.68=0
36
+ - libcurand-dev=10.3.7.68=0
37
+ - libcusolver=11.4.1.48=0
38
+ - libcusolver-dev=11.4.1.48=0
39
+ - libcusparse=11.7.5.86=0
40
+ - libcusparse-dev=11.7.5.86=0
41
+ - libdeflate=1.17=h2bbff1b_1
42
+ - libffi=3.4.4=hd77b12b_1
43
+ - libjpeg-turbo=2.0.0=h196d8e1_0
44
+ - libnpp=11.8.0.86=0
45
+ - libnpp-dev=11.8.0.86=0
46
+ - libnvjpeg=11.9.0.86=0
47
+ - libnvjpeg-dev=11.9.0.86=0
48
+ - libpng=1.6.39=h8cc25b3_0
49
+ - libtiff=4.5.1=hd77b12b_0
50
+ - libuv=1.48.0=h827c3e9_0
51
+ - libwebp-base=1.3.2=h2bbff1b_0
52
+ - lz4-c=1.9.4=h2bbff1b_1
53
+ - mkl=2023.1.0=h6b88ed4_46358
54
+ - mkl-service=2.4.0=py310h2bbff1b_1
55
+ - mkl_fft=1.3.8=py310h2bbff1b_0
56
+ - mkl_random=1.2.4=py310h59b6b97_0
57
+ - mpc=1.1.0=h7edee0f_1
58
+ - mpfr=4.0.2=h62dcd97_1
59
+ - mpir=3.0.0=hec2e145_1
60
+ - mpmath=1.3.0=py310haa95532_0
61
+ - networkx=3.3=py310haa95532_0
62
+ - numpy=1.26.4=py310h055cbcc_0
63
+ - numpy-base=1.26.4=py310h65a83cf_0
64
+ - openjpeg=2.5.2=hae555c5_0
65
+ - openssl=3.0.14=h827c3e9_0
66
+ - pip=24.2=py310haa95532_0
67
+ - pysocks=1.7.1=py310haa95532_0
68
+ - python=3.10.14=he1021f5_1
69
+ - pytorch=2.4.0=py3.10_cuda11.8_cudnn9_0
70
+ - pytorch-cuda=11.8=h24eeafa_5
71
+ - pytorch-mutex=1.0=cuda
72
+ - pyyaml=6.0.1=py310h2bbff1b_0
73
+ - requests=2.32.3=py310haa95532_0
74
+ - setuptools=72.1.0=py310haa95532_0
75
+ - sqlite=3.45.3=h2bbff1b_0
76
+ - sympy=1.13.2=py310haa95532_0
77
+ - tbb=2021.8.0=h59b6b97_0
78
+ - tk=8.6.14=h0416ee5_0
79
+ - typing_extensions=4.11.0=py310haa95532_0
80
+ - tzdata=2024a=h04d1e81_0
81
+ - vc=14.40=h2eaa2aa_0
82
+ - vs2015_runtime=14.40.33807=h98bb1dd_0
83
+ - wheel=0.43.0=py310haa95532_0
84
+ - win_inet_pton=1.1.0=py310haa95532_0
85
+ - xz=5.4.6=h8cc25b3_1
86
+ - yaml=0.2.5=he774522_0
87
+ - zlib=1.2.13=h8cc25b3_1
88
+ - zstd=1.5.5=hd43e919_2
89
+ - pip:
90
+ - abctoolkit==0.0.4
91
+ - accelerate==0.34.0
92
+ - aiohappyeyeballs==2.4.0
93
+ - aiohttp==3.10.5
94
+ - aiosignal==1.3.1
95
+ - annotated-types==0.7.0
96
+ - anyio==4.6.2.post1
97
+ - async-timeout==4.0.3
98
+ - attrs==24.2.0
99
+ - audioread==3.0.1
100
+ - certifi==2023.7.22
101
+ - cffi==1.17.0
102
+ - chardet==5.2.0
103
+ - charset-normalizer==3.2.0
104
+ - click==8.1.7
105
+ - colorama==0.4.6
106
+ - coloredlogs==15.0.1
107
+ - cycler==0.11.0
108
+ - datasets==2.21.0
109
+ - decorator==5.1.1
110
+ - dill==0.3.8
111
+ - distro==1.9.0
112
+ - docker-pycreds==0.4.0
113
+ - exceptiongroup==1.2.2
114
+ - filelock==3.12.2
115
+ - fonttools==4.38.0
116
+ - frozenlist==1.4.1
117
+ - fsspec==2024.6.1
118
+ - gitdb==4.0.11
119
+ - gitpython==3.1.43
120
+ - h11==0.14.0
121
+ - httpcore==1.0.6
122
+ - httpx==0.27.2
123
+ - huggingface-hub==0.24.6
124
+ - humanfriendly==10.0
125
+ - idna==3.4
126
+ - importlib-metadata==6.7.0
127
+ - jellyfish==1.0.0
128
+ - jiter==0.6.1
129
+ - joblib==1.3.2
130
+ - jsonpickle==3.0.2
131
+ - kiwisolver==1.4.4
132
+ - langcodes==3.4.0
133
+ - langdetect==1.0.9
134
+ - langid==1.1.6
135
+ - language-data==1.2.0
136
+ - lazy-loader==0.4
137
+ - levenshtein==0.25.1
138
+ - librosa==0.10.1
139
+ - llvmlite==0.43.0
140
+ - lxml==5.3.0
141
+ - marisa-trie==1.2.0
142
+ - markupsafe==2.1.5
143
+ - matplotlib==3.5.3
144
+ - mido==1.3.0
145
+ - more-itertools==9.1.0
146
+ - msgpack==1.0.8
147
+ - multidict==6.0.5
148
+ - multiprocess==0.70.16
149
+ - music21==7.3.3
150
+ - nltk==3.8.1
151
+ - numba==0.60.0
152
+ - openai==1.51.2
153
+ - optimum==1.21.4
154
+ - packaging==23.1
155
+ - pandas==1.3.5
156
+ - pillow==9.5.0
157
+ - platformdirs==4.2.2
158
+ - pooch==1.8.2
159
+ - portalocker==2.10.1
160
+ - protobuf==5.28.0
161
+ - psutil==6.0.0
162
+ - pyarrow==17.0.0
163
+ - pycparser==2.22
164
+ - pydantic==2.9.2
165
+ - pydantic-core==2.23.4
166
+ - pydub==0.25.1
167
+ - pyparsing==3.1.1
168
+ - pyreadline3==3.4.1
169
+ - python-dateutil==2.8.2
170
+ - pytz==2023.3
171
+ - pywin32==306
172
+ - rapidfuzz==3.9.7
173
+ - rarfile==4.1
174
+ - regex==2023.8.8
175
+ - sacrebleu==2.4.3
176
+ - sacremoses==0.0.53
177
+ - safetensors==0.4.4
178
+ - samplings==0.1.7
179
+ - scikit-learn==1.5.1
180
+ - scipy==1.14.1
181
+ - sentencepiece==0.2.0
182
+ - sentry-sdk==2.13.0
183
+ - setproctitle==1.3.3
184
+ - six==1.16.0
185
+ - smmap==5.0.1
186
+ - sniffio==1.3.1
187
+ - soundfile==0.12.1
188
+ - soxr==0.5.0.post1
189
+ - tabulate==0.9.0
190
+ - threadpoolctl==3.5.0
191
+ - tokenizers==0.19.1
192
+ - torch==2.4.0
193
+ - torchaudio==2.4.0
194
+ - torchvision==0.19.0
195
+ - tqdm==4.66.5
196
+ - transformers==4.40.0
197
+ - typing-extensions==4.12.2
198
+ - unidecode==1.3.6
199
+ - urllib3==2.0.4
200
+ - wandb==0.17.8
201
+ - webcolors==1.13
202
+ - xxhash==3.5.0
203
+ - yarl==1.9.7
204
+ - zipp==3.15.0
music_classification/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Music Classification Codebase
2
+
3
+ ## Overview
4
+ Linear Probe is a powerful classification tool that leverages feature representations for supervised learning tasks. This codebase includes scripts for training a linear classification model, performing classification on new feature data. The features utilized can be extracted from the M3 or CLaMP 2 models, ensuring that the time dimension information is preserved and **not normalized**. Below is a description of the scripts contained in the `music_classification/` folder.
5
+
6
+ ## Repository Structure
7
+ The `music_classification/` folder contains the following scripts:
8
+
9
+ ### 1. `config.py`
10
+ This script defines configurations for the linear probe training and inference, specifying training data paths and parameters like learning rate, number of epochs, and hidden size.
11
+
12
+ ### 2. `inference_cls.py`
13
+ This script enables the classification of feature vectors using a pre-trained linear probe model.
14
+
15
+ #### JSON Output Format
16
+ The resulting JSON file contains a dictionary with the following structure:
17
+ ```json
18
+ {
19
+ "path/to/feature1.npy": "class_A",
20
+ "path/to/feature2.npy": "class_B",
21
+ "path/to/feature3.npy": "class_A"
22
+ }
23
+ ```
24
+ - **Key**: The path to the input feature file (e.g., `feature1.npy`).
25
+ - **Value**: The predicted class label assigned by the linear probe model (e.g., `class_A`).
26
+
27
+ #### Usage
28
+ ```bash
29
+ python inference_cls.py <feature_folder> <output_file>
30
+ ```
31
+ - `feature_folder`: Directory containing input feature files (in `.npy` format).
32
+ - `output_file`: File path to save the classification results (in JSON format).
33
+
34
+ ### 3. `train_cls.py`
35
+ This script is designed for training the linear classification model.
36
+
37
+ #### Usage
38
+ ```bash
39
+ python train_cls.py
40
+ ```
41
+
42
+ ### 4. `utils.py`
43
+ The utility script defines the architecture of the linear classification model.
44
+
45
+ ## Naming Convention
46
+ All `.npy` files used in this codebase must follow the naming convention of `label_filename.npy`, where the filename should not contain any underscores (`_`).
music_classification/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for generative modelling and classification
2
+ TRAIN_FOLDERS = [
3
+ "<path_to_training_data>" # Directory containing training data
4
+ ]
5
+
6
+ EVAL_FOLDERS = [
7
+ "" # (Optional) Directory containing evaluation data
8
+ ]
9
+
10
+ EVAL_SPLIT = 0.2 # Fraction of training data to use for evaluation
11
+
12
+ # Weights and Biases configuration
13
+ WANDB_KEY = "<your_wandb_key>" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging
14
+
15
+ # Model Configuration
16
+ INPUT_HIDDEN_SIZE = 768 # Input hidden size
17
+ HIDDEN_SIZE = 768 # Model hidden size
18
+ NUM_EPOCHS = 1000 # Max number of epochs to train (early stopping can terminate earlier)
19
+ LEARNING_RATE = 1e-5 # Optimizer learning rate
20
+ BALANCED_TRAINING = False # Set to True to balance labels in training data
21
+ WANDB_LOG = False # Set to True to log training metrics to WANDB
22
+
23
+ # Paths Configuration
24
+ last_folder_name = TRAIN_FOLDERS[-1].split('/')[-1]
25
+ WEIGHTS_PATH = f"weights-{last_folder_name}.pth" # Weights file path
26
+ LOGS_PATH = f"logs-{last_folder_name}.txt" # Log file path
music_classification/inference_cls.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from utils import *
7
+ from tqdm import tqdm
8
+ from samplings import *
9
+ import argparse
10
+
11
+ def list_files_in_directory(directories, extensions=["npy"]):
12
+ file_list = []
13
+
14
+ for directory in directories:
15
+ for root, dirs, files in os.walk(directory):
16
+ for file in files:
17
+ if any(file.endswith(ext) for ext in extensions):
18
+ file_path = os.path.join(root, file)
19
+ file_list.append(file_path)
20
+
21
+ return file_list
22
+
23
+ if __name__ == "__main__":
24
+ # Setup argument parser
25
+ parser = argparse.ArgumentParser(description="Feature extraction and classification with CLaMP2.")
26
+ parser.add_argument("feature_folder", type=str, help="Directory containing input feature files.")
27
+ parser.add_argument("output_file", type=str, help="File to save the classification results. (format: json)")
28
+
29
+ # Parse arguments
30
+ args = parser.parse_args()
31
+ feature_folder = args.feature_folder
32
+ output_file = args.output_file
33
+
34
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35
+ seed = 42
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ torch.cuda.manual_seed_all(seed)
40
+ torch.backends.cudnn.deterministic = True
41
+ torch.backends.cudnn.benchmark = False
42
+
43
+ checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
44
+ print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with acc {checkpoint['max_eval_acc']}")
45
+ label2idx = checkpoint['labels']
46
+ idx2label = {idx: label for label, idx in label2idx.items()} # Create reverse mapping
47
+ model = LinearClassification(num_classes=len(label2idx))
48
+ model = model.to(device)
49
+
50
+ # print parameter number
51
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
52
+
53
+ model.eval()
54
+ model.load_state_dict(checkpoint['model'])
55
+
56
+ # load filenames under train and eval folder
57
+ feature_files = list_files_in_directory([feature_folder])
58
+ cls_results = {}
59
+
60
+ for filepath in tqdm(feature_files):
61
+ outputs = np.load(filepath)[0]
62
+ outputs = torch.from_numpy(outputs).to(device)
63
+ outputs = outputs.unsqueeze(0)
64
+ cls_list = model(outputs)[0].tolist()
65
+ max_prob = max(cls_list)
66
+ cls_idx = cls_list.index(max_prob)
67
+ cls_label = idx2label[cls_idx]
68
+ cls_results[filepath] = cls_label
69
+
70
+ with open(output_file, "w", encoding="utf-8") as f:
71
+ json.dump(cls_results, f)
music_classification/train_cls.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import wandb
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ from utils import *
9
+ from config import *
10
+ from tqdm import tqdm
11
+ from sklearn.metrics import f1_score
12
+ from torch.amp import autocast, GradScaler
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from transformers import get_constant_schedule_with_warmup
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data.distributed import DistributedSampler
18
+
19
+ # Set up distributed training
20
+ world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
21
+ global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
22
+ local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
23
+
24
+ if world_size > 1:
25
+ torch.cuda.set_device(local_rank)
26
+ device = torch.device("cuda", local_rank)
27
+ dist.init_process_group(backend='nccl') if world_size > 1 else None
28
+ else:
29
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
+
31
+ # Set random seed
32
+ seed = 42 + global_rank
33
+ random.seed(seed)
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+
40
+ batch_size = 1
41
+
42
+ def collate_batch(input_tensors):
43
+
44
+ input_tensors, labels = zip(*input_tensors)
45
+ input_tensors = torch.stack(input_tensors, dim=0)
46
+ labels = torch.stack(labels, dim=0)
47
+
48
+ return input_tensors.to(device), labels.to(device)
49
+
50
+ def list_files_in_directory(directories):
51
+ file_list = []
52
+
53
+ for directory in directories:
54
+ for root, dirs, files in os.walk(directory):
55
+ for file in files:
56
+ if file.endswith(".npy"):
57
+ file_path = os.path.join(root, file)
58
+ file_list.append(file_path)
59
+ return file_list
60
+
61
+ class TensorDataset(Dataset):
62
+ def __init__(self, filenames):
63
+ print(f"Loading {len(filenames)} files for classification")
64
+ self.filenames = []
65
+ self.label2idx = {}
66
+
67
+ for filename in tqdm(filenames):
68
+ label = os.path.basename(filename).split('_')[0]
69
+
70
+ self.filenames.append(filename)
71
+ if label not in self.label2idx:
72
+ self.label2idx[label] = len(self.label2idx)
73
+ print(f"Found {len(self.label2idx)} classes")
74
+
75
+ def __len__(self):
76
+ return len(self.filenames)
77
+
78
+ def __getitem__(self, idx):
79
+
80
+ filename = self.filenames[idx]
81
+ label = os.path.basename(filename).split('_')[0]
82
+ label = self.label2idx[label]
83
+
84
+ # load numpy file
85
+ data = np.load(filename)
86
+ data = torch.from_numpy(data)[0]
87
+ label = torch.tensor(label)
88
+
89
+ return data, label
90
+
91
+ class BalancedTensorDataset(Dataset):
92
+ def __init__(self, filenames):
93
+ print(f"Loading {len(filenames)} files for classification")
94
+ self.filenames = filenames
95
+ self.label2idx = {}
96
+ self.label2files = {}
97
+
98
+ for filename in tqdm(filenames):
99
+ label = os.path.basename(filename).split('_')[0]
100
+ if label not in self.label2idx:
101
+ self.label2idx[label] = len(self.label2idx)
102
+ if label not in self.label2files:
103
+ self.label2files[label] = []
104
+ self.label2files[label].append(filename)
105
+ print(f"Found {len(self.label2idx)} classes")
106
+
107
+ self.min_samples = min(len(files) for files in self.label2files.values())
108
+
109
+ self._update_epoch_filenames()
110
+
111
+ def _update_epoch_filenames(self):
112
+ self.epoch_filenames = []
113
+ for label, files in self.label2files.items():
114
+ sampled_files = random.sample(files, self.min_samples)
115
+ self.epoch_filenames.extend(sampled_files)
116
+
117
+ random.shuffle(self.epoch_filenames)
118
+
119
+ def __len__(self):
120
+ return len(self.epoch_filenames)
121
+
122
+ def __getitem__(self, idx):
123
+ filename = self.epoch_filenames[idx]
124
+ label = os.path.basename(filename).split('_')[0]
125
+ label = self.label2idx[label]
126
+
127
+ data = np.load(filename)
128
+ data = torch.from_numpy(data)[0]
129
+ label = torch.tensor(label)
130
+
131
+ return data, label
132
+
133
+ def on_epoch_end(self):
134
+ self._update_epoch_filenames()
135
+
136
+ # load filenames under train and eval folder
137
+ train_files = list_files_in_directory(TRAIN_FOLDERS)
138
+ eval_files = list_files_in_directory(EVAL_FOLDERS)
139
+
140
+ if len(eval_files)==0:
141
+ random.shuffle(train_files)
142
+ eval_files = train_files[:math.ceil(len(train_files)*EVAL_SPLIT)]
143
+ train_files = train_files[math.ceil(len(train_files)*EVAL_SPLIT):]
144
+ if BALANCED_TRAINING:
145
+ train_set = BalancedTensorDataset(train_files)
146
+ else:
147
+ train_set = TensorDataset(train_files)
148
+ eval_set = TensorDataset(eval_files)
149
+ eval_set.label2idx = train_set.label2idx
150
+
151
+ model = LinearClassification(num_classes=len(train_set.label2idx))
152
+ model = model.to(device)
153
+
154
+ # print parameter number
155
+ print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
156
+
157
+ if world_size > 1:
158
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
159
+
160
+ scaler = GradScaler()
161
+ is_autocast = True
162
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
163
+ loss_fn = torch.nn.CrossEntropyLoss()
164
+
165
+ # call model with a batch of input
166
+ def process_one_batch(batch):
167
+ input_tensors, labels = batch
168
+ logits = model(input_tensors)
169
+ loss = loss_fn(logits, labels)
170
+ prediction = torch.argmax(logits, dim=1)
171
+ acc_num = torch.sum(prediction==labels)
172
+
173
+ return loss, acc_num, prediction, labels
174
+
175
+ # do one epoch for training
176
+ def train_epoch():
177
+ tqdm_train_set = tqdm(train_set)
178
+ total_train_loss = 0
179
+ total_acc_num = 0
180
+ iter_idx = 1
181
+ model.train()
182
+
183
+ for batch in tqdm_train_set:
184
+ if is_autocast:
185
+ with autocast(device_type='cuda'):
186
+ loss, acc_num, prediction, labels = process_one_batch(batch)
187
+ scaler.scale(loss).backward()
188
+ scaler.step(optimizer)
189
+ scaler.update()
190
+ else:
191
+ loss, acc_num, prediction, labels = process_one_batch(batch)
192
+ loss.backward()
193
+ optimizer.step()
194
+
195
+ lr_scheduler.step()
196
+ model.zero_grad(set_to_none=True)
197
+ total_train_loss += loss.item()
198
+ total_acc_num += acc_num.item()
199
+ tqdm_train_set.set_postfix({str(global_rank)+'_train_acc': total_acc_num / (iter_idx*batch_size)})
200
+ # Log the training loss to wandb
201
+ if global_rank==0 and WANDB_LOG:
202
+ wandb.log({"acc": total_acc_num / (iter_idx*batch_size)})
203
+
204
+ iter_idx += 1
205
+
206
+ if BALANCED_TRAINING:
207
+ train_set.dataset.on_epoch_end()
208
+
209
+ return total_acc_num / ((iter_idx-1)*batch_size)
210
+
211
+ # do one epoch for eval
212
+ def eval_epoch():
213
+ tqdm_eval_set = tqdm(eval_set)
214
+ total_eval_loss = 0
215
+ total_acc_num = 0
216
+ iter_idx = 1
217
+ model.eval()
218
+
219
+ all_predictions = []
220
+ all_labels = []
221
+
222
+ # Evaluate data for one epoch
223
+ for batch in tqdm_eval_set:
224
+ with torch.no_grad():
225
+ loss, acc_num, prediction, labels = process_one_batch(batch)
226
+ total_eval_loss += loss.item()
227
+ total_acc_num += acc_num.item()
228
+
229
+ # Accumulate predictions and labels
230
+ all_predictions.extend(prediction.cpu().numpy())
231
+ all_labels.extend(labels.cpu().numpy())
232
+
233
+ tqdm_eval_set.set_postfix({str(global_rank)+'_eval_acc': total_acc_num / (iter_idx*batch_size)})
234
+ iter_idx += 1
235
+
236
+ # Compute F1 Macro
237
+ f1_macro = f1_score(all_labels, all_predictions, average='macro')
238
+ return total_acc_num / ((iter_idx - 1) * batch_size), f1_macro
239
+
240
+ # train and eval
241
+ if __name__ == "__main__":
242
+
243
+ label2idx = train_set.label2idx
244
+ max_eval_acc = 0
245
+ train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank)
246
+ eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank)
247
+
248
+ train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
249
+ eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
250
+
251
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = len(train_set))
252
+
253
+ model = model.to(device)
254
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
255
+
256
+ if WANDB_LOG and global_rank==0:
257
+ # Initialize wandb
258
+ if WANDB_KEY:
259
+ wandb.login(key=WANDB_KEY)
260
+ wandb.init(project="linear",
261
+ name=WEIGHTS_PATH.replace("weights_", "").replace(".pth", ""))
262
+
263
+ for epoch in range(1, NUM_EPOCHS+1):
264
+ train_sampler.set_epoch(epoch)
265
+ eval_sampler.set_epoch(epoch)
266
+ print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
267
+ train_acc = train_epoch()
268
+ eval_acc, eval_f1_macro = eval_epoch()
269
+ if global_rank==0:
270
+ with open(LOGS_PATH,'a') as f:
271
+ f.write("Epoch " + str(epoch) + "\ntrain_acc: " + str(train_acc) + "\neval_acc: " +str(eval_acc) + "\neval_f1_macro: " +str(eval_f1_macro) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
272
+ if eval_acc > max_eval_acc:
273
+ best_epoch = epoch
274
+ max_eval_acc = eval_acc
275
+ checkpoint = {
276
+ 'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
277
+ 'optimizer': optimizer.state_dict(),
278
+ 'lr_sched': lr_scheduler.state_dict(),
279
+ 'epoch': epoch,
280
+ 'best_epoch': best_epoch,
281
+ 'max_eval_acc': max_eval_acc,
282
+ "labels": label2idx
283
+ }
284
+ torch.save(checkpoint, WEIGHTS_PATH)
285
+ with open(LOGS_PATH,'a') as f:
286
+ f.write("Best Epoch so far!\n\n\n")
287
+
288
+ if world_size > 1:
289
+ dist.barrier()
290
+
291
+ if global_rank==0:
292
+ print("Best Eval Epoch : "+str(best_epoch))
293
+ print("Max Eval Accuracy : "+str(max_eval_acc))
music_classification/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config import *
3
+
4
+ class LinearClassification(torch.nn.Module):
5
+ def __init__(self, num_classes):
6
+ super(LinearClassification, self).__init__()
7
+ self.fc1 = torch.nn.Linear(INPUT_HIDDEN_SIZE, HIDDEN_SIZE)
8
+ self.relu = torch.nn.ReLU()
9
+ self.fc2 = torch.nn.Linear(HIDDEN_SIZE, num_classes)
10
+ self.softmax = torch.nn.Softmax(dim=1)
11
+
12
+ def forward(self, x):
13
+ # Apply the linear layer and ReLU to each time step
14
+ x = self.fc1(x) # x shape (B, L, H) -> (B, L, hidden_size)
15
+ x = self.relu(x)
16
+
17
+ # Average over the time steps (L dimension)
18
+ x = x.mean(dim=1) # Now x has shape (B, hidden_size)
19
+
20
+ x = self.fc2(x) # Now applying the final layer (B, hidden_size) -> (B, num_classes)
21
+ x = self.softmax(x)
22
+ return x
process_data/README.md ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Processing Database
2
+
3
+ ## Overview
4
+ This codebase contains scripts and utilities for converting between various musical data formats, including ABC notation, MusicXML, MIDI, and MTF (MIDI Text Format). Additionally, it includes a script for summarizing music metadata, which is represented in JSON format containing textual information, using the OpenAI GPT-4 API. The GPT-4 model processes this metadata to generate concise summaries in multiple languages to boost multilingual MIR. These tools are designed to facilitate the transformation and manipulation of musical files, as well as to provide concise multilingual summaries of music metadata for use with CLaMP 2.
5
+
6
+
7
+ ## About ABC notation
8
+ ### Standard ABC Notation
9
+ ABC notation (sheet music), a text-based sheet music representation like stave notation, is theory-oriented and ideal for presenting complex musical concepts to musicians for study and analysis. Standard ABC notation encodes each voice separately, which often results in corresponding bars being spaced far apart. This separation makes it difficult for models to accurately understand the interactions between voices in sheet music that are meant to align musically.
10
+
11
+ Example Standard ABC notation representation:
12
+ ```
13
+ %%score { 1 | 2 }
14
+ L:1/8
15
+ Q:1/4=120
16
+ M:3/4
17
+ K:G
18
+ V:1 treble nm="Piano" snm="Pno."
19
+ V:2 bass
20
+ V:1
21
+ !mf!"^Allegro" d2 (GA Bc | d2) .G2 .G2 |]
22
+ V:2
23
+ [G,B,D]4 A,2 | B,6 |]
24
+ ```
25
+
26
+ ### Interleaved ABC Notation
27
+ In contrast, interleaved ABC notation effectively aligns multi-track music by integrating multiple voices of the same bar into a single line, ensuring that all parts remain synchronized. This format combines voices in-line and tags each bar with its corresponding voice (e.g., `[V:1]` for treble and `[V:2]` for bass). By directly aligning related bars, interleaved ABC notation enhances the model’s understanding of how different voices interact within the same bar.
28
+
29
+ Below is the same data optimized with M3 encoding, where each bar or header corresponds to a patch:
30
+ ```
31
+ %%score { 1 | 2 }
32
+ L:1/8
33
+ Q:1/4=120
34
+ M:3/4
35
+ K:G
36
+ V:1 treble nm="Piano" snm="Pno."
37
+ V:2 bass
38
+ [V:1]!mf!"^Allegro" d2 (GA Bc|[V:2][G,B,D]4 A,2|
39
+ [V:1]d2) .G2 .G2|][V:2]B,6|]
40
+ ```
41
+
42
+ ## About MTF
43
+ ### Raw MIDI Messages
44
+ MIDI (performance data) precisely encodes performance information related to timing and dynamics, thus suitable for music production and live performance. Raw MIDI messages contain essential musical instructions and metadata, extracted directly from a MIDI file. These include events like note on/off, tempo changes, key signatures, and control changes, which define how the music is performed. The [mido library](https://mido.readthedocs.io/) allows for reading these messages in their native format, as seen below. Each message can include multiple parameters, making the output comprehensive but sometimes redundant.
45
+
46
+ ```
47
+ MetaMessage ('time_signature', numerator=3, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0)
48
+ MetaMessage('key_signature', key='G', time=0)
49
+ MetaMessage('set_tempo', tempo=500000, time=0)
50
+ control_change channel=0 control=121 value=0 time=0
51
+ program_change channel=0 program=0 time=0
52
+ control_change channel=0 control=7 value=100 time=0
53
+ control_change channel=0 control=10 value=64 time=0
54
+ control_change channel=0 control=91 value=0 time=0
55
+ control_change channel=0 control=93 value=0 time=0
56
+ MetaMessage('midi_port', port=0, time=0)
57
+ note_on channel=0 note=74 velocity=80 time=0
58
+ MetaMessage('key_signature', key='G', time=0)
59
+ MetaMessage('midi_port', port=0, time=0)
60
+ note_on channel=0 note=55 velocity=80 time=0
61
+ note_on channel=0 note=59 velocity=80 time=0
62
+ note_on channel=0 note=62 velocity=80 time=0
63
+ note_on channel=0 note=74 velocity=0 time=455
64
+ note_on channel=0 note=67 velocity=80 time=25
65
+ note_on channel=0 note=67 velocity=0 time=239
66
+ note_on channel=0 note=69 velocity=80 time=1
67
+ note_on channel=0 note=55 velocity=0 time=191
68
+ note_on channel=0 note=59 velocity=0 time=0
69
+ note_on channel=0 note=62 velocity=0 time=0
70
+ note_on channel=0 note=69 velocity=0 time=48
71
+ note_on channel=0 note=71 velocity=80 time=1
72
+ note_on channel=0 note=57 velocity=80 time=0
73
+ note_on channel=0 note=71 velocity=0 time=239
74
+ note_on channel=0 note=72 velocity=80 time=1
75
+ note_on channel=0 note=57 velocity=0 time=215
76
+ note_on channel=0 note=72 velocity=0 time=24
77
+ note_on channel=0 note=74 velocity=80 time=1
78
+ note_on channel=0 note=59 velocity=80 time=0
79
+ note_on channel=0 note=74 velocity=0 time=455
80
+ note_on channel=0 note=67 velocity=80 time=25
81
+ note_on channel=0 note=67 velocity=0 time=239
82
+ note_on channel=0 note=67 velocity=80 time=241
83
+ note_on channel=0 note=67 velocity=0 time=239
84
+ note_on channel=0 note=59 velocity=0 time=168
85
+ MetaMessage('end_of_track', time=1)
86
+ ```
87
+ ### MIDI Text Format (MTF)
88
+ The MIDI Text Format (MTF) provides a structured, textual representation of MIDI data that preserves all original information without loss. Each MIDI message is accurately represented, allowing full reconstruction, ensuring no musical nuances are overlooked during conversion.
89
+
90
+ To generate MTF, the mido library reads raw MIDI messages from MIDI files. The output retains all essential information but can be lengthy and redundant. To simplify the representation, parameter values are read in a fixed order and separated by spaces. For example, the raw time signature message, which contains several parameters—numerator, denominator, clocks per click, notated 32nd notes per beat, and time—is represented in MTF as:
91
+
92
+ ```
93
+ time_signature 3 4 24 8 0
94
+ ```
95
+
96
+ Other messages, such as control changes and note events, follow a similar compact format while preserving all relevant musical details. This structured simplification improves computational performance and maintains precise control over musical elements, including timing and dynamics.
97
+
98
+ Example MTF representation:
99
+ ```
100
+ ticks_per_beat 480
101
+ time_signature 3 4 24 8 0
102
+ key_signature G 0
103
+ set_tempo 500000 0
104
+ control_change 0 0 121 0
105
+ program_change 0 0 0
106
+ control_change 0 0 7 100
107
+ control_change 0 0 10 64
108
+ control_change 0 0 91 0
109
+ control_change 0 0 93 0
110
+ midi_port 0 0
111
+ note_on 0 0 74 80
112
+ key_signature G 0
113
+ midi_port 0 0
114
+ note_on 0 0 55 80
115
+ note_on 0 0 59 80
116
+ note_on 0 0 62 80
117
+ note_on 455 0 74 0
118
+ note_on 25 0 67 80
119
+ note_on 239 0 67 0
120
+ note_on 1 0 69 80
121
+ note_on 191 0 55 0
122
+ note_on 0 0 59 0
123
+ note_on 0 0 62 0
124
+ note_on 48 0 69 0
125
+ note_on 1 0 71 80
126
+ note_on 0 0 57 80
127
+ note_on 239 0 71 0
128
+ note_on 1 0 72 80
129
+ note_on 215 0 57 0
130
+ note_on 24 0 72 0
131
+ note_on 1 0 74 80
132
+ note_on 0 0 59 80
133
+ note_on 455 0 74 0
134
+ note_on 25 0 67 80
135
+ note_on 239 0 67 0
136
+ note_on 241 0 67 80
137
+ note_on 239 0 67 0
138
+ note_on 168 0 59 0
139
+ end_of_track 1
140
+ ```
141
+ For simplicity, `ticks_per_beat`, though originally an attribute of MIDI objects in mido, is included as the first message at the beginning of the MTF representation.
142
+
143
+ ### M3-Encoded MTF
144
+ When processed using M3 encoding, consecutive messages of the same type that fit within a 64-character limit (the patch size of M3) are combined into a single line. Only the first message in each group specifies the type, with subsequent messages listing only the parameter values separated by tabs. This further simplifies the representation and improves processing efficiency.
145
+
146
+ Below is the same data optimized with M3 encoding, where each line corresponds to a patch:
147
+ ```
148
+ ticks_per_beat 480
149
+ time_signature 3 4 24 8 0
150
+ key_signature G 0
151
+ set_tempo 500000 0
152
+ control_change 0 0 121 0
153
+ program_change 0 0 0
154
+ control_change 0 0 7 100\t0 0 10 64\t0 0 91 0\t0 0 93 0
155
+ midi_port 0 0
156
+ note_on 0 0 74 80
157
+ key_signature G 0
158
+ midi_port 0 0
159
+ note_on 0 0 55 80\t0 0 59 80\t0 0 62 80\t455 0 74 0\t25 0 67 80
160
+ note_on 239 0 67 0\t1 0 69 80\t191 0 55 0\t0 0 59 0\t0 0 62 0
161
+ note_on 48 0 69 0\t1 0 71 80\t0 0 57 80\t239 0 71 0\t1 0 72 80
162
+ note_on 215 0 57 0\t24 0 72 0\t1 0 74 80\t0 0 59 80\t455 0 74 0
163
+ note_on 25 0 67 80\t239 0 67 0\t0 67 80\t239 0 67 0\t168 0 59 0
164
+ end_of_track 1
165
+ ```
166
+
167
+ By reducing redundancy, M3 encoding ensures improved computational performance while maintaining precise timing and musical control, making it an ideal choice for efficient MIDI processing.
168
+
169
+ ## Repository Structure
170
+ The `process_data/` folder includes the following scripts and utility files:
171
+
172
+ ### 1. **Conversion Scripts**
173
+
174
+ #### `batch_abc2xml.py`
175
+ - **Purpose**: Converts ABC notation files into MusicXML format.
176
+ - **Input**: Directory of interleaved ABC files (modify the `input_dir` variable in the code).
177
+ - **Output**: MusicXML files saved in a newly created `_xml` directory.
178
+ - **Logging**: Errors are logged to `logs/abc2xml_error_log.txt`.
179
+
180
+ #### `batch_xml2abc.py`
181
+ - **Purpose**: Converts MusicXML files into standard ABC notation format.
182
+ - **Input**: Directory of MusicXML files (e.g., `.xml`, `.mxl`, `.musicxml`) (modify the `input_dir` variable in the code).
183
+ - **Output**: Standard ABC files saved in a newly created `_abc` directory.
184
+ - **Logging**: Errors are logged to `logs/xml2abc_error_log.txt`.
185
+
186
+ #### `batch_interleaved_abc.py`
187
+ - **Purpose**: Processes standard ABC notation files into interleaved ABC notation.
188
+ - **Input**: Directory of ABC files (modify the `input_dir` variable in the code).
189
+ - **Output**: Interleaved ABC files saved in a newly created `_interleaved` directory.
190
+ - **Logging**: Any processing errors are printed to the console.
191
+
192
+ #### `batch_midi2mtf.py`
193
+ - **Purpose**: Converts MIDI files into MIDI Text Format (MTF).
194
+ - **Input**: Directory of MIDI files (e.g., `.mid`, `.midi`) (modify the `input_dir` variable in the code).
195
+ - **Output**: MTF files saved in a newly created `_mtf` directory.
196
+ - **Logging**: Errors are logged to `logs/midi2mtf_error_log.txt`.
197
+ - **Note**: The script includes an `m3_compatible` variable, which is set to `True` by default. When `True`, the conversion omits messages whose parameters are strings or lists to eliminate potential natural language information. This ensures that the converted MTF files align with the data format used for training the M3 and CLaMP 2 pretrained weights.
198
+
199
+ #### `batch_mtf2midi.py`
200
+ - **Purpose**: Converts MTF files into MIDI format.
201
+ - **Input**: Directory of MTF files (modify the `input_dir` variable in the code).
202
+ - **Output**: MIDI files saved in a newly created `_midi` directory.
203
+ - **Logging**: Errors are logged to `logs/mtf2midi_error_log.txt`.
204
+
205
+ ### 2. **Summarization Script**
206
+
207
+ #### `gpt4_summarize.py`
208
+ - **Purpose**: Utilizes the OpenAI GPT-4 API to generate concise summaries of music metadata in multiple languages. The script filters out any entries that lack sufficient musical information to ensure meaningful summaries are produced.
209
+ - **Input**: Directory of JSON files containing music metadata (modify the `input_dir` variable in the code). For any missing metadata fields, the corresponding keys can be set to `None`. Each JSON file corresponds to a single musical composition and can be linked to both ABC notation and MTF formats. Here’s an example of the required metadata format:
210
+
211
+ ```json
212
+ {
213
+ "title": "Hard Times Come Again No More",
214
+ "composer": "Stephen Foster",
215
+ "genres": ["Children's Music", "Folk"],
216
+ "description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.",
217
+ "lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus",
218
+ "tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"],
219
+ "ensembles": ["Folk Ensemble"],
220
+ "instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"],
221
+ "filepaths": [
222
+ "abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc",
223
+ "mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf"
224
+ ]
225
+ }
226
+ ```
227
+
228
+ - **Output**: JSON files containing structured summaries in both English and a randomly selected non-English language, chosen from a selection of 100 different non-English languages (in this case, Simplified Chinese). Here’s an example of the expected output format:
229
+
230
+ ```json
231
+ {
232
+ "title": "Hard Times Come Again No More",
233
+ "composer": "Stephen Foster",
234
+ "genres": ["Children's Music", "Folk"],
235
+ "description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.",
236
+ "lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus",
237
+ "tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"],
238
+ "ensembles": ["Folk Ensemble"],
239
+ "instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"],
240
+ "summary_en": "\"Hard Times Come Again No More,\" composed by Stephen Foster, is a poignant American parlor song that explores themes of sorrow and hope. The lyrics reflect on the contrast between life's pleasures and its hardships, inviting listeners to acknowledge both joy and suffering. With a heartfelt chorus that repeats the line \"Hard times come again no more,\" the song resonates with nostalgia and resilience. It is often performed by folk ensembles and features a variety of instruments, including vocals, violin, guitar, and banjo, encapsulating the spirit of American roots music.",
241
+ "summary_nen": {
242
+ "language": "Chinese (Simplified)",
243
+ "summary": "《艰难时光再无来临》是斯蒂芬·福斯特创作的一首感人至深的美国小歌厅歌曲,探讨了悲伤与希望的主题。歌词展现了生活的乐趣与艰辛之间的对比,邀请听众去感受快乐与痛苦的交织。歌曲中那句反复吟唱的“艰难时光再无来临”深情地表达了怀旧与坚韧。它常常由民谣乐队演奏,伴随着人声、小提琴、吉他和班卓琴等多种乐器,生动地展现了美国根源音乐的独特魅力。"
244
+ },
245
+ "filepaths": [
246
+ "abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc",
247
+ "mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf"
248
+ ]
249
+ }
250
+ ```
251
+
252
+ - **Logging**: Errors are logged to `logs/gpt4_summarize_error_log.txt`.
253
+
254
+ ### 3. **Utilities**
255
+ - **`utils/`**: Contains utility files required for the conversion processes.
256
+
257
+ ## Usage
258
+ To use the scripts, modify the `input_dir` variable in each script to point to the directory containing your input files. Then run the script from the command line. Below are example commands for each script:
259
+
260
+ ### Example Commands
261
+ ```bash
262
+ # Modify the input_dir variable in the script before running
263
+ python batch_abc2xml.py
264
+ python batch_xml2abc.py
265
+ python batch_interleaved_abc.py
266
+ python batch_midi2mtf.py
267
+ python batch_mtf2midi.py
268
+ python gpt4_summarize.py
269
+ ```
270
+
271
+ ### Execution Order
272
+ To achieve specific conversions, follow the order below:
273
+
274
+ 1. **To obtain interleaved ABC notation**:
275
+ - First, run `batch_xml2abc.py` to convert MusicXML files to ABC notation.
276
+ - Then, run `batch_interleaved_abc.py` to process the ABC files into interleaved ABC notation.
277
+
278
+ 2. **To obtain MTF**:
279
+ - Run `batch_midi2mtf.py` to convert MIDI files into MTF.
280
+
281
+ 3. **To convert interleaved ABC back to XML**:
282
+ - Run `batch_xml2abc.py` on the interleaved ABC files to convert them back to MusicXML format.
283
+
284
+ 4. **To convert MTF back to MIDI**:
285
+ - Run `batch_mtf2midi.py` to convert MTF files back to MIDI format.
286
+
287
+ 5. **To summarize music metadata**:
288
+ - Run `gpt4_summarize.py` to generate summaries for the music metadata files in JSON format. This assumes you have a directory of JSON files that includes a `filepaths` key, which connects to the corresponding interleaved ABC and MTF files.
289
+
290
+ ### Parameters
291
+ To run the scripts, you need to configure the following parameters:
292
+
293
+ - **`input_dir`**: This variable should be set to the directory containing the input files to be processed (such as ABC, MusicXML, MIDI, MTF, or JSON files), which is shared across all scripts.
294
+
295
+ In addition to **`input_dir`**, the following parameters are specific to certain scripts:
296
+
297
+ - **`m3_compatible`** (specific to `batch_midi2mtf.py`):
298
+ - Default is `True`, which omits messages with parameters that are strings or lists to avoid including potential natural language information.
299
+ - Setting this to `False` retains all MIDI messages, which is crucial for those planning to retrain models on custom datasets or needing precise MIDI reproduction.
300
+
301
+ For **`gpt4_summarize.py`**, you also need to configure these parameters:
302
+
303
+ 1. **`base_url`**: The base URL for the OpenAI API, used to initialize the client.
304
+ 2. **`api_key`**: Your API key for authenticating requests, required for client initialization.
305
+ 3. **`model`**: The GPT-4 model to use, specified when generating summaries.
306
+
307
+ **Important**: When `m3_compatible` is set to `True`, the conversion back from MTF to MIDI using `batch_mtf2midi.py` may produce MIDI files that do not exactly match the original MIDI files. This discrepancy is unexpected; however, retraining both M3 and CLaMP 2 to address this issue would require approximately 6000 hours of H800 GPU hours. Considering that M3 and CLaMP 2 have already achieved state-of-the-art results on MIDI tasks, we have opted not to retrain. Therefore, if consistency with original MIDI files is critical for your application, it is advisable to set `m3_compatible` to `False`.
process_data/batch_abc2xml.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = "<path_to_your_interleaved_abc_files>" # Replace with the path to your folder containing interleaved ABC (.abc) files
2
+
3
+ import os
4
+ import math
5
+ import random
6
+ import subprocess
7
+ from tqdm import tqdm
8
+ from multiprocessing import Pool
9
+
10
+ def convert_abc2xml(file_list):
11
+ cmd = 'cmd /u /c python utils/abc2xml.py '
12
+ for file in tqdm(file_list):
13
+ filename = file.split('/')[-1] # Extract file name
14
+ output_dir = file.split('/')[:-1] # Extract directory path
15
+ output_dir[0] = output_dir[0] + '_xml' # Create new output folder
16
+ output_dir = '/'.join(output_dir)
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ try:
20
+ p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
21
+ result = p.communicate()
22
+ output = result[0].decode('utf-8')
23
+
24
+ if output == '':
25
+ with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
26
+ f.write(file + '\n')
27
+ continue
28
+ else:
29
+ output_path = f"{output_dir}/" + ".".join(filename.split(".")[:-1]) + ".xml"
30
+ with open(output_path, 'w', encoding='utf-8') as f:
31
+ f.write(output)
32
+ except Exception as e:
33
+ with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
34
+ f.write(file + ' ' + str(e) + '\n')
35
+ pass
36
+
37
+ if __name__ == '__main__':
38
+ file_list = []
39
+ os.makedirs("logs", exist_ok=True)
40
+
41
+ # Traverse the specified folder for ABC files
42
+ for root, dirs, files in os.walk(input_dir):
43
+ for file in files:
44
+ if not file.endswith(".abc"):
45
+ continue
46
+ filename = os.path.join(root, file).replace("\\", "/")
47
+ file_list.append(filename)
48
+
49
+ # Prepare for multiprocessing
50
+ file_lists = []
51
+ random.shuffle(file_list)
52
+ for i in range(os.cpu_count()):
53
+ start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
54
+ end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
55
+ file_lists.append(file_list[start_idx:end_idx])
56
+
57
+ pool = Pool(processes=os.cpu_count())
58
+ pool.map(convert_abc2xml, file_lists)
process_data/batch_interleaved_abc.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = "<path_to_your_abc_files>" # Replace with the path to your folder containing standard ABC (.abc) files
2
+
3
+ import os
4
+ import re
5
+ import random
6
+ from multiprocessing import Pool
7
+ from tqdm import tqdm
8
+ from abctoolkit.utils import (
9
+ find_all_abc,
10
+ remove_information_field,
11
+ remove_bar_no_annotations,
12
+ Quote_re,
13
+ Barlines,
14
+ strip_empty_bars
15
+ )
16
+ from abctoolkit.rotate import rotate_abc
17
+ from abctoolkit.check import check_alignment_unrotated
18
+
19
+ def abc_pipeline(abc_path, input_dir, output_dir):
20
+ """
21
+ Converts standard ABC notation to interleaved ABC notation.
22
+ """
23
+ with open(abc_path, 'r', encoding='utf-8') as f:
24
+ abc_lines = f.readlines()
25
+
26
+ abc_lines = [line for line in abc_lines if line.strip() != '']
27
+ abc_lines = remove_information_field(
28
+ abc_lines=abc_lines,
29
+ info_fields=['X:', 'T:', 'C:', 'W:', 'w:', 'Z:', '%%MIDI']
30
+ )
31
+ abc_lines = remove_bar_no_annotations(abc_lines)
32
+
33
+ # Remove escaped quotes and clean up barlines inside quotes
34
+ for i, line in enumerate(abc_lines):
35
+ if not (re.search(r'^[A-Za-z]:', line) or line.startswith('%')):
36
+ abc_lines[i] = line.replace(r'\"', '')
37
+ quote_contents = re.findall(Quote_re, line)
38
+ for quote_content in quote_contents:
39
+ for barline in Barlines:
40
+ if barline in quote_content:
41
+ line = line.replace(quote_content, '')
42
+ abc_lines[i] = line
43
+
44
+ try:
45
+ stripped_abc_lines, bar_counts = strip_empty_bars(abc_lines)
46
+ except Exception as e:
47
+ print(abc_path, 'Error in stripping empty bars:', e)
48
+ return
49
+
50
+ if stripped_abc_lines is None:
51
+ print(abc_path, 'Failed to strip')
52
+ return
53
+
54
+ # Check alignment
55
+ _, bar_no_equal_flag, bar_dur_equal_flag = check_alignment_unrotated(stripped_abc_lines)
56
+ if not bar_no_equal_flag:
57
+ print(abc_path, 'Unequal bar number')
58
+ if not bar_dur_equal_flag:
59
+ print(abc_path, 'Unequal bar duration (unaligned)')
60
+
61
+ # Construct the output path, maintaining input folder structure
62
+ relative_path = os.path.relpath(abc_path, input_dir) # Get relative path from input dir
63
+ output_file_path = os.path.join(output_dir, relative_path) # Recreate output path
64
+ os.makedirs(os.path.dirname(output_file_path), exist_ok=True) # Ensure output folder exists
65
+
66
+ try:
67
+ rotated_abc_lines = rotate_abc(stripped_abc_lines)
68
+ except Exception as e:
69
+ print(abc_path, 'Error in rotating:', e)
70
+ return
71
+
72
+ if rotated_abc_lines is None:
73
+ print(abc_path, 'Failed to rotate')
74
+ return
75
+
76
+ with open(output_file_path, 'w', encoding='utf-8') as w:
77
+ w.writelines(rotated_abc_lines)
78
+
79
+ def abc_pipeline_list(abc_path_list, input_dir, output_dir):
80
+ for abc_path in tqdm(abc_path_list):
81
+ try:
82
+ abc_pipeline(abc_path, input_dir, output_dir)
83
+ except Exception as e:
84
+ print(abc_path, e)
85
+ pass
86
+
87
+ def batch_abc_pipeline(input_dir):
88
+ """
89
+ Batch process all ABC files from `input_dir`, converting them to interleaved notation.
90
+ """
91
+ output_dir = input_dir + "_interleaved"
92
+ if not os.path.exists(output_dir):
93
+ os.makedirs(output_dir, exist_ok=True)
94
+
95
+ abc_path_list = []
96
+ for abc_path in find_all_abc(input_dir):
97
+ if os.path.getsize(abc_path) > 0:
98
+ abc_path_list.append(abc_path)
99
+ random.shuffle(abc_path_list)
100
+ print(f"Found {len(abc_path_list)} ABC files.")
101
+
102
+ num_cpus = os.cpu_count()
103
+ split_lists = [[] for _ in range(num_cpus)]
104
+ index = 0
105
+
106
+ for abc_path in abc_path_list:
107
+ split_lists[index].append(abc_path)
108
+ index = (index + 1) % num_cpus
109
+
110
+ pool = Pool(processes=num_cpus)
111
+ pool.starmap(
112
+ abc_pipeline_list,
113
+ [(split, input_dir, output_dir) for split in split_lists]
114
+ )
115
+
116
+ if __name__ == '__main__':
117
+ batch_abc_pipeline(input_dir)
process_data/batch_midi2mtf.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = "<path_to_your_midi_files>" # Replace with the path to your folder containing MIDI (.midi, .mid) files
2
+ m3_compatible = True # Set to True for M3 compatibility; set to False to retain all MIDI information during conversion.
3
+
4
+ import os
5
+ import math
6
+ import mido
7
+ import random
8
+ from tqdm import tqdm
9
+ from multiprocessing import Pool
10
+
11
+ def msg_to_str(msg):
12
+ str_msg = ""
13
+ for key, value in msg.dict().items():
14
+ str_msg += " " + str(value)
15
+ return str_msg.strip().encode('unicode_escape').decode('utf-8')
16
+
17
+ def load_midi(filename):
18
+ # Load a MIDI file
19
+ mid = mido.MidiFile(filename)
20
+ msg_list = ["ticks_per_beat " + str(mid.ticks_per_beat)]
21
+
22
+ # Traverse the MIDI file
23
+ for msg in mid.merged_track:
24
+ if m3_compatible:
25
+ if msg.is_meta:
26
+ if msg.type in ["text", "copyright", "track_name", "instrument_name",
27
+ "lyrics", "marker", "cue_marker", "device_name", "sequencer_specific"]:
28
+ continue
29
+ else:
30
+ if msg.type in ["sysex"]:
31
+ continue
32
+ str_msg = msg_to_str(msg)
33
+ msg_list.append(str_msg)
34
+
35
+ return "\n".join(msg_list)
36
+
37
+ def convert_midi2mtf(file_list):
38
+ for file in tqdm(file_list):
39
+ filename = file.split('/')[-1]
40
+ output_dir = file.split('/')[:-1]
41
+ output_dir[0] = output_dir[0] + '_mtf'
42
+ output_dir = '/'.join(output_dir)
43
+ os.makedirs(output_dir, exist_ok=True)
44
+ try:
45
+ output = load_midi(file)
46
+
47
+ if output == '':
48
+ with open('logs/midi2mtf_error_log.txt', 'a', encoding='utf-8') as f:
49
+ f.write(file + '\n')
50
+ continue
51
+ else:
52
+ with open(output_dir + "/" + ".".join(filename.split(".")[:-1]) + '.mtf', 'w', encoding='utf-8') as f:
53
+ f.write(output)
54
+ except Exception as e:
55
+ with open('logs/midi2mtf_error_log.txt', 'a', encoding='utf-8') as f:
56
+ f.write(file + " " + str(e) + '\n')
57
+ pass
58
+
59
+ if __name__ == '__main__':
60
+ file_list = []
61
+ os.makedirs("logs", exist_ok=True)
62
+
63
+ # Traverse the specified folder for MIDI files
64
+ for root, dirs, files in os.walk(input_dir):
65
+ for file in files:
66
+ if not file.endswith(".mid") and not file.endswith(".midi"):
67
+ continue
68
+ filename = os.path.join(root, file).replace("\\", "/")
69
+ file_list.append(filename)
70
+
71
+ # Prepare for multiprocessing
72
+ file_lists = []
73
+ random.shuffle(file_list)
74
+ for i in range(os.cpu_count()):
75
+ start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
76
+ end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
77
+ file_lists.append(file_list[start_idx:end_idx])
78
+
79
+ pool = Pool(processes=os.cpu_count())
80
+ pool.map(convert_midi2mtf, file_lists)
process_data/batch_mtf2midi.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = "<path_to_your_mtf_files>" # Replace with the path to your folder containing MTF (.mtf) files
2
+
3
+ import os
4
+ import math
5
+ import mido
6
+ import random
7
+ from tqdm import tqdm
8
+ from multiprocessing import Pool
9
+
10
+ def str_to_msg(str_msg):
11
+ type = str_msg.split(" ")[0]
12
+ try:
13
+ msg = mido.Message(type)
14
+ except:
15
+ msg = mido.MetaMessage(type)
16
+
17
+ if type in ["text", "copyright", "track_name", "instrument_name",
18
+ "lyrics", "marker", "cue_marker", "device_name"]:
19
+ values = [type, " ".join(str_msg.split(" ")[1:-1]).encode('utf-8').decode('unicode_escape'), str_msg.split(" ")[-1]]
20
+ elif "[" in str_msg or "(" in str_msg:
21
+ is_bracket = "[" in str_msg
22
+ left_idx = str_msg.index("[") if is_bracket else str_msg.index("(")
23
+ right_idx = str_msg.index("]") if is_bracket else str_msg.index(")")
24
+ list_str = [int(num) for num in str_msg[left_idx+1:right_idx].split(", ")]
25
+ if not is_bracket:
26
+ list_str = tuple(list_str)
27
+ values = str_msg[:left_idx].split(" ") + [list_str] + str_msg[right_idx+1:].split(" ")
28
+ values = [value for value in values if value != ""]
29
+ else:
30
+ values = str_msg.split(" ")
31
+
32
+ if len(values) != 1:
33
+ for idx, (key, content) in enumerate(msg.__dict__.items()):
34
+ if key == "type":
35
+ continue
36
+ value = values[idx]
37
+ if isinstance(content, int) or isinstance(content, float):
38
+ float_value = float(value)
39
+ value = float_value
40
+ if value % 1 == 0:
41
+ value = int(value)
42
+ setattr(msg, key, value)
43
+
44
+ return msg
45
+
46
+ def convert_mtf2midi(file_list):
47
+ for file in tqdm(file_list):
48
+ filename = file.split('/')[-1]
49
+ output_dir = file.split('/')[:-1]
50
+ output_dir[0] = output_dir[0] + '_midi'
51
+ output_dir = '/'.join(output_dir)
52
+ os.makedirs(output_dir, exist_ok=True)
53
+ try:
54
+ with open(file, 'r', encoding='utf-8') as f:
55
+ msg_list = f.read().splitlines()
56
+
57
+ # Build a new MIDI file based on the MIDI messages
58
+ new_mid = mido.MidiFile()
59
+ new_mid.ticks_per_beat = int(msg_list[0].split(" ")[1])
60
+
61
+ track = mido.MidiTrack()
62
+ new_mid.tracks.append(track)
63
+
64
+ for msg in msg_list[1:]:
65
+ if "unknown_meta" in msg:
66
+ continue
67
+ new_msg = str_to_msg(msg)
68
+ track.append(new_msg)
69
+
70
+ output_file_path = os.path.join(output_dir, os.path.basename(file).replace('.mtf', '.mid'))
71
+ new_mid.save(output_file_path)
72
+ except Exception as e:
73
+ with open('logs/mtf2midi_error_log.txt', 'a', encoding='utf-8') as f:
74
+ f.write(f"Error processing {file}: {str(e)}\n")
75
+
76
+ if __name__ == '__main__':
77
+ file_list = []
78
+ os.makedirs("logs", exist_ok=True)
79
+
80
+ # Traverse the specified folder for MTF files
81
+ for root, dirs, files in os.walk(input_dir):
82
+ for file in files:
83
+ if not file.endswith(".mtf"):
84
+ continue
85
+ filename = os.path.join(root, file).replace("\\", "/")
86
+ file_list.append(filename)
87
+
88
+ # Prepare for multiprocessing
89
+ file_lists = []
90
+ random.shuffle(file_list)
91
+ for i in range(os.cpu_count()):
92
+ start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
93
+ end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
94
+ file_lists.append(file_list[start_idx:end_idx])
95
+
96
+ pool = Pool(processes=os.cpu_count())
97
+ pool.map(convert_mtf2midi, file_lists)
process_data/batch_xml2abc.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = "<path_to_your_xml_files>" # Replace with the path to your folder containing XML (.xml, .mxl, .musicxml) files
2
+
3
+ import os
4
+ import math
5
+ import random
6
+ import subprocess
7
+ from tqdm import tqdm
8
+ from multiprocessing import Pool
9
+
10
+ def convert_xml2abc(file_list):
11
+ cmd = 'cmd /u /c python utils/xml2abc.py -d 8 -x '
12
+ for file in tqdm(file_list):
13
+ filename = file.split('/')[-1]
14
+ output_dir = file.split('/')[:-1]
15
+ output_dir[0] = output_dir[0] + '_abc'
16
+ output_dir = '/'.join(output_dir)
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ try:
20
+ p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
21
+ result = p.communicate()
22
+ output = result[0].decode('utf-8')
23
+
24
+ if output == '':
25
+ with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
26
+ f.write(file + '\n')
27
+ continue
28
+ else:
29
+ with open(output_dir + '/' + ".".join(filename.split(".")[:-1]) + '.abc', 'w', encoding='utf-8') as f:
30
+ f.write(output)
31
+ except Exception as e:
32
+ with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
33
+ f.write(file + ' ' + str(e) + '\n')
34
+ pass
35
+
36
+ if __name__ == '__main__':
37
+ file_list = []
38
+ os.makedirs("logs", exist_ok=True)
39
+
40
+ # Traverse the specified folder for XML/MXL files
41
+ for root, dirs, files in os.walk(input_dir):
42
+ for file in files:
43
+ if not file.endswith((".mxl", ".xml", ".musicxml")):
44
+ continue
45
+ filename = os.path.join(root, file).replace("\\", "/")
46
+ file_list.append(filename)
47
+
48
+ # Prepare for multiprocessing
49
+ file_lists = []
50
+ random.shuffle(file_list)
51
+ for i in range(os.cpu_count()):
52
+ start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
53
+ end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
54
+ file_lists.append(file_list[start_idx:end_idx])
55
+
56
+ pool = Pool(processes=os.cpu_count())
57
+ pool.map(convert_xml2abc, file_lists)
process_data/gpt4_summarize.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_dir = "<path_to_your_metadata_json_files>" # Replace with the path to your folder containing metadata (.json) files
2
+ base_url = "<your_base_url>" # Replace with the base URL for the API
3
+ api_key = "<your_api_key>" # Replace with your API key
4
+ model = "<your_model>" # Replace with your model name
5
+
6
+ import os
7
+ import json
8
+ import random
9
+ from openai import OpenAI
10
+
11
+ # Initialize the OpenAI client
12
+ client = OpenAI(base_url=base_url, api_key=api_key)
13
+
14
+ def log_error(file_path, error_message):
15
+ """Logs error messages to a specified log file."""
16
+ os.makedirs("logs", exist_ok=True)
17
+ with open("logs/gpt4_summarize_error_log.txt", 'a', encoding='utf-8') as log_file:
18
+ log_file.write(f"Error processing {file_path}: {error_message}\n")
19
+
20
+ def process_json(metadata, language):
21
+ """
22
+ Processes the given metadata of a music piece using GPT-4 API.
23
+
24
+ This function sends the metadata and target language to the GPT-4 model to generate
25
+ a structured summary. The summary is provided in both English and the specified
26
+ non-English language from the 'nen_language' field.
27
+
28
+ If the provided metadata lacks sufficient music-related details, the function returns `None`.
29
+
30
+ Parameters:
31
+ - metadata (dict): A dictionary containing the metadata of the music piece.
32
+ - language (str): The target non-English language for the summary.
33
+
34
+ Returns:
35
+ - str: A JSON-formatted string containing the English and non-English summaries,
36
+ or `None` if there is insufficient information.
37
+ """
38
+ system = """Your task is to provide a concise, comprehensive, and coherent summary of the music piece using the provided metadata. Please write the summary in English first, and then write an equivalent summary in the specified non-English language from the "nen_language" field. Use this JSON format:
39
+ {
40
+ "summary_en": "Your English summary here.",
41
+ "summary_nen": {
42
+ "language": "Specified non-English language.",
43
+ "summary": "Your non-English summary here."
44
+ }
45
+ If there is not enough music-related information, return `None` instead.
46
+ }
47
+ """
48
+ user1 = """{
49
+ "title": "Brejeiro",
50
+ "composer": "Ernesto Nazareth",
51
+ "genres": ["Choro", "Classical", "Instrumental"],
52
+ "description": "\"Brejeiro\" is in A major and 2/4 time. A joyful melody begins at bar six, and a lively tango rhythm starts at bar fourteen. It has a D.C. al Fine at bar fifty-three and ends on two quarter notes in bar thirty-seven. The piece, with its vibrant melodies and rhythms, reflects celebration and carefreeness, embodying the spirit of Brazilian music.",
53
+ "tags": ["Brazilian", "Choro", "Piano"],
54
+ "ensembles": ["Solo Piano", "Small Ensemble"],
55
+ "instruments": ["Piano"],
56
+ "nen_language": "Japanese"
57
+ }
58
+ """
59
+ assistant1 = """{
60
+ "summary_en": "Brejeiro, composed by Ernesto Nazareth, is a lively choro piece in A major and 2/4 time. It features a joyful melody that begins at bar six and a vibrant tango rhythm introduced at bar fourteen. The piece includes a D.C. al Fine at bar fifty-three, concluding on two quarter notes in bar thirty-seven. With its themes of celebration and carefreeness, Brejeiro beautifully captures the essence of Brazilian music and is well-suited for solo piano and small ensembles.",
61
+ "summary_nen": {
62
+ "language": "Japanese",
63
+ "summary": "「ブレジェイロ」は、エルネスト・ナザレが作曲した活気あふれるショーロの作品で、イ長調の2/4拍子で書かれています。第6小節から始まる喜びに満ちたメロディーと、第14小節で導入される活気あるタンゴのリズムが特徴です。この曲には、第53小節でのD.C. al Fineが含まれ、また第37小節で二つの四分音符で締めくくられています。「ブレジェイロ」は、お祝いと無邪気さのテーマを持ち、ブラジル音楽の本質を美しく捉えており、ソロピアノや小編成のアンサンブルにぴったりの作品です。"
64
+ }
65
+ }
66
+ """
67
+ user2 = """{
68
+ "title": "Untitled",
69
+ "composer": "Unknown",
70
+ "description": "This is a good song.",
71
+ "nen_language": "Russian"
72
+ }
73
+ """
74
+ assistant2 = "None"
75
+ filepaths = metadata.pop('filepaths')
76
+ metadata = {k: v for k, v in metadata.items() if v is not None}
77
+
78
+ metadata["nen_language"] = language
79
+ metadata = json.dumps(metadata, ensure_ascii=False, indent=4)
80
+ summaries = client.chat.completions.create(
81
+ model=model,
82
+ messages=[
83
+ {"role": "system", "content": system},
84
+ {"role": "user", "content": user1},
85
+ {"role": "assistant", "content": assistant1},
86
+ {"role": "user", "content": user2},
87
+ {"role": "assistant", "content": assistant2},
88
+ {"role": "user", "content": metadata},
89
+ ]
90
+ ).choices[0].message.content
91
+
92
+ if summaries == "None":
93
+ raise ValueError("Received 'None' as summaries response")
94
+
95
+ metadata = json.loads(metadata)
96
+ summaries = json.loads(summaries)
97
+
98
+ if metadata["nen_language"] == summaries["summary_nen"]["language"]:
99
+ metadata.pop("nen_language")
100
+ metadata["summary_en"] = summaries["summary_en"]
101
+ metadata["summary_nen"] = summaries["summary_nen"]
102
+ metadata["filepaths"] = filepaths
103
+ return metadata
104
+ else:
105
+ raise ValueError("Language mismatch: nen_language does not match summary_nen language")
106
+
107
+ def process_files(input_dir):
108
+ # Create output directory with _summarized suffix
109
+ output_dir = input_dir + "_summarized"
110
+
111
+ # Define available languages
112
+ languages = """Afrikaans
113
+ Amharic
114
+ Arabic
115
+ Assamese
116
+ Azerbaijani
117
+ Belarusian
118
+ Bulgarian
119
+ Bengali
120
+ Bengali (Romanized)
121
+ Breton
122
+ Bosnian
123
+ Catalan
124
+ Czech
125
+ Welsh
126
+ Danish
127
+ German
128
+ Greek
129
+ Esperanto
130
+ Spanish
131
+ Estonian
132
+ Basque
133
+ Persian
134
+ Finnish
135
+ French
136
+ Western Frisian
137
+ Irish
138
+ Scottish Gaelic
139
+ Galician
140
+ Gujarati
141
+ Hausa
142
+ Hebrew
143
+ Hindi
144
+ Hindi (Romanized)
145
+ Croatian
146
+ Hungarian
147
+ Armenian
148
+ Indonesian
149
+ Icelandic
150
+ Italian
151
+ Japanese
152
+ Javanese
153
+ Georgian
154
+ Kazakh
155
+ Khmer
156
+ Kannada
157
+ Korean
158
+ Kurdish (Kurmanji)
159
+ Kyrgyz
160
+ Latin
161
+ Lao
162
+ Lithuanian
163
+ Latvian
164
+ Malagasy
165
+ Macedonian
166
+ Malayalam
167
+ Mongolian
168
+ Marathi
169
+ Malay
170
+ Burmese
171
+ Burmese (Romanized)
172
+ Nepali
173
+ Dutch
174
+ Norwegian
175
+ Oromo
176
+ Oriya
177
+ Punjabi
178
+ Polish
179
+ Pashto
180
+ Portuguese
181
+ Romanian
182
+ Russian
183
+ Sanskrit
184
+ Sindhi
185
+ Sinhala
186
+ Slovak
187
+ Slovenian
188
+ Somali
189
+ Albanian
190
+ Serbian
191
+ Sundanese
192
+ Swedish
193
+ Swahili
194
+ Tamil
195
+ Tamil (Romanized)
196
+ Telugu
197
+ Telugu (Romanized)
198
+ Thai
199
+ Filipino
200
+ Turkish
201
+ Uyghur
202
+ Ukrainian
203
+ Urdu
204
+ Urdu (Romanized)
205
+ Uzbek
206
+ Vietnamese
207
+ Xhosa
208
+ Yiddish
209
+ Chinese (Simplified)
210
+ Chinese (Traditional)
211
+ Cantonese"""
212
+ languages = [language.strip() for language in languages.split("\n")]
213
+
214
+ # Walk through the input directory
215
+ for root, _, files in os.walk(input_dir):
216
+ # Construct the corresponding path in the output folder
217
+ relative_path = os.path.relpath(root, input_dir)
218
+ output_path = os.path.join(output_dir, relative_path)
219
+
220
+ # Create the output directory if it doesn't exist
221
+ os.makedirs(output_path, exist_ok=True)
222
+
223
+ for file in files:
224
+ if file.endswith('.json'):
225
+ input_file = os.path.join(root, file)
226
+ output_file = os.path.join(output_path, file)
227
+
228
+ try:
229
+ # Read the JSON file
230
+ with open(input_file, 'r', encoding='utf-8') as f:
231
+ metadata = json.load(f)
232
+
233
+ # Randomly select a language from the list of languages
234
+ language = random.choice(languages)
235
+
236
+ # Process the JSON data
237
+ processed_metadata = process_json(metadata, language)
238
+
239
+ # Write the processed JSON to the output file
240
+ with open(output_file, 'w', encoding='utf-8') as f:
241
+ json.dump(processed_metadata, f, indent=4, ensure_ascii=False)
242
+
243
+ print(f"Processed: {input_file} -> {output_file}")
244
+
245
+ except Exception as e:
246
+ print(f"Failed to process {input_file}: {e}")
247
+ log_error(input_file, str(e))
248
+
249
+ if __name__ == "__main__":
250
+ process_files(input_dir)
process_data/utils/abc2xml.py ADDED
The diff for this file is too large to render. See raw diff
 
process_data/utils/pyparsing.py ADDED
The diff for this file is too large to render. See raw diff
 
process_data/utils/xml2abc.py ADDED
@@ -0,0 +1,1582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=latin-1
3
+ '''
4
+ Copyright (C) 2012-2018: W.G. Vree
5
+ Contributions: M. Tarenskeen, N. Liberg, Paul Villiger, Janus Meuris, Larry Myerscough,
6
+ Dick Jackson, Jan Wybren de Jong, Mark Zealey.
7
+
8
+ This program is free software; you can redistribute it and/or modify it under the terms of the
9
+ Lesser GNU General Public License as published by the Free Software Foundation;
10
+
11
+ This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
12
+ without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
13
+ See the Lesser GNU General Public License for more details. <http://www.gnu.org/licenses/lgpl.html>.
14
+ '''
15
+
16
+ try: import xml.etree.cElementTree as E
17
+ except: import xml.etree.ElementTree as E
18
+ import os, sys, types, re, math
19
+
20
+ VERSION = 143
21
+
22
+ python3 = sys.version_info.major > 2
23
+ if python3:
24
+ tupletype = tuple
25
+ listtype = list
26
+ max_int = sys.maxsize
27
+ else:
28
+ tupletype = types.TupleType
29
+ listtype = types.ListType
30
+ max_int = sys.maxint
31
+
32
+ note_ornamentation_map = { # for notations/, modified from EasyABC
33
+ 'ornaments/trill-mark': 'T',
34
+ 'ornaments/mordent': 'M',
35
+ 'ornaments/inverted-mordent': 'P',
36
+ 'ornaments/turn': '!turn!',
37
+ 'ornaments/inverted-turn': '!invertedturn!',
38
+ 'technical/up-bow': 'u',
39
+ 'technical/down-bow': 'v',
40
+ 'technical/harmonic': '!open!',
41
+ 'technical/open-string': '!open!',
42
+ 'technical/stopped': '!plus!',
43
+ 'technical/snap-pizzicato': '!snap!',
44
+ 'technical/thumb-position': '!thumb!',
45
+ 'articulations/accent': '!>!',
46
+ 'articulations/strong-accent':'!^!',
47
+ 'articulations/staccato': '.',
48
+ 'articulations/staccatissimo':'!wedge!',
49
+ 'articulations/scoop': '!slide!',
50
+ 'fermata': '!fermata!',
51
+ 'arpeggiate': '!arpeggio!',
52
+ 'articulations/tenuto': '!tenuto!',
53
+ 'articulations/staccatissimo':'!wedge!', # not sure whether this is the right translation
54
+ 'articulations/spiccato': '!wedge!', # not sure whether this is the right translation
55
+ 'articulations/breath-mark': '!breath!', # this may need to be tested to make sure it appears on the right side of the note
56
+ 'articulations/detached-legato': '!tenuto!.',
57
+ }
58
+
59
+ dynamics_map = { # for direction/direction-type/dynamics/
60
+ 'p': '!p!',
61
+ 'pp': '!pp!',
62
+ 'ppp': '!ppp!',
63
+ 'pppp': '!pppp!',
64
+ 'f': '!f!',
65
+ 'ff': '!ff!',
66
+ 'fff': '!fff!',
67
+ 'ffff': '!ffff!',
68
+ 'mp': '!mp!',
69
+ 'mf': '!mf!',
70
+ 'sfz': '!sfz!',
71
+ }
72
+
73
+ percSvg = '''%%beginsvg
74
+ <defs>
75
+ <text id="x" x="-3" y="0">&#xe263;</text>
76
+ <text id="x-" x="-3" y="0">&#xe263;</text>
77
+ <text id="x+" x="-3" y="0">&#xe263;</text>
78
+ <text id="normal" x="-3.7" y="0">&#xe0a3;</text>
79
+ <text id="normal-" x="-3.7" y="0">&#xe0a3;</text>
80
+ <text id="normal+" x="-3.7" y="0">&#xe0a4;</text>
81
+ <g id="circle-x"><text x="-3" y="0">&#xe263;</text><circle r="4" class="stroke"></circle></g>
82
+ <g id="circle-x-"><text x="-3" y="0">&#xe263;</text><circle r="4" class="stroke"></circle></g>
83
+ <path id="triangle" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
84
+ <path id="triangle-" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
85
+ <path id="triangle+" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="fill:#000"></path>
86
+ <path id="square" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
87
+ <path id="square-" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
88
+ <path id="square+" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="fill:#000"></path>
89
+ <path id="diamond" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
90
+ <path id="diamond-" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
91
+ <path id="diamond+" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="fill:#000"></path>
92
+ </defs>
93
+ %%endsvg'''
94
+
95
+ tabSvg = '''%%beginsvg
96
+ <style type="text/css">
97
+ .bf {font-family:sans-serif; font-size:7px}
98
+ </style>
99
+ <defs>
100
+ <rect id="clr" x="-3" y="-1" width="6" height="5" fill="white"></rect>
101
+ <rect id="clr2" x="-3" y="-1" width="11" height="5" fill="white"></rect>'''
102
+
103
+ kopSvg = '<g id="kop%s" class="bf"><use xlink:href="#clr"></use><text x="-2" y="3">%s</text></g>\n'
104
+ kopSvg2 = '<g id="kop%s" class="bf"><use xlink:href="#clr2"></use><text x="-2" y="3">%s</text></g>\n'
105
+
106
+ def info (s, warn=1): sys.stderr.write ((warn and '-- ' or '') + s + '\n')
107
+
108
+ #-------------------
109
+ # data abstractions
110
+ #-------------------
111
+ class Measure:
112
+ def __init__ (s, p):
113
+ s.reset ()
114
+ s.ixp = p # part number
115
+ s.ixm = 0 # measure number
116
+ s.mdur = 0 # measure duration (nominal metre value in divisions)
117
+ s.divs = 0 # number of divisions per 1/4
118
+ s.mtr = 4,4 # meter
119
+
120
+ def reset (s): # reset each measure
121
+ s.attr = '' # measure signatures, tempo
122
+ s.lline = '' # left barline, but only holds ':' at start of repeat, otherwise empty
123
+ s.rline = '|' # right barline
124
+ s.lnum = '' # (left) volta number
125
+
126
+ class Note:
127
+ def __init__ (s, dur=0, n=None):
128
+ s.tijd = 0 # the time in XML division units
129
+ s.dur = dur # duration of a note in XML divisions
130
+ s.fact = None # time modification for tuplet notes (num, div)
131
+ s.tup = [''] # start(s) and/or stop(s) of tuplet
132
+ s.tupabc = '' # abc tuplet string to issue before note
133
+ s.beam = 0 # 1 = beamed
134
+ s.grace = 0 # 1 = grace note
135
+ s.before = [] # abc string that goes before the note/chord
136
+ s.after = '' # the same after the note/chord
137
+ s.ns = n and [n] or [] # notes in the chord
138
+ s.lyrs = {} # {number -> syllabe}
139
+ s.tab = None # (string number, fret number)
140
+ s.ntdec = '' # !string!, !courtesy!
141
+
142
+ class Elem:
143
+ def __init__ (s, string):
144
+ s.tijd = 0 # the time in XML division units
145
+ s.str = string # any abc string that is not a note
146
+
147
+ class Counter:
148
+ def inc (s, key, voice): s.counters [key][voice] = s.counters [key].get (voice, 0) + 1
149
+ def clear (s, vnums): # reset all counters
150
+ tups = list( zip (vnums.keys (), len (vnums) * [0]))
151
+ s.counters = {'note': dict (tups), 'nopr': dict (tups), 'nopt': dict (tups)}
152
+ def getv (s, key, voice): return s.counters[key][voice]
153
+ def prcnt (s, ip): # print summary of all non zero counters
154
+ for iv in s.counters ['note']:
155
+ if s.getv ('nopr', iv) != 0:
156
+ info ( 'part %d, voice %d has %d skipped non printable notes' % (ip, iv, s.getv ('nopr', iv)))
157
+ if s.getv ('nopt', iv) != 0:
158
+ info ( 'part %d, voice %d has %d notes without pitch' % (ip, iv, s.getv ('nopt', iv)))
159
+ if s.getv ('note', iv) == 0: # no real notes counted in this voice
160
+ info ( 'part %d, skipped empty voice %d' % (ip, iv))
161
+
162
+ class Music:
163
+ def __init__(s, options):
164
+ s.tijd = 0 # the current time
165
+ s.maxtime = 0 # maximum time in a measure
166
+ s.gMaten = [] # [voices,.. for all measures in a part]
167
+ s.gLyrics = [] # [{num: (abc_lyric_string, melis)},.. for all measures in a part]
168
+ s.vnums = {} # all used voice id's in a part (xml voice id's == numbers)
169
+ s.cnt = Counter () # global counter object
170
+ s.vceCnt = 1 # the global voice count over all parts
171
+ s.lastnote = None # the last real note record inserted in s.voices
172
+ s.bpl = options.b # the max number of bars per line when writing abc
173
+ s.cpl = options.n # the number of chars per line when writing abc
174
+ s.repbra = 0 # true if volta is used somewhere
175
+ s.nvlt = options.v # no volta on higher voice numbers
176
+ s.jscript = options.j # compatibility with javascript version
177
+
178
+ def initVoices (s, newPart=0):
179
+ s.vtimes, s.voices, s.lyrics = {}, {}, {}
180
+ for v in s.vnums:
181
+ s.vtimes [v] = 0 # {voice: the end time of the last item in each voice}
182
+ s.voices [v] = [] # {voice: [Note|Elem, ..]}
183
+ s.lyrics [v] = [] # {voice: [{num: syl}, ..]}
184
+ if newPart: s.cnt.clear (s.vnums) # clear counters once per part
185
+
186
+ def incTime (s, dt):
187
+ s.tijd += dt
188
+ if s.tijd < 0: s.tijd = 0 # erroneous <backup> element
189
+ if s.tijd > s.maxtime: s.maxtime = s.tijd
190
+
191
+ def appendElemCv (s, voices, elem):
192
+ for v in voices:
193
+ s.appendElem (v, elem) # insert element in all voices
194
+
195
+ def insertElem (s, v, elem): # insert at the start of voice v in the current measure
196
+ obj = Elem (elem)
197
+ obj.tijd = 0 # because voice is sorted later
198
+ s.voices [v].insert (0, obj)
199
+
200
+ def appendObj (s, v, obj, dur):
201
+ obj.tijd = s.tijd
202
+ s.voices [v].append (obj)
203
+ s.incTime (dur)
204
+ if s.tijd > s.vtimes[v]: s.vtimes[v] = s.tijd # don't update for inserted earlier items
205
+
206
+ def appendElem (s, v, elem, tel=0):
207
+ s.appendObj (v, Elem (elem), 0)
208
+ if tel: s.cnt.inc ('note', v) # count number of certain elements in each voice (in addition to notes)
209
+
210
+ def appendElemT (s, v, elem, tijd): # insert element at specified time
211
+ obj = Elem (elem)
212
+ obj.tijd = tijd
213
+ s.voices [v].append (obj)
214
+
215
+ def appendNote (s, v, note, noot):
216
+ note.ns.append (note.ntdec + noot)
217
+ s.appendObj (v, note, int (note.dur))
218
+ s.lastnote = note # remember last note/rest for later modifications (chord, grace)
219
+ if noot != 'z' and noot != 'x': # real notes and grace notes
220
+ s.cnt.inc ('note', v) # count number of real notes in each voice
221
+ if not note.grace: # for every real note
222
+ s.lyrics[v].append (note.lyrs) # even when it has no lyrics
223
+
224
+ def getLastRec (s, voice):
225
+ if s.gMaten: return s.gMaten[-1][voice][-1] # the last record in the last measure
226
+ return None # no previous records in the first measure
227
+
228
+ def getLastMelis (s, voice, num): # get melisma of last measure
229
+ if s.gLyrics:
230
+ lyrdict = s.gLyrics[-1][voice] # the previous lyrics dict in this voice
231
+ if num in lyrdict: return lyrdict[num][1] # lyrdict = num -> (lyric string, melisma)
232
+ return 0 # no previous lyrics in voice or line number
233
+
234
+ def addChord (s, note, noot): # careful: we assume that chord notes follow immediately
235
+ for d in note.before: # put all decorations before chord
236
+ if d not in s.lastnote.before:
237
+ s.lastnote.before += [d]
238
+ s.lastnote.ns.append (note.ntdec + noot)
239
+
240
+ def addBar (s, lbrk, m): # linebreak, measure data
241
+ if m.mdur and s.maxtime > m.mdur: info ('measure %d in part %d longer than metre' % (m.ixm+1, m.ixp+1))
242
+ s.tijd = s.maxtime # the time of the bar lines inserted here
243
+ for v in s.vnums:
244
+ if m.lline or m.lnum: # if left barline or left volta number
245
+ p = s.getLastRec (v) # get the previous barline record
246
+ if p: # in measure 1 no previous measure is available
247
+ x = p.str # p.str is the ABC barline string
248
+ if m.lline: # append begin of repeat, m.lline == ':'
249
+ x = (x + m.lline).replace (':|:','::').replace ('||','|')
250
+ if s.nvlt == 3: # add volta number only to lowest voice in part 0
251
+ if m.ixp + v == min (s.vnums): x += m.lnum
252
+ elif m.lnum: # new behaviour with I:repbra 0
253
+ x += m.lnum # add volta number(s) or text to all voices
254
+ s.repbra = 1 # signal occurrence of a volta
255
+ p.str = x # modify previous right barline
256
+ elif m.lline: # begin of new part and left repeat bar is required
257
+ s.insertElem (v, '|:')
258
+ if lbrk:
259
+ p = s.getLastRec (v) # get the previous barline record
260
+ if p: p.str += lbrk # insert linebreak char after the barlines+volta
261
+ if m.attr: # insert signatures at front of buffer
262
+ s.insertElem (v, '%s' % m.attr)
263
+ s.appendElem (v, ' %s' % m.rline) # insert current barline record at time maxtime
264
+ s.voices[v] = sortMeasure (s.voices[v], m) # make all times consistent
265
+ lyrs = s.lyrics[v] # [{number: sylabe}, .. for all notes]
266
+ lyrdict = {} # {number: (abc_lyric_string, melis)} for this voice
267
+ nums = [num for d in lyrs for num in d.keys ()] # the lyrics numbers in this measure
268
+ maxNums = max (nums + [0]) # the highest lyrics number in this measure
269
+ for i in range (maxNums, 0, -1):
270
+ xs = [syldict.get (i, '') for syldict in lyrs] # collect the syllabi with number i
271
+ melis = s.getLastMelis (v, i) # get melisma from last measure
272
+ lyrdict [i] = abcLyr (xs, melis)
273
+ s.lyrics[v] = lyrdict # {number: (abc_lyric_string, melis)} for this measure
274
+ mkBroken (s.voices[v])
275
+ s.gMaten.append (s.voices)
276
+ s.gLyrics.append (s.lyrics)
277
+ s.tijd = s.maxtime = 0
278
+ s.initVoices ()
279
+
280
+ def outVoices (s, divs, ip, isSib): # output all voices of part ip
281
+ vvmap = {} # xml voice number -> abc voice number (one part)
282
+ vnum_keys = list (s.vnums.keys ())
283
+ if s.jscript or isSib: vnum_keys.sort ()
284
+ lvc = min (vnum_keys or [1]) # lowest xml voice number of this part
285
+ for iv in vnum_keys:
286
+ if s.cnt.getv ('note', iv) == 0: # no real notes counted in this voice
287
+ continue # skip empty voices
288
+ if abcOut.denL: unitL = abcOut.denL # take the unit length from the -d option
289
+ else: unitL = compUnitLength (iv, s.gMaten, divs) # compute the best unit length for this voice
290
+ abcOut.cmpL.append (unitL) # remember for header output
291
+ vn, vl = [], {} # for voice iv: collect all notes to vn and all lyric lines to vl
292
+ for im in range (len (s.gMaten)):
293
+ measure = s.gMaten [im][iv]
294
+ vn.append (outVoice (measure, divs [im], im, ip, unitL))
295
+ checkMelismas (s.gLyrics, s.gMaten, im, iv)
296
+ for n, (lyrstr, melis) in s.gLyrics [im][iv].items ():
297
+ if n in vl:
298
+ while len (vl[n]) < im: vl[n].append ('') # fill in skipped measures
299
+ vl[n].append (lyrstr)
300
+ else:
301
+ vl[n] = im * [''] + [lyrstr] # must skip im measures
302
+ for n, lyrs in vl.items (): # fill up possibly empty lyric measures at the end
303
+ mis = len (vn) - len (lyrs)
304
+ lyrs += mis * ['']
305
+ abcOut.add ('V:%d' % s.vceCnt)
306
+ if s.repbra:
307
+ if s.nvlt == 1 and s.vceCnt > 1: abcOut.add ('I:repbra 0') # only volta on first voice
308
+ if s.nvlt == 2 and iv > lvc: abcOut.add ('I:repbra 0') # only volta on first voice of each part
309
+ if s.cpl > 0: s.bpl = 0 # option -n (max chars per line) overrules -b (max bars per line)
310
+ elif s.bpl == 0: s.cpl = 100 # the default: 100 chars per line
311
+ bn = 0 # count bars
312
+ while vn: # while still measures available
313
+ ib = 1
314
+ chunk = vn [0]
315
+ while ib < len (vn):
316
+ if s.cpl > 0 and len (chunk) + len (vn [ib]) >= s.cpl: break # line full (number of chars)
317
+ if s.bpl > 0 and ib >= s.bpl: break # line full (number of bars)
318
+ chunk += vn [ib]
319
+ ib += 1
320
+ bn += ib
321
+ abcOut.add (chunk + ' %%%d' % bn) # line with barnumer
322
+ del vn[:ib] # chop ib bars
323
+ lyrlines = sorted (vl.items ()) # order the numbered lyric lines for output
324
+ for n, lyrs in lyrlines:
325
+ abcOut.add ('w: ' + '|'.join (lyrs[:ib]) + '|')
326
+ del lyrs[:ib]
327
+ vvmap [iv] = s.vceCnt # xml voice number -> abc voice number
328
+ s.vceCnt += 1 # count voices over all parts
329
+ s.gMaten = [] # reset the follwing instance vars for each part
330
+ s.gLyrics = []
331
+ s.cnt.prcnt (ip+1) # print summary of skipped items in this part
332
+ return vvmap
333
+
334
+ class ABCoutput:
335
+ pagekeys = 'scale,pageheight,pagewidth,leftmargin,rightmargin,topmargin,botmargin'.split (',')
336
+ def __init__ (s, fnmext, pad, X, options):
337
+ s.fnmext = fnmext
338
+ s.outlist = [] # list of ABC strings
339
+ s.title = 'T:Title'
340
+ s.key = 'none'
341
+ s.clefs = {} # clefs for all abc-voices
342
+ s.mtr = 'none'
343
+ s.tempo = 0 # 0 -> no tempo field
344
+ s.tempo_units = (1,4) # note type of tempo direction
345
+ s.pad = pad # the output path or none
346
+ s.X = X + 1 # the abc tune number
347
+ s.denL = options.d # denominator of the unit length (L:) from -d option
348
+ s.volpan = int (options.m) # 0 -> no %%MIDI, 1 -> only program, 2 -> all %%MIDI
349
+ s.cmpL = [] # computed optimal unit length for all voices
350
+ s.jscript = options.j # compatibility with javascript version
351
+ s.tstep = options.t # translate percmap to voicemap
352
+ s.stemless = 0 # use U:s=!stemless!
353
+ s.shiftStem = options.s # shift note heads 3 units left
354
+ if pad:
355
+ _, base_name = os.path.split (fnmext)
356
+ s.outfile = open (os.path.join (pad, base_name), 'w')
357
+ else: s.outfile = sys.stdout
358
+ if s.jscript: s.X = 1 # always X:1 in javascript version
359
+ s.pageFmt = {}
360
+ for k in s.pagekeys: s.pageFmt [k] = None
361
+ if len (options.p) == 7:
362
+ for k, v in zip (s.pagekeys, options.p):
363
+ try: s.pageFmt [k] = float (v)
364
+ except: info ('illegal float %s for %s', (k, v)); continue
365
+
366
+ def add (s, str):
367
+ s.outlist.append (str + '\n') # collect all ABC output
368
+
369
+ def mkHeader (s, stfmap, partlist, midimap, vmpdct, koppen): # stfmap = [parts], part = [staves], stave = [voices]
370
+ accVce, accStf, staffs = [], [], stfmap[:] # staffs is consumed
371
+ for x in partlist: # collect partnames into accVce and staff groups into accStf
372
+ try: prgroupelem (x, ('', ''), '', stfmap, accVce, accStf)
373
+ except: info ('lousy musicxml: error in part-list')
374
+ staves = ' '.join (accStf)
375
+ clfnms = {}
376
+ for part, (partname, partabbrv) in zip (staffs, accVce):
377
+ if not part: continue # skip empty part
378
+ firstVoice = part[0][0] # the first voice number in this part
379
+ nm = partname.replace ('\n','\\n').replace ('.:','.').strip (':')
380
+ snm = partabbrv.replace ('\n','\\n').replace ('.:','.').strip (':')
381
+ clfnms [firstVoice] = (nm and 'nm="%s"' % nm or '') + (snm and ' snm="%s"' % snm or '')
382
+ hd = ['X:%d\n%s\n' % (s.X, s.title)]
383
+ for i, k in enumerate (s.pagekeys):
384
+ if s.jscript and k in ['pageheight','topmargin', 'botmargin']: continue
385
+ if s.pageFmt [k] != None: hd.append ('%%%%%s %.2f%s\n' % (k, s.pageFmt [k], i > 0 and 'cm' or ''))
386
+ if staves and len (accStf) > 1: hd.append ('%%score ' + staves + '\n')
387
+ tempo = s.tempo and 'Q:%d/%d=%s\n' % (s.tempo_units [0], s.tempo_units [1], s.tempo) or '' # default no tempo field
388
+ d = {} # determine the most frequently occurring unit length over all voices
389
+ for x in s.cmpL: d[x] = d.get (x, 0) + 1
390
+ if s.jscript: defLs = sorted (d.items (), key=lambda x: (-x[1], x[0])) # when tie (1) sort on key (0)
391
+ else: defLs = sorted (d.items (), key=lambda x: -x[1])
392
+ defL = s.denL and s.denL or defLs [0][0] # override default unit length with -d option
393
+ hd.append ('L:1/%d\n%sM:%s\n' % (defL, tempo, s.mtr))
394
+ hd.append ('K:%s\n' % s.key)
395
+ if s.stemless: hd.append ('U:s=!stemless!\n')
396
+ vxs = sorted (vmpdct.keys ())
397
+ for vx in vxs: hd.extend (vmpdct [vx])
398
+ s.dojef = 0 # translate percmap to voicemap
399
+ for vnum, clef in s.clefs.items ():
400
+ ch, prg, vol, pan = midimap [vnum-1][:4]
401
+ dmap = midimap [vnum - 1][4:] # map of abc percussion notes to midi notes
402
+ if dmap and 'perc' not in clef: clef = (clef + ' map=perc').strip ();
403
+ hd.append ('V:%d %s %s\n' % (vnum, clef, clfnms.get (vnum, '')))
404
+ if vnum in vmpdct:
405
+ hd.append ('%%%%voicemap tab%d\n' % vnum)
406
+ hd.append ('K:none\nM:none\n%%clef none\n%%staffscale 1.6\n%%flatbeams true\n%%stemdir down\n')
407
+ if 'perc' in clef: hd.append ('K:none\n'); # no key for a perc voice
408
+ if s.volpan > 1: # option -m 2 -> output all recognized midi commands when needed and present in xml
409
+ if ch > 0 and ch != vnum: hd.append ('%%%%MIDI channel %d\n' % ch)
410
+ if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
411
+ if vol >= 0: hd.append ('%%%%MIDI control 7 %.0f\n' % vol) # volume == 0 is possible ...
412
+ if pan >= 0: hd.append ('%%%%MIDI control 10 %.0f\n' % pan)
413
+ elif s.volpan > 0: # default -> only output midi program command when present in xml
414
+ if dmap and ch > 0: hd.append ('%%%%MIDI channel %d\n' % ch) # also channel if percussion part
415
+ if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
416
+ for abcNote, step, midiNote, notehead in dmap:
417
+ if not notehead: notehead = 'normal'
418
+ if abcMid (abcNote) != midiNote or abcNote != step:
419
+ if s.volpan > 0: hd.append ('%%%%MIDI drummap %s %s\n' % (abcNote, midiNote))
420
+ hd.append ('I:percmap %s %s %s %s\n' % (abcNote, step, midiNote, notehead))
421
+ s.dojef = s.tstep
422
+ if defL != s.cmpL [vnum-1]: # only if computed unit length different from header
423
+ hd.append ('L:1/%d\n' % s.cmpL [vnum-1])
424
+ s.outlist = hd + s.outlist
425
+ if koppen: # output SVG stuff needed for tablature
426
+ k1 = kopSvg.replace ('-2','-5') if s.shiftStem else kopSvg # shift note heads 3 units left
427
+ k2 = kopSvg2.replace ('-2','-5') if s.shiftStem else kopSvg2
428
+ tb = tabSvg.replace ('-3','-6') if s.shiftStem else tabSvg
429
+ ks = sorted (koppen.keys ()) # javascript compatibility
430
+ ks = [k2 % (k, k) if len (k) == 2 else k1 % (k, k) for k in ks]
431
+ tbs = map (lambda x: x.strip () + '\n', tb.splitlines ()) # javascript compatibility
432
+ s.outlist = tbs + ks + ['</defs>\n%%endsvg\n'] + s.outlist
433
+
434
+ def writeall (s): # determine the required encoding of the entire ABC output
435
+ str = ''.join (s.outlist)
436
+ if s.dojef: str = perc2map (str)
437
+ if python3: s.outfile.write (str)
438
+ else: s.outfile.write (str.encode ('utf-8'))
439
+ if s.pad: s.outfile.close () # close each file with -o option
440
+ else: s.outfile.write ('\n') # add empty line between tunes on stdout
441
+ info ('%s written with %d voices' % (s.fnmext, len (s.clefs)), warn=0)
442
+
443
+ #----------------
444
+ # functions
445
+ #----------------
446
+ def abcLyr (xs, melis): # Convert list xs to abc lyrics.
447
+ if not ''.join (xs): return '', 0 # there is no lyrics in this measure
448
+ res = []
449
+ for x in xs: # xs has for every note a lyrics syllabe or an empty string
450
+ if x == '': # note without lyrics
451
+ if melis: x = '_' # set melisma
452
+ else: x = '*' # skip note
453
+ elif x.endswith ('_') and not x.endswith ('\_'): # start of new melisma
454
+ x = x.replace ('_', '') # remove and set melis boolean
455
+ melis = 1 # so next skips will become melisma
456
+ else: melis = 0 # melisma stops on first syllable
457
+ res.append (x)
458
+ return (' '.join (res), melis)
459
+
460
+ def simplify (a, b): # divide a and b by their greatest common divisor
461
+ x, y = a, b
462
+ while b: a, b = b, a % b
463
+ return x // a, y // a
464
+
465
+ def abcdur (nx, divs, uL): # convert an musicXML duration d to abc units with L:1/uL
466
+ if nx.dur == 0: return '' # when called for elements without duration
467
+ num, den = simplify (uL * nx.dur, divs * 4) # L=1/8 -> uL = 8 units
468
+ if nx.fact: # apply tuplet time modification
469
+ numfac, denfac = nx.fact
470
+ num, den = simplify (num * numfac, den * denfac)
471
+ if den > 64: # limit the denominator to a maximum of 64
472
+ x = float (num) / den; n = math.floor (x); # when just above an integer n
473
+ if x - n < 0.1 * x: num, den = n, 1; # round to n
474
+ num64 = 64. * num / den + 1.0e-15 # to get Python2 behaviour of round
475
+ num, den = simplify (int (round (num64)), 64)
476
+ if num == 1:
477
+ if den == 1: dabc = ''
478
+ elif den == 2: dabc = '/'
479
+ else: dabc = '/%d' % den
480
+ elif den == 1: dabc = '%d' % num
481
+ else: dabc = '%d/%d' % (num, den)
482
+ return dabc
483
+
484
+ def abcMid (note): # abc note -> midi pitch
485
+ r = re.search (r"([_^]*)([A-Ga-g])([',]*)", note)
486
+ if not r: return -1
487
+ acc, n, oct = r.groups ()
488
+ nUp = n.upper ()
489
+ p = 60 + [0,2,4,5,7,9,11]['CDEFGAB'.index (nUp)] + (12 if nUp != n else 0);
490
+ if acc: p += (1 if acc[0] == '^' else -1) * len (acc)
491
+ if oct: p += (12 if oct[0] == "'" else -12) * len (oct)
492
+ return p
493
+
494
+ def staffStep (ptc, o, clef, tstep):
495
+ ndif = 0
496
+ if 'stafflines=1' in clef: ndif += 4 # meaning of one line: E (xml) -> B (abc)
497
+ if not tstep and clef.startswith ('bass'): ndif += 12 # transpose bass -> treble (C3 -> A4)
498
+ if ndif: # diatonic transposition == addition modulo 7
499
+ nm7 = 'C,D,E,F,G,A,B'.split (',')
500
+ n = nm7.index (ptc) + ndif
501
+ ptc, o = nm7 [n % 7], o + n // 7
502
+ if o > 4: ptc = ptc.lower ()
503
+ if o > 5: ptc = ptc + (o-5) * "'"
504
+ if o < 4: ptc = ptc + (4-o) * ","
505
+ return ptc
506
+
507
+ def setKey (fifths, mode):
508
+ sharpness = ['Fb', 'Cb','Gb','Db','Ab','Eb','Bb','F','C','G','D','A', 'E', 'B', 'F#','C#','G#','D#','A#','E#','B#']
509
+ offTab = {'maj':8, 'ion':8, 'm':11, 'min':11, 'aeo':11, 'mix':9, 'dor':10, 'phr':12, 'lyd':7, 'loc':13, 'non':8}
510
+ mode = mode.lower ()[:3] # only first three chars, no case
511
+ key = sharpness [offTab [mode] + fifths] + (mode if offTab [mode] != 8 else '')
512
+ accs = ['F','C','G','D','A','E','B']
513
+ if fifths >= 0: msralts = dict (zip (accs[:fifths], fifths * [1]))
514
+ else: msralts = dict (zip (accs[fifths:], -fifths * [-1]))
515
+ return key, msralts
516
+
517
+ def insTup (ix, notes, fact): # read one nested tuplet
518
+ tupcnt = 0
519
+ nx = notes [ix]
520
+ if 'start' in nx.tup:
521
+ nx.tup.remove ('start') # do recursive calls when starts remain
522
+ tix = ix # index of first tuplet note
523
+ fn, fd = fact # xml time-mod of the higher level
524
+ fnum, fden = nx.fact # xml time-mod of the current level
525
+ tupfact = fnum//fn, fden//fd # abc time mod of this level
526
+ while ix < len (notes):
527
+ nx = notes [ix]
528
+ if isinstance (nx, Elem) or nx.grace:
529
+ ix += 1 # skip all non tuplet elements
530
+ continue
531
+ if 'start' in nx.tup: # more nested tuplets to start
532
+ ix, tupcntR = insTup (ix, notes, tupfact) # ix is on the stop note!
533
+ tupcnt += tupcntR
534
+ elif nx.fact:
535
+ tupcnt += 1 # count tuplet elements
536
+ if 'stop' in nx.tup:
537
+ nx.tup.remove ('stop')
538
+ break
539
+ if not nx.fact: # stop on first non tuplet note
540
+ ix = lastix # back to last tuplet note
541
+ break
542
+ lastix = ix
543
+ ix += 1
544
+ # put abc tuplet notation before the recursive ones
545
+ tup = (tupfact[0], tupfact[1], tupcnt)
546
+ if tup == (3, 2, 3): tupPrefix = '(3'
547
+ else: tupPrefix = '(%d:%d:%d' % tup
548
+ notes [tix].tupabc = tupPrefix + notes [tix].tupabc
549
+ return ix, tupcnt # ix is on the last tuplet note
550
+
551
+ def mkBroken (vs): # introduce broken rhythms (vs: one voice, one measure)
552
+ vs = [n for n in vs if isinstance (n, Note)]
553
+ i = 0
554
+ while i < len (vs) - 1:
555
+ n1, n2 = vs[i], vs[i+1] # scan all adjacent pairs
556
+ # skip if note in tuplet or has no duration or outside beam
557
+ if not n1.fact and not n2.fact and n1.dur > 0 and n2.beam:
558
+ if n1.dur * 3 == n2.dur:
559
+ n2.dur = (2 * n2.dur) // 3
560
+ n1.dur = n1.dur * 2
561
+ n1.after = '<' + n1.after
562
+ i += 1 # do not chain broken rhythms
563
+ elif n2.dur * 3 == n1.dur:
564
+ n1.dur = (2 * n1.dur) // 3
565
+ n2.dur = n2.dur * 2
566
+ n1.after = '>' + n1.after
567
+ i += 1 # do not chain broken rhythms
568
+ i += 1
569
+
570
+ def outVoice (measure, divs, im, ip, unitL): # note/elem objects of one measure in one voice
571
+ ix = 0
572
+ while ix < len (measure): # set all (nested) tuplet annotations
573
+ nx = measure [ix]
574
+ if isinstance (nx, Note) and nx.fact and not nx.grace:
575
+ ix, tupcnt = insTup (ix, measure, (1, 1)) # read one tuplet, insert annotation(s)
576
+ ix += 1
577
+ vs = []
578
+ for nx in measure:
579
+ if isinstance (nx, Note):
580
+ durstr = abcdur (nx, divs, unitL) # xml -> abc duration string
581
+ chord = len (nx.ns) > 1
582
+ cns = [nt[:-1] for nt in nx.ns if nt.endswith ('-')]
583
+ tie = ''
584
+ if chord and len (cns) == len (nx.ns): # all chord notes tied
585
+ nx.ns = cns # chord notes without tie
586
+ tie = '-' # one tie for whole chord
587
+ s = nx.tupabc + ''.join (nx.before)
588
+ if chord: s += '['
589
+ for nt in nx.ns: s += nt
590
+ if chord: s += ']' + tie
591
+ if s.endswith ('-'): s, tie = s[:-1], '-' # split off tie
592
+ s += durstr + tie # and put it back again
593
+ s += nx.after
594
+ nospace = nx.beam
595
+ else:
596
+ if isinstance (nx.str, listtype): nx.str = nx.str [0]
597
+ s = nx.str
598
+ nospace = 1
599
+ if nospace: vs.append (s)
600
+ else: vs.append (' ' + s)
601
+ vs = ''.join (vs) # ad hoc: remove multiple pedal directions
602
+ while vs.find ('!ped!!ped!') >= 0: vs = vs.replace ('!ped!!ped!','!ped!')
603
+ while vs.find ('!ped-up!!ped-up!') >= 0: vs = vs.replace ('!ped-up!!ped-up!','!ped-up!')
604
+ while vs.find ('!8va(!!8va)!') >= 0: vs = vs.replace ('!8va(!!8va)!','') # remove empty ottava's
605
+ return vs
606
+
607
+ def sortMeasure (voice, m):
608
+ voice.sort (key=lambda o: o.tijd) # sort on time
609
+ time = 0
610
+ v = []
611
+ rs = [] # holds rests in between notes
612
+ for i, nx in enumerate (voice): # establish sequentiality
613
+ if nx.tijd > time and chkbug (nx.tijd - time, m):
614
+ v.append (Note (nx.tijd - time, 'x')) # fill hole with invisble rest
615
+ rs.append (len (v) - 1)
616
+ if isinstance (nx, Elem):
617
+ if nx.tijd < time: nx.tijd = time # shift elems without duration to where they fit
618
+ v.append (nx)
619
+ time = nx.tijd
620
+ continue
621
+ if nx.tijd < time: # overlapping element
622
+ if nx.ns[0] == 'z': continue # discard overlapping rest
623
+ if v[-1].tijd <= nx.tijd: # we can do something
624
+ if v[-1].ns[0] == 'z': # shorten rest
625
+ v[-1].dur = nx.tijd - v[-1].tijd
626
+ if v[-1].dur == 0: del v[-1] # nothing left
627
+ info ('overlap in part %d, measure %d: rest shortened' % (m.ixp+1, m.ixm+1))
628
+ else: # make a chord of overlap
629
+ v[-1].ns += nx.ns
630
+ info ('overlap in part %d, measure %d: added chord' % (m.ixp+1, m.ixm+1))
631
+ nx.dur = (nx.tijd + nx.dur) - time # the remains
632
+ if nx.dur <= 0: continue # nothing left
633
+ nx.tijd = time # append remains
634
+ else: # give up
635
+ info ('overlapping notes in one voice! part %d, measure %d, note %s discarded' % (m.ixp+1, m.ixm+1, isinstance (nx, Note) and nx.ns or nx.str))
636
+ continue
637
+ v.append (nx)
638
+ if isinstance (nx, Note):
639
+ if nx.ns [0] in 'zx':
640
+ rs.append (len (v) - 1) # remember rests between notes
641
+ elif len (rs):
642
+ if nx.beam and not nx.grace: # copy beam into rests
643
+ for j in rs: v[j].beam = nx.beam
644
+ rs = [] # clear rests on each note
645
+ time = nx.tijd + nx.dur
646
+ # when a measure contains no elements and no forwards -> no incTime -> s.maxtime = 0 -> right barline
647
+ # is inserted at time == 0 (in addbar) and is only element in the voice when sortMeasure is called
648
+ if time == 0: info ('empty measure in part %d, measure %d, it should contain at least a rest to advance the time!' % (m.ixp+1, m.ixm+1))
649
+ return v
650
+
651
+ def getPartlist (ps): # correct part-list (from buggy xml-software)
652
+ xs = [] # the corrected part-list
653
+ e = [] # stack of opened part-groups
654
+ for x in list (ps): # insert missing stops, delete double starts
655
+ if x.tag == 'part-group':
656
+ num, type = x.get ('number'), x.get ('type')
657
+ if type == 'start':
658
+ if num in e: # missing stop: insert one
659
+ xs.append (E.Element ('part-group', number = num, type = 'stop'))
660
+ xs.append (x)
661
+ else: # normal start
662
+ xs.append (x)
663
+ e.append (num)
664
+ else:
665
+ if num in e: # normal stop
666
+ e.remove (num)
667
+ xs.append (x)
668
+ else: pass # double stop: skip it
669
+ else: xs.append (x)
670
+ for num in reversed (e): # fill missing stops at the end
671
+ xs.append (E.Element ('part-group', number = num, type = 'stop'))
672
+ return xs
673
+
674
+ def parseParts (xs, d, e): # -> [elems on current level], rest of xs
675
+ if not xs: return [],[]
676
+ x = xs.pop (0)
677
+ if x.tag == 'part-group':
678
+ num, type = x.get ('number'), x.get ('type')
679
+ if type == 'start': # go one level deeper
680
+ s = [x.findtext (n, '') for n in ['group-symbol','group-barline','group-name','group-abbreviation']]
681
+ d [num] = s # remember groupdata by group number
682
+ e.append (num) # make stack of open group numbers
683
+ elemsnext, rest1 = parseParts (xs, d, e) # parse one level deeper to next stop
684
+ elems, rest2 = parseParts (rest1, d, e) # parse the rest on this level
685
+ return [elemsnext] + elems, rest2
686
+ else: # stop: close level and return group-data
687
+ nums = e.pop () # last open group number in stack order
688
+ if xs and xs[0].get ('type') == 'stop': # two consequetive stops
689
+ if num != nums: # in the wrong order (tempory solution)
690
+ d[nums], d[num] = d[num], d[nums] # exchange values (only works for two stops!!!)
691
+ sym = d[num] # retrieve an return groupdata as last element of the group
692
+ return [sym], xs
693
+ else:
694
+ elems, rest = parseParts (xs, d, e) # parse remaining elements on current level
695
+ name = x.findtext ('part-name',''), x.findtext ('part-abbreviation','')
696
+ return [name] + elems, rest
697
+
698
+ def bracePart (part): # put a brace on multistaff part and group voices
699
+ if not part: return [] # empty part in the score
700
+ brace = []
701
+ for ivs in part:
702
+ if len (ivs) == 1: # stave with one voice
703
+ brace.append ('%s' % ivs[0])
704
+ else: # stave with multiple voices
705
+ brace += ['('] + ['%s' % iv for iv in ivs] + [')']
706
+ brace.append ('|')
707
+ del brace[-1] # no barline at the end
708
+ if len (part) > 1:
709
+ brace = ['{'] + brace + ['}']
710
+ return brace
711
+
712
+ def prgroupelem (x, gnm, bar, pmap, accVce, accStf): # collect partnames (accVce) and %%score map (accStf)
713
+ if type (x) == tupletype: # partname-tuple = (part-name, part-abbrev)
714
+ y = pmap.pop (0)
715
+ if gnm[0]: x = [n1 + ':' + n2 for n1, n2 in zip (gnm, x)] # put group-name before part-name
716
+ accVce.append (x)
717
+ accStf.extend (bracePart (y))
718
+ elif len (x) == 2 and type (x[0]) == tupletype: # misuse of group just to add extra name to stave
719
+ y = pmap.pop (0)
720
+ nms = [n1 + ':' + n2 for n1, n2 in zip (x[0], x[1][2:])] # x[0] = partname-tuple, x[1][2:] = groupname-tuple
721
+ accVce.append (nms)
722
+ accStf.extend (bracePart (y))
723
+ else:
724
+ prgrouplist (x, bar, pmap, accVce, accStf)
725
+
726
+ def prgrouplist (x, pbar, pmap, accVce, accStf): # collect partnames, scoremap for a part-group
727
+ sym, bar, gnm, gabbr = x[-1] # bracket symbol, continue barline, group-name-tuple
728
+ bar = bar == 'yes' or pbar # pbar -> the parent has bar
729
+ accStf.append (sym == 'brace' and '{' or '[')
730
+ for z in x[:-1]:
731
+ prgroupelem (z, (gnm, gabbr), bar, pmap, accVce, accStf)
732
+ if bar: accStf.append ('|')
733
+ if bar: del accStf [-1] # remove last one before close
734
+ accStf.append (sym == 'brace' and '}' or ']')
735
+
736
+ def compUnitLength (iv, maten, divs): # compute optimal unit length
737
+ uLmin, minLen = 0, max_int
738
+ for uL in [4,8,16]: # try 1/4, 1/8 and 1/16
739
+ vLen = 0 # total length of abc duration strings in this voice
740
+ for im, m in enumerate (maten): # all measures
741
+ for e in m[iv]: # all notes in voice iv
742
+ if isinstance (e, Elem) or e.dur == 0: continue # no real durations
743
+ vLen += len (abcdur (e, divs [im], uL)) # add len of duration string
744
+ if vLen < minLen: uLmin, minLen = uL, vLen # remember the smallest
745
+ return uLmin
746
+
747
+ def doSyllable (syl):
748
+ txt = ''
749
+ for e in syl:
750
+ if e.tag == 'elision': txt += '~'
751
+ elif e.tag == 'text': # escape - and space characters
752
+ txt += (e.text or '').replace ('_','\_').replace('-', r'\-').replace(' ', '~')
753
+ if not txt: return txt
754
+ if syl.findtext('syllabic') in ['begin', 'middle']: txt += '-'
755
+ if syl.find('extend') is not None: txt += '_'
756
+ return txt
757
+
758
+ def checkMelismas (lyrics, maten, im, iv):
759
+ if im == 0: return
760
+ maat = maten [im][iv] # notes of the current measure
761
+ curlyr = lyrics [im][iv] # lyrics dict of current measure
762
+ prvlyr = lyrics [im-1][iv] # lyrics dict of previous measure
763
+ for n, (lyrstr, melis) in prvlyr.items (): # all lyric numbers in the previous measure
764
+ if n not in curlyr and melis: # melisma required, but no lyrics present -> make one!
765
+ ms = getMelisma (maat) # get a melisma for the current measure
766
+ if ms: curlyr [n] = (ms, 0) # set melisma as the n-th lyrics of the current measure
767
+
768
+ def getMelisma (maat): # get melisma from notes in maat
769
+ ms = []
770
+ for note in maat: # every note should get an underscore
771
+ if not isinstance (note, Note): continue # skip Elem's
772
+ if note.grace: continue # skip grace notes
773
+ if note.ns [0] in 'zx': break # stop on first rest
774
+ ms.append ('_')
775
+ return ' '.join (ms)
776
+
777
+ def perc2map (abcIn):
778
+ fillmap = {'diamond':1, 'triangle':1, 'square':1, 'normal':1};
779
+ abc = map (lambda x: x.strip (), percSvg.splitlines ())
780
+ id='default'
781
+ maps = {'default': []};
782
+ dmaps = {'default': []}
783
+ r1 = re.compile (r'V:\s*(\S+)')
784
+ ls = abcIn.splitlines ()
785
+ for x in ls:
786
+ if 'I:percmap' in x:
787
+ noot, step, midi, kop = map (lambda x: x.strip (), x.split ()[1:])
788
+ if kop in fillmap: kop = kop + '+' + ',' + kop
789
+ x = '%%%%map perc%s %s print=%s midi=%s heads=%s' % (id, noot, step, midi, kop)
790
+ maps [id].append (x)
791
+ if '%%MIDI' in x: dmaps [id].append (x)
792
+ if 'V:' in x:
793
+ r = r1.match (x)
794
+ if r:
795
+ id = r.group (1);
796
+ if id not in maps: maps [id] = []; dmaps [id] = []
797
+ ids = sorted (maps.keys ())
798
+ for id in ids: abc += maps [id]
799
+ id='default'
800
+ for x in ls:
801
+ if 'I:percmap' in x: continue
802
+ if '%%MIDI' in x: continue
803
+ if 'V:' in x or 'K:' in x:
804
+ r = r1.match (x)
805
+ if r: id = r.group (1)
806
+ abc.append (x)
807
+ if id in dmaps and len (dmaps [id]) > 0: abc.extend (dmaps [id]); del dmaps [id]
808
+ if 'perc' in x and 'map=' not in x: x += ' map=perc';
809
+ if 'map=perc' in x and len (maps [id]) > 0: abc.append ('%%voicemap perc' + id);
810
+ if 'map=off' in x: abc.append ('%%voicemap');
811
+ else:
812
+ abc.append (x)
813
+ return '\n'.join (abc) + '\n'
814
+
815
+ def addoct (ptc, o): # xml staff step, xml octave number
816
+ p = ptc
817
+ if o > 4: p = ptc.lower ()
818
+ if o > 5: p = p + (o-5) * "'"
819
+ if o < 4: p = p + (4-o) * ","
820
+ return p # abc pitch == abc note without accidental
821
+
822
+ def chkbug (dt, m):
823
+ if dt > m.divs / 16: return 1 # duration should be > 1/64 note
824
+ info ('MuseScore bug: incorrect duration, smaller then 1/64! in measure %d, part %d' % (m.ixm, m.ixp))
825
+ return 0
826
+
827
+ #----------------
828
+ # parser
829
+ #----------------
830
+ class Parser:
831
+ note_alts = [ # 3 alternative notations of the same note for tablature mapping
832
+ [x.strip () for x in '=C, ^C, =D, ^D, =E, =F, ^F, =G, ^G, =A, ^A, =B'.split (',')],
833
+ [x.strip () for x in '^B, _D,^^C, _E, _F, ^E, _G,^^F, _A,^^G, _B, _C'.split (',')],
834
+ [x.strip () for x in '__D,^^B,__E,__F,^^D,__G,^^E,__A,_/A,__B,__C,^^A'.split (',')] ]
835
+ step_map = {'C':0,'D':2,'E':4,'F':5,'G':7,'A':9,'B':11}
836
+ def __init__ (s, options):
837
+ # unfold repeats, number of chars per line, credit filter level, volta option
838
+ s.slurBuf = {} # dict of open slurs keyed by slur number
839
+ s.dirStk = {} # {direction-type + number -> (type, voice | time)} dict for proper closing
840
+ s.ingrace = 0 # marks a sequence of grace notes
841
+ s.msc = Music (options) # global music data abstraction
842
+ s.unfold = options.u # turn unfolding repeats on
843
+ s.ctf = options.c # credit text filter level
844
+ s.gStfMap = [] # [[abc voice numbers] for all parts]
845
+ s.midiMap = [] # midi-settings for each abc voice, in order
846
+ s.drumInst = {} # inst_id -> midi pitch for channel 10 notes
847
+ s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
848
+ s.instMid = [] # [{inst id -> midi-settings} for all parts]
849
+ s.midDflt = [-1,-1,-1,-91] # default midi settings for channel, program, volume, panning
850
+ s.msralts = {} # xml-notenames (without octave) with accidentals from the key
851
+ s.curalts = {} # abc-notenames (with voice number) with passing accidentals
852
+ s.stfMap = {} # xml staff number -> [xml voice number]
853
+ s.vce2stf = {} # xml voice number -> allocated staff number
854
+ s.clefMap = {} # xml staff number -> abc clef (for header only)
855
+ s.curClef = {} # xml staff number -> current abc clef
856
+ s.stemDir = {} # xml voice number -> current stem direction
857
+ s.clefOct = {} # xml staff number -> current clef-octave-change
858
+ s.curStf = {} # xml voice number -> current xml staff number
859
+ s.nolbrk = options.x; # generate no linebreaks ($)
860
+ s.jscript = options.j # compatibility with javascript version
861
+ s.ornaments = sorted (note_ornamentation_map.items ())
862
+ s.doPageFmt = len (options.p) == 1 # translate xml page format
863
+ s.tstep = options.t # clef determines step on staff (percussion)
864
+ s.dirtov1 = options.v1 # all directions to first voice of staff
865
+ s.ped = options.ped # render pedal directions
866
+ s.wstems = options.stm # translate stem elements
867
+ s.pedVce = None # voice for pedal directions
868
+ s.repeat_str = {} # staff number -> [measure number, repeat-text]
869
+ s.tabVceMap = {} # abc voice num -> [%%map ...] for tab voices
870
+ s.koppen = {} # noteheads needed for %%map
871
+
872
+ def matchSlur (s, type2, n, v2, note2, grace, stopgrace): # match slur number n in voice v2, add abc code to before/after
873
+ if type2 not in ['start', 'stop']: return # slur type continue has no abc equivalent
874
+ if n == None: n = '1'
875
+ if n in s.slurBuf:
876
+ type1, v1, note1, grace1 = s.slurBuf [n]
877
+ if type2 != type1: # slur complete, now check the voice
878
+ if v2 == v1: # begins and ends in the same voice: keep it
879
+ if type1 == 'start' and (not grace1 or not stopgrace): # normal slur: start before stop and no grace slur
880
+ note1.before = ['('] + note1.before # keep left-right order!
881
+ note2.after += ')'
882
+ # no else: don't bother with reversed stave spanning slurs
883
+ del s.slurBuf [n] # slur finished, remove from stack
884
+ else: # double definition, keep the last
885
+ info ('double slur numbers %s-%s in part %d, measure %d, voice %d note %s, first discarded' % (type2, n, s.msr.ixp+1, s.msr.ixm+1, v2, note2.ns))
886
+ s.slurBuf [n] = (type2, v2, note2, grace)
887
+ else: # unmatched slur, put in dict
888
+ s.slurBuf [n] = (type2, v2, note2, grace)
889
+
890
+ def doNotations (s, note, nttn, isTab):
891
+ for key, val in s.ornaments:
892
+ if nttn.find (key) != None: note.before += [val] # just concat all ornaments
893
+ trem = nttn.find ('ornaments/tremolo')
894
+ if trem != None:
895
+ type = trem.get ('type')
896
+ if type == 'single':
897
+ note.before.insert (0, '!%s!' % (int (trem.text) * '/'))
898
+ else:
899
+ note.fact = None # no time modification in ABC
900
+ if s.tstep: # abc2svg version
901
+ if type == 'stop': note.before.insert (0, '!trem%s!' % trem.text);
902
+ else: # abc2xml version
903
+ if type == 'start': note.before.insert (0, '!%s-!' % (int (trem.text) * '/'));
904
+ fingering = nttn.findall ('technical/fingering')
905
+ for finger in fingering: # handle multiple finger annotations
906
+ if not isTab: note.before += ['!%s!' % finger.text] # fingering goes before chord (addChord)
907
+ snaar = nttn.find ('technical/string')
908
+ if snaar != None and isTab:
909
+ if s.tstep:
910
+ fret = nttn.find ('technical/fret')
911
+ if fret != None: note.tab = (snaar.text, fret.text)
912
+ else:
913
+ deco = '!%s!' % snaar.text # no double string decos (bug in musescore)
914
+ if deco not in note.ntdec: note.ntdec += deco
915
+ wvln = nttn.find ('ornaments/wavy-line')
916
+ if wvln != None:
917
+ if wvln.get ('type') == 'start': note.before = ['!trill(!'] + note.before # keep left-right order!
918
+ elif wvln.get ('type') == 'stop': note.before = ['!trill)!'] + note.before
919
+ glis = nttn.find ('glissando')
920
+ if glis == None: glis = nttn.find ('slide') # treat slide as glissando
921
+ if glis != None:
922
+ lt = '~' if glis.get ('line-type') =='wavy' else '-'
923
+ if glis.get ('type') == 'start': note.before = ['!%s(!' % lt] + note.before # keep left-right order!
924
+ elif glis.get ('type') == 'stop': note.before = ['!%s)!' % lt] + note.before
925
+
926
+ def tabnote (s, alt, ptc, oct, v, ntrec):
927
+ p = s.step_map [ptc] + int (alt or '0') # p in -2 .. 13
928
+ if p > 11: oct += 1 # octave correction
929
+ if p < 0: oct -= 1
930
+ p = p % 12 # remap p into 0..11
931
+ snaar_nw, fret_nw = ntrec.tab # the computed/annotated allocation of nt
932
+ for i in range (4): # support same note on 4 strings
933
+ na = s.note_alts [i % 3] [p] # get alternative representation of same note
934
+ o = oct
935
+ if na in ['^B', '^^B']: o -= 1 # because in adjacent octave
936
+ if na in ['_C', '__C']: o += 1
937
+ if '/' in na or i == 3: o = 9 # emergency notation for 4th string case
938
+ nt = addoct (na, o)
939
+ snaar, fret = s.tabmap.get ((v, nt), ('', '')) # the current allocation of nt
940
+ if not snaar: break # note not yet allocated
941
+ if snaar_nw == snaar: return nt # use present allocation
942
+ if i == 3: # new allocaion needed but none is free
943
+ fmt = 'rejected: voice %d note %3s string %s fret %2s remains: string %s fret %s'
944
+ info (fmt % (v, nt, snaar_nw, fret_nw, snaar, fret), 1)
945
+ ntrec.tab = (snaar, fret)
946
+ s.tabmap [v, nt] = ntrec.tab # for tablature map (voice, note) -> (string, fret)
947
+ return nt # ABC code always in key C (with midi pitch alterations)
948
+
949
+ def ntAbc (s, ptc, oct, note, v, ntrec, isTab): # pitch, octave -> abc notation
950
+ acc2alt = {'double-flat':-2,'flat-flat':-2,'flat':-1,'natural':0,'sharp':1,'sharp-sharp':2,'double-sharp':2}
951
+ oct += s.clefOct.get (s.curStf [v], 0) # minus clef-octave-change value
952
+ acc = note.findtext ('accidental') # should be the notated accidental
953
+ alt = note.findtext ('pitch/alter') # pitch alteration (midi)
954
+ if ntrec.tab: return s.tabnote (alt, ptc, oct, v, ntrec) # implies s.tstep is true (options.t was given)
955
+ elif isTab and s.tstep:
956
+ nt = ['__','_','','^','^^'][int (alt or '0') + 2] + addoct (ptc, oct)
957
+ info ('no string notation found for note %s in voice %d' % (nt, v), 1)
958
+ p = addoct (ptc, oct)
959
+ if alt == None and s.msralts.get (ptc, 0): alt = 0 # no alt but key implies alt -> natural!!
960
+ if alt == None and (p, v) in s.curalts: alt = 0 # no alt but previous note had one -> natural!!
961
+ if acc == None and alt == None: return p # no acc, no alt
962
+ elif acc != None:
963
+ alt = acc2alt [acc] # acc takes precedence over the pitch here!
964
+ else: # now see if we really must add an accidental
965
+ alt = int (float (alt))
966
+ if (p, v) in s.curalts: # the note in this voice has been altered before
967
+ if alt == s.curalts [(p, v)]: return p # alteration still the same
968
+ elif alt == s.msralts.get (ptc, 0): return p # alteration implied by the key
969
+ tieElms = note.findall ('tie') + note.findall ('notations/tied') # in xml we have separate notated ties and playback ties
970
+ if 'stop' in [e.get ('type') for e in tieElms]: return p # don't alter tied notes
971
+ info ('accidental %d added in part %d, measure %d, voice %d note %s' % (alt, s.msr.ixp+1, s.msr.ixm+1, v+1, p))
972
+ s.curalts [(p, v)] = alt
973
+ p = ['__','_','=','^','^^'][alt+2] + p # and finally ... prepend the accidental
974
+ return p
975
+
976
+ def doNote (s, n): # parse a musicXML note tag
977
+ note = Note ()
978
+ v = int (n.findtext ('voice', '1'))
979
+ if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
980
+ chord = n.find ('chord') != None
981
+ p = n.findtext ('pitch/step') or n.findtext ('unpitched/display-step')
982
+ o = n.findtext ('pitch/octave') or n.findtext ('unpitched/display-octave')
983
+ r = n.find ('rest')
984
+ numer = n.findtext ('time-modification/actual-notes')
985
+ if numer:
986
+ denom = n.findtext ('time-modification/normal-notes')
987
+ note.fact = (int (numer), int (denom))
988
+ note.tup = [x.get ('type') for x in n.findall ('notations/tuplet')]
989
+ dur = n.findtext ('duration')
990
+ grc = n.find ('grace')
991
+ note.grace = grc != None
992
+ note.before, note.after = [], '' # strings with ABC stuff that goes before or after a note/chord
993
+ if note.grace and not s.ingrace: # open a grace sequence
994
+ s.ingrace = 1
995
+ note.before = ['{']
996
+ if grc.get ('slash') == 'yes': note.before += ['/'] # acciaccatura
997
+ stopgrace = not note.grace and s.ingrace
998
+ if stopgrace: # close the grace sequence
999
+ s.ingrace = 0
1000
+ s.msc.lastnote.after += '}' # close grace on lastenote.after
1001
+ if dur == None or note.grace: dur = 0
1002
+ if r == None and n.get ('print-object') == 'no':
1003
+ if chord: return
1004
+ r = 1 # turn invisible notes (that advance the time) into invisible rests
1005
+ note.dur = int (dur)
1006
+ if r == None and (not p or not o): # not a rest and no pitch
1007
+ s.msc.cnt.inc ('nopt', v) # count unpitched notes
1008
+ o, p = 5,'E' # make it an E5 ??
1009
+ isTab = s.curClef and s.curClef.get (s.curStf [v], '').startswith ('tab')
1010
+ nttn = n.find ('notations') # add ornaments
1011
+ if nttn != None: s.doNotations (note, nttn, isTab)
1012
+ e = n.find ('stem') if r == None else None # no !stemless! before rest
1013
+ if e != None and e.text == 'none' and (not isTab or v in s.hasStems or s.tstep):
1014
+ note.before += ['s']; abcOut.stemless = 1;
1015
+ e = n.find ('accidental')
1016
+ if e != None and e.get ('parentheses') == 'yes': note.ntdec += '!courtesy!'
1017
+ if r != None: noot = 'x' if n.get ('print-object') == 'no' or isTab else 'z'
1018
+ else: noot = s.ntAbc (p, int (o), n, v, note, isTab)
1019
+ if n.find ('unpitched') != None:
1020
+ clef = s.curClef [s.curStf [v]] # the current clef for this voice
1021
+ step = staffStep (p, int (o), clef, s.tstep) # (clef independent) step value of note on the staff
1022
+ instr = n.find ('instrument')
1023
+ instId = instr.get ('id') if instr != None else 'dummyId'
1024
+ midi = s.drumInst.get (instId, abcMid (noot))
1025
+ nh = n.findtext ('notehead', '').replace (' ','-') # replace spaces in xml notehead names for percmap
1026
+ if nh == 'x': noot = '^' + noot.replace ('^','').replace ('_','')
1027
+ if nh in ['circle-x','diamond','triangle']: noot = '_' + noot.replace ('^','').replace ('_','')
1028
+ if nh and n.find ('notehead').get ('filled','') == 'yes': nh += '+'
1029
+ if nh and n.find ('notehead').get ('filled','') == 'no': nh += '-'
1030
+ s.drumNotes [(v, noot)] = (step, midi, nh) # keep data for percussion map
1031
+ tieElms = n.findall ('tie') + n.findall ('notations/tied') # in xml we have separate notated ties and playback ties
1032
+ if 'start' in [e.get ('type') for e in tieElms]: # n can have stop and start tie
1033
+ noot = noot + '-'
1034
+ note.beam = sum ([1 for b in n.findall('beam') if b.text in ['continue', 'end']]) + int (note.grace)
1035
+ lyrlast = 0; rsib = re.compile (r'^.*verse')
1036
+ for e in n.findall ('lyric'):
1037
+ lyrnum = int (rsib.sub ('', e.get ('number', '1'))) # also do Sibelius numbers
1038
+ if lyrnum == 0: lyrnum = lyrlast + 1 # and correct Sibelius bugs
1039
+ else: lyrlast = lyrnum
1040
+ note.lyrs [lyrnum] = doSyllable (e)
1041
+ stemdir = n.findtext ('stem')
1042
+ if s.wstems and (stemdir == 'up' or stemdir == 'down'):
1043
+ if stemdir != s.stemDir.get (v, ''):
1044
+ s.stemDir [v] = stemdir
1045
+ s.msc.appendElem (v, '[I:stemdir %s]' % stemdir)
1046
+ if chord: s.msc.addChord (note, noot)
1047
+ else:
1048
+ xmlstaff = int (n.findtext ('staff', '1'))
1049
+ if s.curStf [v] != xmlstaff: # the note should go to another staff
1050
+ dstaff = xmlstaff - s.curStf [v] # relative new staff number
1051
+ s.curStf [v] = xmlstaff # remember the new staff for this voice
1052
+ s.msc.appendElem (v, '[I:staff %+d]' % dstaff) # insert a move before the note
1053
+ s.msc.appendNote (v, note, noot)
1054
+ for slur in n.findall ('notations/slur'): # s.msc.lastnote points to the last real note/chord inserted above
1055
+ s.matchSlur (slur.get ('type'), slur.get ('number'), v, s.msc.lastnote, note.grace, stopgrace) # match slur definitions
1056
+
1057
+ def doAttr (s, e): # parse a musicXML attribute tag
1058
+ teken = {'C1':'alto1','C2':'alto2','C3':'alto','C4':'tenor','F4':'bass','F3':'bass3','G2':'treble','TAB':'tab','percussion':'perc'}
1059
+ dvstxt = e.findtext ('divisions')
1060
+ if dvstxt: s.msr.divs = int (dvstxt)
1061
+ steps = int (e.findtext ('transpose/chromatic', '0')) # for transposing instrument
1062
+ fifths = e.findtext ('key/fifths')
1063
+ first = s.msc.tijd == 0 and s.msr.ixm == 0 # first attributes in first measure
1064
+ if fifths:
1065
+ key, s.msralts = setKey (int (fifths), e.findtext ('key/mode','major'))
1066
+ if first and not steps and abcOut.key == 'none':
1067
+ abcOut.key = key # first measure -> header, if not transposing instrument or percussion part!
1068
+ elif key != abcOut.key or not first:
1069
+ s.msr.attr += '[K:%s]' % key # otherwise -> voice
1070
+ beats = e.findtext ('time/beats')
1071
+ if beats:
1072
+ unit = e.findtext ('time/beat-type')
1073
+ mtr = beats + '/' + unit
1074
+ if first: abcOut.mtr = mtr # first measure -> header
1075
+ else: s.msr.attr += '[M:%s]' % mtr # otherwise -> voice
1076
+ s.msr.mtr = int (beats), int (unit)
1077
+ s.msr.mdur = (s.msr.divs * s.msr.mtr[0] * 4) // s.msr.mtr[1] # duration of measure in xml-divisions
1078
+ for ms in e.findall('measure-style'):
1079
+ n = int (ms.get ('number', '1')) # staff number
1080
+ voices = s.stfMap [n] # all voices of staff n
1081
+ for mr in ms.findall('measure-repeat'):
1082
+ ty = mr.get('type')
1083
+ if ty == 'start': # remember start measure number and text voor each staff
1084
+ s.repeat_str [n] = [s.msr.ixm, mr.text]
1085
+ for v in voices: # insert repeat into all voices, value will be overwritten at stop
1086
+ s.msc.insertElem (v, s.repeat_str [n])
1087
+ elif ty == 'stop': # calculate repeat measure count for this staff n
1088
+ start_ix, text_ = s.repeat_str [n]
1089
+ repeat_count = s.msr.ixm - start_ix
1090
+ if text_:
1091
+ mid_str = "%s " % text_
1092
+ repeat_count /= int (text_)
1093
+ else:
1094
+ mid_str = "" # overwrite repeat with final string
1095
+ s.repeat_str [n][0] = '[I:repeat %s%d]' % (mid_str, repeat_count)
1096
+ del s.repeat_str [n] # remove closed repeats
1097
+ toct = e.findtext ('transpose/octave-change', '')
1098
+ if toct: steps += 12 * int (toct) # extra transposition of toct octaves
1099
+ for clef in e.findall ('clef'): # a part can have multiple staves
1100
+ n = int (clef.get ('number', '1')) # local staff number for this clef
1101
+ sgn = clef.findtext ('sign')
1102
+ line = clef.findtext ('line', '') if sgn not in ['percussion','TAB'] else ''
1103
+ cs = teken.get (sgn + line, '')
1104
+ oct = clef.findtext ('clef-octave-change', '') or '0'
1105
+ if oct: cs += {-2:'-15', -1:'-8', 1:'+8', 2:'+15'}.get (int (oct), '')
1106
+ s.clefOct [n] = -int (oct); # xml playback pitch -> abc notation pitch
1107
+ if steps: cs += ' transpose=' + str (steps)
1108
+ stfdtl = e.find ('staff-details')
1109
+ if stfdtl and int (stfdtl.get ('number', '1')) == n:
1110
+ lines = stfdtl.findtext ('staff-lines')
1111
+ if lines:
1112
+ lns= '|||' if lines == '3' and sgn == 'TAB' else lines
1113
+ cs += ' stafflines=%s' % lns
1114
+ s.stafflines = int (lines) # remember for tab staves
1115
+ strings = stfdtl.findall ('staff-tuning')
1116
+ if strings:
1117
+ tuning = [st.findtext ('tuning-step') + st.findtext ('tuning-octave') for st in strings]
1118
+ cs += ' strings=%s' % ','.join (tuning)
1119
+ capo = stfdtl.findtext ('capo')
1120
+ if capo: cs += ' capo=%s' % capo
1121
+ s.curClef [n] = cs # keep track of current clef (for percmap)
1122
+ if first: s.clefMap [n] = cs # clef goes to header (where it is mapped to voices)
1123
+ else:
1124
+ voices = s.stfMap[n] # clef change to all voices of staff n
1125
+ for v in voices:
1126
+ if n != s.curStf [v]: # voice is not at its home staff n
1127
+ dstaff = n - s.curStf [v]
1128
+ s.curStf [v] = n # reset current staff at start of measure to home position
1129
+ s.msc.appendElem (v, '[I:staff %+d]' % dstaff)
1130
+ s.msc.appendElem (v, '[K:%s]' % cs)
1131
+
1132
+ def findVoice (s, i, es):
1133
+ stfnum = int (es[i].findtext ('staff',1)) # directions belong to a staff
1134
+ vs = s.stfMap [stfnum] # voices in this staff
1135
+ v1 = vs [0] if vs else 1 # directions to first voice of staff
1136
+ if s.dirtov1: return stfnum, v1, v1 # option --v1
1137
+ for e in es [i+1:]: # or to the voice of the next note
1138
+ if e.tag == 'note':
1139
+ v = int (e.findtext ('voice', '1'))
1140
+ if s.isSib: v += 100 * int (e.findtext ('staff', '1')) # repair bug in Sibelius
1141
+ stf = s.vce2stf [v] # use our own staff allocation
1142
+ return stf, v, v1 # voice of next note, first voice of staff
1143
+ if e.tag == 'backup': break
1144
+ return stfnum, v1, v1 # no note found, fall back to v1
1145
+
1146
+ def doDirection (s, e, i, es): # parse a musicXML direction tag
1147
+ def addDirection (x, vs, tijd, stfnum):
1148
+ if not x: return
1149
+ vs = s.stfMap [stfnum] if '!8v' in x else [vs] # ottava's go to all voices of staff
1150
+ for v in vs:
1151
+ if tijd != None: # insert at time of encounter
1152
+ s.msc.appendElemT (v, x.replace ('(',')').replace ('ped','ped-up'), tijd)
1153
+ else:
1154
+ s.msc.appendElem (v, x)
1155
+ def startStop (dtype, vs, stfnum=1):
1156
+ typmap = {'down':'!8va(!', 'up':'!8vb(!', 'crescendo':'!<(!', 'diminuendo':'!>(!', 'start':'!ped!'}
1157
+ type = t.get ('type', '')
1158
+ k = dtype + t.get ('number', '1') # key to match the closing direction
1159
+ if type in typmap: # opening the direction
1160
+ x = typmap [type]
1161
+ if k in s.dirStk: # closing direction already encountered
1162
+ stype, tijd = s.dirStk [k]; del s.dirStk [k]
1163
+ if stype == 'stop':
1164
+ addDirection (x, vs, tijd, stfnum)
1165
+ else:
1166
+ info ('%s direction %s has no stop in part %d, measure %d, voice %d' % (dtype, stype, s.msr.ixp+1, s.msr.ixm+1, vs+1))
1167
+ s.dirStk [k] = ((type , vs)) # remember voice and type for closing
1168
+ else:
1169
+ s.dirStk [k] = ((type , vs)) # remember voice and type for closing
1170
+ elif type == 'stop':
1171
+ if k in s.dirStk: # matching open direction found
1172
+ type, vs = s.dirStk [k]; del s.dirStk [k] # into the same voice
1173
+ if type == 'stop':
1174
+ info ('%s direction %s has double stop in part %d, measure %d, voice %d' % (dtype, type, s.msr.ixp+1, s.msr.ixm+1, vs+1))
1175
+ x = ''
1176
+ else:
1177
+ x = typmap [type].replace ('(',')').replace ('ped','ped-up')
1178
+ else: # closing direction found before opening
1179
+ s.dirStk [k] = ('stop', s.msc.tijd)
1180
+ x = '' # delay code generation until opening found
1181
+ else: raise ValueError ('wrong direction type')
1182
+ addDirection (x, vs, None, stfnum)
1183
+ tempo, wrdstxt = None, ''
1184
+ plcmnt = e.get ('placement')
1185
+ stf, vs, v1 = s.findVoice (i, es)
1186
+ jmp = '' # for jump sound elements: dacapo, dalsegno and family
1187
+ jmps = [('dacapo','D.C.'),('dalsegno','D.S.'),('tocoda','dacoda'),('fine','fine'),('coda','O'),('segno','S')]
1188
+ t = e.find ('sound') # there are many possible attributes for sound
1189
+ if t != None:
1190
+ minst = t.find ('midi-instrument')
1191
+ if minst:
1192
+ prg = t.findtext ('midi-instrument/midi-program')
1193
+ chn = t.findtext ('midi-instrument/midi-channel')
1194
+ vids = [v for v, id in s.vceInst.items () if id == minst.get ('id')]
1195
+ if vids: vs = vids [0] # direction for the indentified voice, not the staff
1196
+ parm, inst = ('program', str (int (prg) - 1)) if prg else ('channel', chn)
1197
+ if inst and abcOut.volpan > 0: s.msc.appendElem (vs, '[I:MIDI= %s %s]' % (parm, inst))
1198
+ tempo = t.get ('tempo') # look for tempo attribute
1199
+ if tempo:
1200
+ tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
1201
+ tempo_units = (1,4) # always 1/4 for sound elements!
1202
+ for r, v in jmps:
1203
+ if t.get (r, ''): jmp = v; break
1204
+ dirtypes = e.findall ('direction-type')
1205
+ for dirtyp in dirtypes:
1206
+ units = { 'whole': (1,1), 'half': (1,2), 'quarter': (1,4), 'eighth': (1,8) }
1207
+ metr = dirtyp.find ('metronome')
1208
+ if metr != None:
1209
+ t = metr.findtext ('beat-unit', '')
1210
+ if t in units: tempo_units = units [t]
1211
+ else: tempo_units = units ['quarter']
1212
+ if metr.find ('beat-unit-dot') != None:
1213
+ tempo_units = simplify (tempo_units [0] * 3, tempo_units [1] * 2)
1214
+ tmpro = re.search ('[.\d]+', metr.findtext ('per-minute')) # look for a number
1215
+ if tmpro: tempo = tmpro.group () # overwrites the value set by the sound element of this direction
1216
+ t = dirtyp.find ('wedge')
1217
+ if t != None: startStop ('wedge', vs)
1218
+ allwrds = dirtyp.findall ('words') # insert text annotations
1219
+ if not allwrds: allwrds = dirtyp.findall ('rehearsal') # treat rehearsal mark as text annotation
1220
+ for wrds in allwrds:
1221
+ if jmp: # ignore the words when a jump sound element is present in this direction
1222
+ s.msc.appendElem (vs, '!%s!' % jmp , 1) # to voice
1223
+ break
1224
+ plc = plcmnt == 'below' and '_' or '^'
1225
+ if float (wrds.get ('default-y', '0')) < 0: plc = '_'
1226
+ wrdstxt += (wrds.text or '').replace ('"','\\"').replace ('\n', '\\n')
1227
+ wrdstxt = wrdstxt.strip ()
1228
+ for key, val in dynamics_map.items ():
1229
+ if dirtyp.find ('dynamics/' + key) != None:
1230
+ s.msc.appendElem (vs, val, 1) # to voice
1231
+ if dirtyp.find ('coda') != None: s.msc.appendElem (vs, 'O', 1)
1232
+ if dirtyp.find ('segno') != None: s.msc.appendElem (vs, 'S', 1)
1233
+ t = dirtyp.find ('octave-shift')
1234
+ if t != None: startStop ('octave-shift', vs, stf) # assume size == 8 for the time being
1235
+ t = dirtyp.find ('pedal')
1236
+ if t != None and s.ped:
1237
+ if not s.pedVce: s.pedVce = vs
1238
+ startStop ('pedal', s.pedVce)
1239
+ if dirtyp.findtext ('other-direction') == 'diatonic fretting': s.diafret = 1;
1240
+ if tempo:
1241
+ tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
1242
+ if s.msc.tijd == 0 and s.msr.ixm == 0: # first measure -> header
1243
+ abcOut.tempo = tempo
1244
+ abcOut.tempo_units = tempo_units
1245
+ else:
1246
+ s.msc.appendElem (v1, '[Q:%d/%d=%s]' % (tempo_units [0], tempo_units [1], tempo)) # otherwise -> 1st voice
1247
+ if wrdstxt: s.msc.appendElem (vs, '"%s%s"' % (plc, wrdstxt), 1) # to voice, but after tempo
1248
+
1249
+ def doHarmony (s, e, i, es): # parse a musicXMl harmony tag
1250
+ _, vt, _ = s.findVoice (i, es)
1251
+ short = {'major':'', 'minor':'m', 'augmented':'+', 'diminished':'dim', 'dominant':'7', 'half-diminished':'m7b5'}
1252
+ accmap = {'major':'maj', 'dominant':'', 'minor':'m', 'diminished':'dim', 'augmented':'+', 'suspended':'sus'}
1253
+ modmap = {'second':'2', 'fourth':'4', 'seventh':'7', 'sixth':'6', 'ninth':'9', '11th':'11', '13th':'13'}
1254
+ altmap = {'1':'#', '0':'', '-1':'b'}
1255
+ root = e.findtext ('root/root-step','')
1256
+ alt = altmap.get (e.findtext ('root/root-alter'), '')
1257
+ sus = ''
1258
+ kind = e.findtext ('kind', '')
1259
+ if kind in short: kind = short [kind]
1260
+ elif '-' in kind: # xml chord names: <triad name>-<modification>
1261
+ triad, mod = kind.split ('-')
1262
+ kind = accmap.get (triad, '') + modmap.get (mod, '')
1263
+ if kind.startswith ('sus'): kind, sus = '', kind # sus-suffix goes to the end
1264
+ elif kind == 'none': kind = e.find ('kind').get ('text','')
1265
+ degrees = e.findall ('degree')
1266
+ for d in degrees: # chord alterations
1267
+ kind += altmap.get (d.findtext ('degree-alter'),'') + d.findtext ('degree-value','')
1268
+ kind = kind.replace ('79','9').replace ('713','13').replace ('maj6','6')
1269
+ bass = e.findtext ('bass/bass-step','') + altmap.get (e.findtext ('bass/bass-alter'),'')
1270
+ s.msc.appendElem (vt, '"%s%s%s%s%s"' % (root, alt, kind, sus, bass and '/' + bass), 1)
1271
+
1272
+ def doBarline (s, e): # 0 = no repeat, 1 = begin repeat, 2 = end repeat
1273
+ rep = e.find ('repeat')
1274
+ if rep != None: rep = rep.get ('direction')
1275
+ if s.unfold: # unfold repeat, don't translate barlines
1276
+ return rep and (rep == 'forward' and 1 or 2) or 0
1277
+ loc = e.get ('location', 'right') # right is the default
1278
+ if loc == 'right': # only change style for the right side
1279
+ style = e.findtext ('bar-style')
1280
+ if style == 'light-light': s.msr.rline = '||'
1281
+ elif style == 'light-heavy': s.msr.rline = '|]'
1282
+ if rep != None: # repeat found
1283
+ if rep == 'forward': s.msr.lline = ':'
1284
+ else: s.msr.rline = ':|' # override barline style
1285
+ end = e.find ('ending')
1286
+ if end != None:
1287
+ if end.get ('type') == 'start':
1288
+ n = end.get ('number', '1').replace ('.','').replace (' ','')
1289
+ try: list (map (int, n.split (','))) # should be a list of integers
1290
+ except: n = '"%s"' % n.strip () # illegal musicXML
1291
+ s.msr.lnum = n # assume a start is always at the beginning of a measure
1292
+ elif s.msr.rline == '|': # stop and discontinue the same in ABC ?
1293
+ s.msr.rline = '||' # to stop on a normal barline use || in ABC ?
1294
+ return 0
1295
+
1296
+ def doPrint (s, e): # print element, measure number -> insert a line break
1297
+ if e.get ('new-system') == 'yes' or e.get ('new-page') == 'yes':
1298
+ if not s.nolbrk: return '$' # a line break
1299
+
1300
+ def doPartList (s, e): # translate the start/stop-event-based xml-partlist into proper tree
1301
+ for sp in e.findall ('part-list/score-part'):
1302
+ midi = {}
1303
+ for m in sp.findall ('midi-instrument'):
1304
+ x = [m.findtext (p, s.midDflt [i]) for i,p in enumerate (['midi-channel','midi-program','volume','pan'])]
1305
+ pan = float (x[3])
1306
+ if pan >= -90 and pan <= 90: # would be better to map behind-pannings
1307
+ pan = (float (x[3]) + 90) / 180 * 127 # xml between -90 and +90
1308
+ midi [m.get ('id')] = [int (x[0]), int (x[1]), float (x[2]) * 1.27, pan] # volume 100 -> midi 127
1309
+ up = m.findtext ('midi-unpitched')
1310
+ if up: s.drumInst [m.get ('id')] = int (up) - 1 # store midi-pitch for channel 10 notes
1311
+ s.instMid.append (midi)
1312
+ ps = e.find ('part-list') # partlist = [groupelem]
1313
+ xs = getPartlist (ps) # groupelem = partname | grouplist
1314
+ partlist, _ = parseParts (xs, {}, []) # grouplist = [groupelem, ..., groupdata]
1315
+ return partlist # groupdata = [group-symbol, group-barline, group-name, group-abbrev]
1316
+
1317
+ def mkTitle (s, e):
1318
+ def filterCredits (y): # y == filter level, higher filters less
1319
+ cs = []
1320
+ for x in credits: # skip redundant credit lines
1321
+ if y < 6 and (x in title or x in mvttl): continue # sure skip
1322
+ if y < 5 and (x in composer or x in lyricist): continue # almost sure skip
1323
+ if y < 4 and ((title and title in x) or (mvttl and mvttl in x)): continue # may skip too much
1324
+ if y < 3 and ([1 for c in composer if c in x] or [1 for c in lyricist if c in x]): continue # skips too much
1325
+ if y < 2 and re.match (r'^[\d\W]*$', x): continue # line only contains numbers and punctuation
1326
+ cs.append (x)
1327
+ if y == 0 and (title + mvttl): cs = '' # default: only credit when no title set
1328
+ return cs
1329
+ title = e.findtext ('work/work-title', '').strip ()
1330
+ mvttl = e.findtext ('movement-title', '').strip ()
1331
+ composer, lyricist, credits = [], [], []
1332
+ for creator in e.findall ('identification/creator'):
1333
+ if creator.text:
1334
+ if creator.get ('type') == 'composer':
1335
+ composer += [line.strip () for line in creator.text.split ('\n')]
1336
+ elif creator.get ('type') in ('lyricist', 'transcriber'):
1337
+ lyricist += [line.strip () for line in creator.text.split ('\n')]
1338
+ for rights in e.findall ('identification/rights'):
1339
+ if rights.text:
1340
+ lyricist += [line.strip () for line in rights.text.split ('\n')]
1341
+ for credit in e.findall('credit'):
1342
+ cs = ''.join (e.text or '' for e in credit.findall('credit-words'))
1343
+ credits += [re.sub (r'\s*[\r\n]\s*', ' ', cs)]
1344
+ credits = filterCredits (s.ctf)
1345
+ if title: title = 'T:%s\n' % title.replace ('\n', '\nT:')
1346
+ if mvttl: title += 'T:%s\n' % mvttl.replace ('\n', '\nT:')
1347
+ if credits: title += '\n'.join (['T:%s' % c for c in credits]) + '\n'
1348
+ if composer: title += '\n'.join (['C:%s' % c for c in composer]) + '\n'
1349
+ if lyricist: title += '\n'.join (['Z:%s' % c for c in lyricist]) + '\n'
1350
+ if title: abcOut.title = title[:-1]
1351
+ s.isSib = 'Sibelius' in (e.findtext ('identification/encoding/software') or '')
1352
+ if s.isSib: info ('Sibelius MusicXMl is unreliable')
1353
+
1354
+ def doDefaults (s, e):
1355
+ if not s.doPageFmt: return # return if -pf option absent
1356
+ d = e.find ('defaults');
1357
+ if d == None: return;
1358
+ mils = d.findtext ('scaling/millimeters') # mills == staff height (mm)
1359
+ tenths = d.findtext ('scaling/tenths') # staff height in tenths
1360
+ if not mils or not tenths: return
1361
+ xmlScale = float (mils) / float (tenths) / 10 # tenths -> mm
1362
+ space = 10 * xmlScale # space between staff lines == 10 tenths
1363
+ abcScale = space / 0.2117 # 0.2117 cm = 6pt = space between staff lines for scale = 1.0 in abcm2ps
1364
+ abcOut.pageFmt ['scale'] = abcScale
1365
+ eks = 2 * ['page-layout/'] + 4 * ['page-layout/page-margins/']
1366
+ eks = [a+b for a,b in zip (eks, 'page-height,page-width,left-margin,right-margin,top-margin,bottom-margin'.split (','))]
1367
+ for i in range (6):
1368
+ v = d.findtext (eks [i])
1369
+ k = abcOut.pagekeys [i+1] # pagekeys [0] == scale already done, skip it
1370
+ if not abcOut.pageFmt [k] and v:
1371
+ try: abcOut.pageFmt [k] = float (v) * xmlScale # -> cm
1372
+ except: info ('illegal value %s for xml element %s', (v, eks [i])); continue # just skip illegal values
1373
+
1374
+ def locStaffMap (s, part, maten): # map voice to staff with majority voting
1375
+ vmap = {} # {voice -> {staff -> n}} count occurrences of voice in staff
1376
+ s.vceInst = {} # {voice -> instrument id} for this part
1377
+ s.msc.vnums = {} # voice id's
1378
+ s.hasStems = {} # XML voice nums with at least one note with a stem (for tab key)
1379
+ s.stfMap, s.clefMap = {}, {} # staff -> [voices], staff -> clef
1380
+ ns = part.findall ('measure/note')
1381
+ for n in ns: # count staff allocations for all notes
1382
+ v = int (n.findtext ('voice', '1'))
1383
+ if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
1384
+ s.msc.vnums [v] = 1 # collect all used voice id's in this part
1385
+ sn = int (n.findtext ('staff', '1'))
1386
+ s.stfMap [sn] = []
1387
+ if v not in vmap:
1388
+ vmap [v] = {sn:1}
1389
+ else:
1390
+ d = vmap[v] # counter for voice v
1391
+ d[sn] = d.get (sn, 0) + 1 # ++ number of allocations for staff sn
1392
+ x = n.find ('instrument')
1393
+ if x != None: s.vceInst [v] = x.get ('id')
1394
+ x, noRest = n.findtext ('stem'), n.find ('rest') == None
1395
+ if noRest and (not x or x != 'none'): s.hasStems [v] = 1 # XML voice v has at least one stem
1396
+ vks = list (vmap.keys ())
1397
+ if s.jscript or s.isSib: vks.sort ()
1398
+ for v in vks: # choose staff with most allocations for each voice
1399
+ xs = [(n, sn) for sn, n in vmap[v].items ()]
1400
+ xs.sort ()
1401
+ stf = xs[-1][1] # the winner: staff with most notes of voice v
1402
+ s.stfMap [stf].append (v)
1403
+ s.vce2stf [v] = stf # reverse map
1404
+ s.curStf [v] = stf # current staff of XML voice v
1405
+
1406
+ def addStaffMap (s, vvmap): # vvmap: xml voice number -> global abc voice number
1407
+ part = [] # default: brace on staffs of one part
1408
+ for stf, voices in sorted (s.stfMap.items ()): # s.stfMap has xml staff and voice numbers
1409
+ locmap = [vvmap [iv] for iv in voices if iv in vvmap]
1410
+ nostem = [(iv not in s.hasStems) for iv in voices if iv in vvmap] # same order as locmap
1411
+ if locmap: # abc voice number of staff stf
1412
+ part.append (locmap)
1413
+ clef = s.clefMap.get (stf, 'treble') # {xml staff number -> clef}
1414
+ for i, iv in enumerate (locmap):
1415
+ clef_attr = ''
1416
+ if clef.startswith ('tab'):
1417
+ if nostem [i] and 'nostems' not in clef: clef_attr = ' nostems'
1418
+ if s.diafret and 'diafret' not in clef: clef_attr += ' diafret' # for all voices in the part
1419
+ abcOut.clefs [iv] = clef + clef_attr # add nostems when all notes of voice had no stem
1420
+ s.gStfMap.append (part)
1421
+
1422
+ def addMidiMap (s, ip, vvmap): # map abc voices to midi settings
1423
+ instr = s.instMid [ip] # get the midi settings for this part
1424
+ if instr.values (): defInstr = list(instr.values ())[0] # default settings = first instrument
1425
+ else: defInstr = s.midDflt # no instruments defined
1426
+ xs = []
1427
+ for v, vabc in vvmap.items (): # xml voice num, abc voice num
1428
+ ks = sorted (s.drumNotes.items ())
1429
+ ds = [(nt, step, midi, head) for (vd, nt), (step, midi, head) in ks if v == vd] # map perc notes
1430
+ id = s.vceInst.get (v, '') # get the instrument-id for part with multiple instruments
1431
+ if id in instr: # id is defined as midi-instrument in part-list
1432
+ xs.append ((vabc, instr [id] + ds)) # get midi settings for id
1433
+ else: xs.append ((vabc, defInstr + ds)) # only one instrument for this part
1434
+ xs.sort () # put abc voices in order
1435
+ s.midiMap.extend ([midi for v, midi in xs])
1436
+ snaarmap = ['E','G','B','d', 'f', 'a', "c'", "e'"]
1437
+ diamap = '0,1-,1,1+,2,3,3,4,4,5,6,6+,7,8-,8,8+,9,10,10,11,11,12,13,13+,14'.split (',')
1438
+ for k in sorted (s.tabmap.keys ()): # add %%map's for all tab voices
1439
+ v, noot = k;
1440
+ snaar, fret = s.tabmap [k];
1441
+ if s.diafret: fret = diamap [int (fret)]
1442
+ vabc = vvmap [v]
1443
+ snaar = s.stafflines - int (snaar)
1444
+ xs = s.tabVceMap.get (vabc, [])
1445
+ xs.append ('%%%%map tab%d %s print=%s heads=kop%s\n' % (vabc, noot, snaarmap [snaar], fret))
1446
+ s.tabVceMap [vabc] = xs
1447
+ s.koppen [fret] = 1 # collect noteheads for SVG defs
1448
+
1449
+ def parse (s, fobj):
1450
+ vvmapAll = {} # collect xml->abc voice maps (vvmap) of all parts
1451
+ e = E.parse (fobj)
1452
+ s.mkTitle (e)
1453
+ s.doDefaults (e)
1454
+ partlist = s.doPartList (e)
1455
+ parts = e.findall ('part')
1456
+ for ip, p in enumerate (parts):
1457
+ maten = p.findall ('measure')
1458
+ s.locStaffMap (p, maten) # {voice -> staff} for this part
1459
+ s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
1460
+ s.clefOct = {} # xml staff number -> current clef-octave-change
1461
+ s.curClef = {} # xml staff number -> current abc clef
1462
+ s.stemDir = {} # xml voice number -> current stem direction
1463
+ s.tabmap = {} # (xml voice, abc note) -> (string, fret)
1464
+ s.diafret = 0 # use diatonic fretting
1465
+ s.stafflines = 5
1466
+ s.msc.initVoices (newPart = 1) # create all voices
1467
+ aantalHerhaald = 0 # keep track of number of repititions
1468
+ herhaalMaat = 0 # target measure of the repitition
1469
+ divisions = [] # current value of <divisions> for each measure
1470
+ s.msr = Measure (ip) # various measure data
1471
+ while s.msr.ixm < len (maten):
1472
+ maat = maten [s.msr.ixm]
1473
+ herhaal, lbrk = 0, ''
1474
+ s.msr.reset ()
1475
+ s.curalts = {} # passing accidentals are reset each measure
1476
+ es = list (maat)
1477
+ for i, e in enumerate (es):
1478
+ if e.tag == 'note': s.doNote (e)
1479
+ elif e.tag == 'attributes': s.doAttr (e)
1480
+ elif e.tag == 'direction': s.doDirection (e, i, es)
1481
+ elif e.tag == 'sound': s.doDirection (maat, i, es) # sound element directly in measure!
1482
+ elif e.tag == 'harmony': s.doHarmony (e, i, es)
1483
+ elif e.tag == 'barline': herhaal = s.doBarline (e)
1484
+ elif e.tag == 'backup':
1485
+ dt = int (e.findtext ('duration'))
1486
+ if chkbug (dt, s.msr): s.msc.incTime (-dt)
1487
+ elif e.tag == 'forward':
1488
+ dt = int (e.findtext ('duration'))
1489
+ if chkbug (dt, s.msr): s.msc.incTime (dt)
1490
+ elif e.tag == 'print': lbrk = s.doPrint (e)
1491
+ s.msc.addBar (lbrk, s.msr)
1492
+ divisions.append (s.msr.divs)
1493
+ if herhaal == 1:
1494
+ herhaalMaat = s.msr.ixm
1495
+ s.msr.ixm += 1
1496
+ elif herhaal == 2:
1497
+ if aantalHerhaald < 1: # jump
1498
+ s.msr.ixm = herhaalMaat
1499
+ aantalHerhaald += 1
1500
+ else:
1501
+ aantalHerhaald = 0 # reset
1502
+ s.msr.ixm += 1 # just continue
1503
+ else: s.msr.ixm += 1 # on to the next measure
1504
+ for rv in s.repeat_str.values (): # close hanging measure-repeats without stop
1505
+ rv [0] = '[I:repeat %s %d]' % (rv [1], 1)
1506
+ vvmap = s.msc.outVoices (divisions, ip, s.isSib)
1507
+ s.addStaffMap (vvmap) # update global staff map
1508
+ s.addMidiMap (ip, vvmap)
1509
+ vvmapAll.update (vvmap)
1510
+ if vvmapAll: # skip output if no part has any notes
1511
+ abcOut.mkHeader (s.gStfMap, partlist, s.midiMap, s.tabVceMap, s.koppen)
1512
+ abcOut.writeall ()
1513
+ else: info ('nothing written, %s has no notes ...' % abcOut.fnmext)
1514
+
1515
+ #----------------
1516
+ # Main Program
1517
+ #----------------
1518
+ if __name__ == '__main__':
1519
+ from optparse import OptionParser
1520
+ from glob import glob
1521
+ from zipfile import ZipFile
1522
+ ustr = '%prog [-h] [-u] [-m] [-c C] [-d D] [-n CPL] [-b BPL] [-o DIR] [-v V]\n'
1523
+ ustr += '[-x] [-p PFMT] [-t] [-s] [-i] [--v1] [--noped] [--stems] <file1> [<file2> ...]'
1524
+ parser = OptionParser (usage=ustr, version=str(VERSION))
1525
+ parser.add_option ("-u", action="store_true", help="unfold simple repeats")
1526
+ parser.add_option ("-m", action="store", help="0 -> no %%MIDI, 1 -> minimal %%MIDI, 2-> all %%MIDI", default=0)
1527
+ parser.add_option ("-c", action="store", type="int", help="set credit text filter to C", default=0, metavar='C')
1528
+ parser.add_option ("-d", action="store", type="int", help="set L:1/D", default=0, metavar='D')
1529
+ parser.add_option ("-n", action="store", type="int", help="CPL: max number of characters per line (default 100)", default=0, metavar='CPL')
1530
+ parser.add_option ("-b", action="store", type="int", help="BPL: max number of bars per line", default=0, metavar='BPL')
1531
+ parser.add_option ("-o", action="store", help="store abc files in DIR", default='', metavar='DIR')
1532
+ parser.add_option ("-v", action="store", type="int", help="set volta typesetting behaviour to V", default=0, metavar='V')
1533
+ parser.add_option ("-x", action="store_true", help="output no line breaks")
1534
+ parser.add_option ("-p", action="store", help="pageformat PFMT (cm) = scale, pageheight, pagewidth, leftmargin, rightmargin, topmargin, botmargin", default='', metavar='PFMT')
1535
+ parser.add_option ("-j", action="store_true", help="switch for compatibility with javascript version")
1536
+ parser.add_option ("-t", action="store_true", help="translate perc- and tab-staff to ABC code with %%map, %%voicemap")
1537
+ parser.add_option ("-s", action="store_true", help="shift node heads 3 units left in a tab staff")
1538
+ parser.add_option ("--v1", action="store_true", help="start-stop directions allways to first voice of staff")
1539
+ parser.add_option ("--noped", action="store_false", help="skip all pedal directions", dest='ped', default=True)
1540
+ parser.add_option ("--stems", action="store_true", help="translate stem directions", dest='stm', default=False)
1541
+ parser.add_option ("-i", action="store_true", help="read xml file from standard input")
1542
+ options, args = parser.parse_args ()
1543
+ if options.n < 0: parser.error ('only values >= 0')
1544
+ if options.b < 0: parser.error ('only values >= 0')
1545
+ if options.d and options.d not in [2**n for n in range (10)]:
1546
+ parser.error ('D should be on of %s' % ','.join ([str(2**n) for n in range (10)]))
1547
+ options.p = options.p and options.p.split (',') or [] # ==> [] | [string]
1548
+ if len (args) == 0 and not options.i: parser.error ('no input file given')
1549
+ pad = options.o
1550
+ if pad:
1551
+ if not os.path.exists (pad): os.mkdir (pad)
1552
+ if not os.path.isdir (pad): parser.error ('%s is not a directory' % pad)
1553
+ fnmext_list = []
1554
+ for i in args: fnmext_list += glob (i)
1555
+ if options.i: fnmext_list = ['stdin.xml']
1556
+ if not fnmext_list: parser.error ('none of the input files exist')
1557
+ for X, fnmext in enumerate (fnmext_list):
1558
+ fnm, ext = os.path.splitext (fnmext)
1559
+ if ext.lower () not in ('.xml','.mxl','.musicxml'):
1560
+ info ('skipped input file %s, it should have extension .xml or .mxl' % fnmext)
1561
+ continue
1562
+ if os.path.isdir (fnmext):
1563
+ info ('skipped directory %s. Only files are accepted' % fnmext)
1564
+ continue
1565
+ if fnmext == 'stdin.xml':
1566
+ fobj = sys.stdin
1567
+ elif ext.lower () == '.mxl': # extract .xml file from .mxl file
1568
+ z = ZipFile(fnmext)
1569
+ for n in z.namelist(): # assume there is always an xml file in a mxl archive !!
1570
+ if (n[:4] != 'META') and (n[-4:].lower() == '.xml'):
1571
+ fobj = z.open (n)
1572
+ break # assume only one MusicXML file per archive
1573
+ else:
1574
+ fobj = open (fnmext, 'rb') # open regular xml file
1575
+
1576
+ abcOut = ABCoutput (fnm + '.abc', pad, X, options) # create global ABC output object
1577
+ psr = Parser (options) # xml parser
1578
+ try:
1579
+ psr.parse (fobj) # parse file fobj and write abc to <fnm>.abc
1580
+ except:
1581
+ etype, value, traceback = sys.exc_info () # works in python 2 & 3
1582
+ info ('** %s occurred: %s in %s' % (etype, value, fnmext), 0)
semantic_search/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Semantic Search Codebase
2
+
3
+ ## Overview
4
+ CLaMP 2 is a state-of-the-art multimodal music information retrieval system designed to work with 101 languages. This codebase includes scripts for evaluating model performance, performing semantic searches, and calculating similarity metrics based on CLaMP2-extracted **nomarlized** feature vectors from music or text data. Below is a description of the scripts contained in the `semantic_search/` folder.
5
+
6
+ ## Repository Structure
7
+ The `semantic_search/` folder contains the following scripts:
8
+
9
+ ### 1. `clamp2_score.py`
10
+ This script calculates the cosine similarity between the average feature vectors extracted from two sets of `.npy` files, serving as a measure of similarity between the reference and test datasets.
11
+
12
+ It can be used to validate the semantic similarity between generated music and ground truth, providing an objective metric. Through empirical observation, we found that this metric aligns well with subjective judgments made by individuals with professional music expertise.
13
+
14
+ **Usage:**
15
+ ```bash
16
+ python clamp2_score.py <reference_folder> <test_folder>
17
+ ```
18
+ - `reference_folder`: Path to the folder containing reference `.npy` files.
19
+ - `test_folder`: Path to the folder containing test `.npy` files.
20
+
21
+ **Functionality:**
22
+ - Loads all `.npy` files from the specified folders.
23
+ - Computes the average feature vector for each folder.
24
+ - Calculates the cosine similarity between the two averaged vectors.
25
+ - Outputs the similarity score rounded to four decimal places.
26
+
27
+ ### 2. `semantic_search.py`
28
+ This script performs semantic search by calculating the cosine similarity between a query feature and a set of features stored in `.npy` files.
29
+
30
+ **Usage:**
31
+ ```bash
32
+ python semantic_search.py <query_file> <features_folder> [--top_k TOP_K]
33
+ ```
34
+ - `query_file`: Path to the query feature file (e.g., `ballad.npy`).
35
+ - `features_folder`: Path to the folder containing feature files for comparison.
36
+ - `--top_k`: (Optional) Number of top similar items to display. Defaults to 10 if not specified.
37
+
38
+ **Functionality:**
39
+ - Loads a query feature from the specified file.
40
+ - Loads feature vectors from the given folder.
41
+ - Computes cosine similarity between the query feature and each loaded feature vector.
42
+ - Displays the top K most similar features along with their similarity scores.
43
+
44
+ ### 3. `semantic_search_metrics.py`
45
+ This script calculates evaluation metrics for semantic search by comparing query features to reference features.
46
+
47
+ **Usage:**
48
+ ```bash
49
+ python semantic_search_metrics.py <query_folder> <reference_folder>
50
+ ```
51
+ - `query_folder`: Path to the folder containing query features (in `.npy` format).
52
+ - `reference_folder`: Path to the folder containing reference features (in `.npy` format).
53
+
54
+ **Functionality:**
55
+ - Loads query features from the specified folder.
56
+ - Loads reference features from the given folder.
57
+ - Computes the following metrics based on cosine similarity:
58
+ - **Mean Reciprocal Rank (MRR)**
59
+ - **Hit@1**
60
+ - **Hit@10**
61
+ - **Hit@100**
62
+ - Outputs the calculated metrics to the console.
semantic_search/clamp2_score.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import argparse
4
+
5
+ def load_npy_files(folder_path):
6
+ """
7
+ Load all .npy files from a specified folder and return a list of numpy arrays.
8
+ """
9
+ npy_list = []
10
+ for file_name in os.listdir(folder_path):
11
+ if file_name.endswith('.npy'):
12
+ file_path = os.path.join(folder_path, file_name)
13
+ np_array = np.load(file_path)[0]
14
+ npy_list.append(np_array)
15
+ return npy_list
16
+
17
+ def average_npy(npy_list):
18
+ """
19
+ Compute the average of a list of numpy arrays.
20
+ """
21
+ return np.mean(npy_list, axis=0)
22
+
23
+ def cosine_similarity(vec1, vec2):
24
+ """
25
+ Compute cosine similarity between two numpy arrays.
26
+ """
27
+ dot_product = np.dot(vec1, vec2)
28
+
29
+ norm_vec1 = np.linalg.norm(vec1)
30
+ norm_vec2 = np.linalg.norm(vec2)
31
+
32
+ cosine_sim = dot_product / (norm_vec1 * norm_vec2)
33
+
34
+ return cosine_sim
35
+
36
+ if __name__ == '__main__':
37
+ # Set up argument parsing for input folders
38
+ parser = argparse.ArgumentParser(description="Calculate cosine similarity between average feature vectors.")
39
+ parser.add_argument('reference', type=str, help='Path to the reference folder containing .npy files.')
40
+ parser.add_argument('test', type=str, help='Path to the test folder containing .npy files.')
41
+ args = parser.parse_args()
42
+
43
+ reference = args.reference
44
+ test = args.test
45
+ # Load .npy files
46
+ ref_npy = load_npy_files(reference)
47
+ test_npy = load_npy_files(test)
48
+
49
+ # Compute the average of each list of numpy arrays
50
+ avg_ref = average_npy(ref_npy)
51
+ avg_test = average_npy(test_npy)
52
+
53
+ # Compute the cosine similarity between the two averaged numpy arrays
54
+ similarity = cosine_similarity(avg_ref, avg_test)
55
+
56
+ # Output the cosine similarity rounded to four decimal places
57
+ print(f"Cosine similarity between '{reference}' and '{test}': {similarity:.4f}")
semantic_search/semantic_search.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+
6
+ def get_info(folder_path):
7
+ """
8
+ Load all .npy files from a specified folder and return a dictionary of features.
9
+ """
10
+ files = sorted(os.listdir(folder_path))
11
+ features = {}
12
+
13
+ for file in files:
14
+ if file.endswith(".npy"):
15
+ key = file.split(".")[0]
16
+ features[key] = np.load(os.path.join(folder_path, file))[0]
17
+
18
+ return features
19
+
20
+ def main(query_file, features_folder, top_k=10):
21
+ # Load query feature from the specified file
22
+ query_feature = np.load(query_file)[0] # Load directly from the query file
23
+ query_tensor = torch.tensor(query_feature).unsqueeze(dim=0)
24
+
25
+ # Load key features from the specified folder
26
+ key_features = get_info(features_folder)
27
+
28
+ # Prepare tensor for key features
29
+ key_feats_tensor = torch.tensor(np.array([key_features[k] for k in key_features.keys()]))
30
+
31
+ # Calculate cosine similarity
32
+ similarities = torch.cosine_similarity(query_tensor, key_feats_tensor)
33
+ ranked_indices = torch.argsort(similarities, descending=True)
34
+
35
+ # Get the keys for the features
36
+ keys = list(key_features.keys())
37
+
38
+ print(f"Top {top_k} similar items:")
39
+ for i in range(top_k):
40
+ print(keys[ranked_indices[i]], similarities[ranked_indices[i]].item())
41
+
42
+ if __name__ == '__main__':
43
+ # Set up argument parsing for input paths
44
+ parser = argparse.ArgumentParser(description="Find top similar features based on cosine similarity.")
45
+ parser.add_argument('query_file', type=str, help='Path to the query feature file (e.g., ballad.npy).')
46
+ parser.add_argument('features_folder', type=str, help='Path to the folder containing feature files for comparison.')
47
+ parser.add_argument('--top_k', type=int, default=10, help='Number of top similar items to display (default: 10).')
48
+ args = parser.parse_args()
49
+
50
+ # Execute the main functionality
51
+ main(args.query_file, args.features_folder, args.top_k)
semantic_search/semantic_search_metrics.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import argparse
5
+
6
+ def get_features(path):
7
+ """
8
+ Load and return feature data from .npy files in the given directory.
9
+ Each feature is stored in a dictionary with the filename (without extension) as the key.
10
+ """
11
+ files = sorted(os.listdir(path))
12
+ features = {}
13
+
14
+ for file in files:
15
+ if file.endswith(".npy"):
16
+ key = file.split(".")[0]
17
+ features[key] = np.load(os.path.join(path, file))[0]
18
+
19
+ return features
20
+
21
+ def calculate_metrics(query_features, reference_features):
22
+ """
23
+ Calculate MRR, Hit@1, Hit@10, and Hit@100 metrics based on the similarity
24
+ between query and reference features.
25
+ """
26
+ common_keys = set(query_features.keys()) & set(reference_features.keys())
27
+ mrr, hit_1, hit_10, hit_100 = 0, 0, 0, 0
28
+
29
+ for idx, key in enumerate(common_keys):
30
+ # Convert query feature to tensor and add batch dimension
31
+ query_feat = torch.tensor(query_features[key]).unsqueeze(dim=0)
32
+
33
+ # Collect all reference features for common keys
34
+ ref_feats = torch.tensor(np.array([reference_features[k] for k in common_keys]))
35
+
36
+ # Compute cosine similarity between the query and all reference features
37
+ similarities = torch.cosine_similarity(query_feat, ref_feats)
38
+
39
+ # Create a list of (similarity, index) pairs
40
+ indexed_sims = list(enumerate(similarities.tolist()))
41
+
42
+ # Sort by similarity in descending order, with idx-based tie-breaking
43
+ sorted_indices = sorted(indexed_sims, key=lambda x: (x[1], x[0] == idx), reverse=True)
44
+
45
+ # Extract the sorted rank list
46
+ ranks = [x[0] for x in sorted_indices]
47
+
48
+ # Calculate MRR
49
+ mrr += 1 / (ranks.index(idx) + 1)
50
+
51
+ # Calculate Hit@1, Hit@10, Hit@100
52
+ if idx in ranks[:100]:
53
+ hit_100 += 1
54
+ if idx in ranks[:10]:
55
+ hit_10 += 1
56
+ if idx in ranks[:1]:
57
+ hit_1 += 1
58
+
59
+ # Compute the final metrics
60
+ total_keys = len(common_keys)
61
+ print(f"MRR: {round(mrr / total_keys, 4)}")
62
+ print(f"Hit@1: {round(hit_1 / total_keys, 4)}")
63
+ print(f"Hit@10: {round(hit_10 / total_keys, 4)}")
64
+ print(f"Hit@100: {round(hit_100 / total_keys, 4)}")
65
+
66
+ if __name__ == '__main__':
67
+ # Set up argument parsing for input directories
68
+ parser = argparse.ArgumentParser(description="Calculate similarity metrics between query and reference features.")
69
+ parser.add_argument('query_folder', type=str, help='Path to the folder containing query features (.npy files).')
70
+ parser.add_argument('reference_folder', type=str, help='Path to the folder containing reference features (.npy files).')
71
+ args = parser.parse_args()
72
+
73
+ # Load features from the specified folders
74
+ query_features = get_features(args.query_folder)
75
+ reference_features = get_features(args.reference_folder)
76
+
77
+ # Calculate and print the metrics
78
+ calculate_metrics(query_features, reference_features)