Upload 32 files
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- benchmarks.z01 +3 -0
- benchmarks.zip +3 -0
- code/README.md +90 -0
- code/config.py +67 -0
- code/extract_clamp2.py +192 -0
- code/extract_m3.py +162 -0
- code/logs_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.txt +500 -0
- code/logs_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.txt +500 -0
- code/train_clamp2.py +356 -0
- code/train_m3.py +321 -0
- code/utils.py +483 -0
- environment.yml +204 -0
- music_classification/README.md +46 -0
- music_classification/config.py +26 -0
- music_classification/inference_cls.py +71 -0
- music_classification/train_cls.py +293 -0
- music_classification/utils.py +22 -0
- process_data/README.md +307 -0
- process_data/batch_abc2xml.py +58 -0
- process_data/batch_interleaved_abc.py +117 -0
- process_data/batch_midi2mtf.py +80 -0
- process_data/batch_mtf2midi.py +97 -0
- process_data/batch_xml2abc.py +57 -0
- process_data/gpt4_summarize.py +250 -0
- process_data/utils/abc2xml.py +0 -0
- process_data/utils/pyparsing.py +0 -0
- process_data/utils/xml2abc.py +1582 -0
- semantic_search/README.md +62 -0
- semantic_search/clamp2_score.py +57 -0
- semantic_search/semantic_search.py +51 -0
- semantic_search/semantic_search_metrics.py +78 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
overview.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
overview.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
benchmarks.z01 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Sander Wood
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
benchmarks.z01
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d88d8f54b46b5ea5ce2ba318eaf7393ef6aa63a1a7fa5247b4c4cdb9b917f6b0
|
3 |
+
size 20971520
|
benchmarks.zip
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6497fc7f335e8f0a5bbddd474d040a218ef13488e280f5ad0b32b51ca036e89f
|
3 |
+
size 14043301
|
code/README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CLaMP 2 Codebase
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
CLaMP 2 is a state-of-the-art multimodal music information retrieval system designed to work with 101 languages. This codebase includes scripts for training models, extracting features, and utility functions for processing music and text data. Below is a description of the scripts contained in the `code/` folder.
|
5 |
+
|
6 |
+
## Repository Structure
|
7 |
+
The `code/` folder contains the following scripts:
|
8 |
+
|
9 |
+
### 1. `config.py`
|
10 |
+
This script contains the training hyperparameters and file paths used in the `train_clamp2.py` and `train_m3.py` scripts. You can modify parameters such as learning rates, batch sizes, and file locations for training data.
|
11 |
+
|
12 |
+
### 2. `extract_clamp2.py`
|
13 |
+
This script utilizes the pre-trained CLaMP 2 model to extract representations of text (.txt) or music (.abc or .mtf) from a specified input folder and save the features to a target output folder in `.npy` format. The extracted features can be normalized for semantic search or retain temporal information for classification tasks.
|
14 |
+
|
15 |
+
**Usage:**
|
16 |
+
```bash
|
17 |
+
python extract_clamp2.py <input_dir> <output_dir> [--normalize]
|
18 |
+
```
|
19 |
+
- `input_dir`: Directory containing input data files.
|
20 |
+
- `output_dir`: Directory to save the output features.
|
21 |
+
- `--normalize`: (Optional) Normalize the extracted features. Normalization is not required for music classification tasks, but it is required for semantic search tasks.
|
22 |
+
|
23 |
+
### 3. `extract_m3.py`
|
24 |
+
This script employs the pre-trained M3 model to extract representations in interleaved ABC notation and MIDI Text Format (MTF) from the specified input folder, saving the features to the target folder as `.npy` files.
|
25 |
+
|
26 |
+
**Usage:**
|
27 |
+
```bash
|
28 |
+
python extract_m3.py <input_dir> <output_dir>
|
29 |
+
```
|
30 |
+
- `input_dir`: Directory with input files (in .abc or .mtf format).
|
31 |
+
- `output_dir`: Directory to save extracted features.
|
32 |
+
|
33 |
+
### 4. `train_clamp2.py`
|
34 |
+
This script manages the training process for the CLaMP 2 model. It prepares training data from a path specified in the `TRAIN_JSONL` variable, which is defined in the `config.py` file. If `EVAL_JSONL` is provided in the configuration, it will be used for validation. By default, 1% of the training data is reserved for validation.
|
35 |
+
|
36 |
+
CLaMP 2 utilizes the multilingual text encoder `FacebookAI/xlm-roberta-base` for processing text data. Additionally, it employs the M3 model, pre-trained on both ABC and MIDI data, as the multimodal music encoder. If the pre-trained weights for M3 are available and the configuration variable `CLAMP2_LOAD_M3` is set to True, the training script will automatically load the M3 weights.
|
37 |
+
|
38 |
+
**Training Command:**
|
39 |
+
To start the training process, use the following command:
|
40 |
+
|
41 |
+
```bash
|
42 |
+
torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env train_clamp2.py
|
43 |
+
```
|
44 |
+
|
45 |
+
Replace `<number_of_GPUs>` with the number of GPUs you want to use for training.
|
46 |
+
|
47 |
+
**Input Data Format**
|
48 |
+
The input training data should be in JSONL format, where each line contains a single JSON object with the following structure. Fields that do not apply should be set to `None`:
|
49 |
+
|
50 |
+
```json
|
51 |
+
{
|
52 |
+
"title": "Song Title",
|
53 |
+
"composer": "Composer Name",
|
54 |
+
"genres": ["Genre1", "Genre2"],
|
55 |
+
"description": "Song description.",
|
56 |
+
"lyrics": "Song lyrics.",
|
57 |
+
"tags": ["tag1", "tag2"],
|
58 |
+
"ensembles": ["Ensemble Name"],
|
59 |
+
"instruments": ["Instrument1", "Instrument2"],
|
60 |
+
"summary_en": "English summary.",
|
61 |
+
"summary_nen": {
|
62 |
+
"language": "Language Name",
|
63 |
+
"summary": "Summary in specified language."
|
64 |
+
},
|
65 |
+
"filepaths": [
|
66 |
+
"path/to/abc/file.abc",
|
67 |
+
"path/to/mtf/file.mtf"
|
68 |
+
]
|
69 |
+
}
|
70 |
+
```
|
71 |
+
|
72 |
+
For obtaining the English and non-English summaries generated by GPT-4, refer to the `process_data/gpt4_summarize.py` script.
|
73 |
+
|
74 |
+
### 5. `train_m3.py`
|
75 |
+
This script is dedicated to training the M3 model using interleaved ABC and MTF files. The directories for training and optional evaluation data should be specified in the `TRAIN_FOLDERS` and `EVAL_FOLDERS` variables, respectively.
|
76 |
+
|
77 |
+
**Training Command:**
|
78 |
+
To start the training process for the M3 model, use the following command:
|
79 |
+
|
80 |
+
```bash
|
81 |
+
torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env train_m3.py
|
82 |
+
```
|
83 |
+
|
84 |
+
Replace `<number_of_GPUs>` with the number of GPUs you want to use for training.
|
85 |
+
|
86 |
+
**Data Preparation:**
|
87 |
+
The data should be structured in interleaved ABC (.abc) and MTF (.mtf) formats. Please refer to the `process_data/` folder for instructions on how to prepare these formats.
|
88 |
+
|
89 |
+
### 6. `utils.py`
|
90 |
+
This utility script contains various classes for model definitions and functions used for training.
|
code/config.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation
|
2 |
+
WANDB_KEY = "<your_wandb_key>" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging
|
3 |
+
|
4 |
+
# -------------------- Configuration for M3 Training --------------------
|
5 |
+
TRAIN_FOLDERS = [
|
6 |
+
"<path_to_training_data>" # Directory containing training data
|
7 |
+
]
|
8 |
+
|
9 |
+
EVAL_FOLDERS = [
|
10 |
+
"" # (Optional) Directory containing evaluation data
|
11 |
+
]
|
12 |
+
|
13 |
+
PATCH_SIZE = 64 # Size of each patch
|
14 |
+
PATCH_LENGTH = 512 # Length of the patches
|
15 |
+
PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
|
16 |
+
TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder
|
17 |
+
M3_HIDDEN_SIZE = 768 # Size of the hidden layer
|
18 |
+
|
19 |
+
M3_NUM_EPOCH = 100 # Maximum number of epochs for training
|
20 |
+
M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer
|
21 |
+
M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training
|
22 |
+
M3_MASK_RATIO = 0.45 # Ratio of masked elements during training
|
23 |
+
M3_DETERMINISTIC = True # Ensures deterministic results with random seeds
|
24 |
+
M3_WANDB_LOG = True # Enable logging to Weights and Biases
|
25 |
+
M3_LOAD_CKPT = True # Load model weights from a checkpoint if available
|
26 |
+
|
27 |
+
M3_WEIGHTS_PATH = (
|
28 |
+
"weights_m3_p_size_" + str(PATCH_SIZE) +
|
29 |
+
"_p_length_" + str(PATCH_LENGTH) +
|
30 |
+
"_t_layers_" + str(TOKEN_NUM_LAYERS) +
|
31 |
+
"_p_layers_" + str(PATCH_NUM_LAYERS) +
|
32 |
+
"_h_size_" + str(M3_HIDDEN_SIZE) +
|
33 |
+
"_lr_" + str(M3_LEARNING_RATE) +
|
34 |
+
"_batch_" + str(M3_BATCH_SIZE) +
|
35 |
+
"_mask_" + str(M3_MASK_RATIO) + ".pth"
|
36 |
+
) # Path to store the model weights
|
37 |
+
M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
|
38 |
+
|
39 |
+
# -------------------- Configuration for CLaMP2 Training ----------------
|
40 |
+
TRAIN_JSONL = "<path_to_training_jsonl>" # Path to the JSONL file with training data
|
41 |
+
EVAL_JSONL = "" # (Optional) Path to the JSONL file with evaluation data
|
42 |
+
|
43 |
+
CLAMP2_HIDDEN_SIZE = 768 # Size of the hidden layer
|
44 |
+
TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model
|
45 |
+
|
46 |
+
CLAMP2_NUM_EPOCH = 100 # Maximum number of epochs for training
|
47 |
+
CLAMP2_LEARNING_RATE = 5e-5 # Learning rate for the optimizer
|
48 |
+
CLAMP2_BATCH_SIZE = 128 # Batch size per GPU (single card) during training
|
49 |
+
LOGIT_SCALE = 1 # Scaling factor for contrastive loss
|
50 |
+
MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input
|
51 |
+
TEXT_DROPOUT = True # Whether to apply dropout during text processing
|
52 |
+
CLAMP2_DETERMINISTIC = True # Ensures deterministic results with random seeds
|
53 |
+
CLAMP2_LOAD_M3 = True # Load weights from the M3 model
|
54 |
+
CLAMP2_WANDB_LOG = True # Enable logging to Weights and Biases
|
55 |
+
CLAMP2_LOAD_CKPT = True # Load weights from a checkpoint if available
|
56 |
+
|
57 |
+
CLAMP2_WEIGHTS_PATH = (
|
58 |
+
"weights_clamp2_h_size_" + str(CLAMP2_HIDDEN_SIZE) +
|
59 |
+
"_lr_" + str(CLAMP2_LEARNING_RATE) +
|
60 |
+
"_batch_" + str(CLAMP2_BATCH_SIZE) +
|
61 |
+
"_scale_" + str(LOGIT_SCALE) +
|
62 |
+
"_t_length_" + str(MAX_TEXT_LENGTH) +
|
63 |
+
"_t_model_" + TEXT_MODEL_NAME.replace("/", "_") +
|
64 |
+
"_t_dropout_" + str(TEXT_DROPOUT) +
|
65 |
+
"_m3_" + str(CLAMP2_LOAD_M3) + ".pth"
|
66 |
+
) # Path to store CLaMP2 model weights
|
67 |
+
CLAMP2_LOGS_PATH = CLAMP2_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
|
code/extract_clamp2.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from config import *
|
8 |
+
from utils import *
|
9 |
+
from samplings import *
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from transformers import BertConfig, AutoTokenizer
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
# Parse command-line arguments
|
15 |
+
parser = argparse.ArgumentParser(description="Feature extraction for CLaMP2.")
|
16 |
+
parser.add_argument("input_dir", type=str, help="Directory containing input data files.")
|
17 |
+
parser.add_argument("output_dir", type=str, help="Directory to save the output features.")
|
18 |
+
parser.add_argument("--normalize", action="store_true", help="Normalize the extracted features.")
|
19 |
+
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# Retrieve arguments
|
23 |
+
input_dir = args.input_dir
|
24 |
+
output_dir = args.output_dir
|
25 |
+
normalize = args.normalize
|
26 |
+
|
27 |
+
os.makedirs("logs", exist_ok=True)
|
28 |
+
for file in ["logs/files_extract_clamp2.json",
|
29 |
+
"logs/files_shuffle_extract_clamp2.json",
|
30 |
+
"logs/log_extract_clamp2.txt",
|
31 |
+
"logs/pass_extract_clamp2.txt",
|
32 |
+
"logs/skip_extract_clamp2.txt"]:
|
33 |
+
if os.path.exists(file):
|
34 |
+
os.remove(file)
|
35 |
+
|
36 |
+
files = []
|
37 |
+
for root, dirs, fs in os.walk(input_dir):
|
38 |
+
for f in fs:
|
39 |
+
if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf"):
|
40 |
+
files.append(os.path.join(root, f))
|
41 |
+
print(f"Found {len(files)} files in total")
|
42 |
+
with open("logs/files_extract_clamp2.json", "w", encoding="utf-8") as f:
|
43 |
+
json.dump(files, f)
|
44 |
+
random.shuffle(files)
|
45 |
+
with open("logs/files_shuffle_extract_clamp2.json", "w", encoding="utf-8") as f:
|
46 |
+
json.dump(files, f)
|
47 |
+
|
48 |
+
accelerator = Accelerator()
|
49 |
+
device = accelerator.device
|
50 |
+
print("Using device:", device)
|
51 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
52 |
+
f.write("Using device: " + str(device) + "\n")
|
53 |
+
|
54 |
+
m3_config = BertConfig(vocab_size=1,
|
55 |
+
hidden_size=M3_HIDDEN_SIZE,
|
56 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
57 |
+
num_attention_heads=M3_HIDDEN_SIZE//64,
|
58 |
+
intermediate_size=M3_HIDDEN_SIZE*4,
|
59 |
+
max_position_embeddings=PATCH_LENGTH)
|
60 |
+
model = CLaMP2Model(m3_config,
|
61 |
+
text_model_name=TEXT_MODEL_NAME,
|
62 |
+
hidden_size=CLAMP2_HIDDEN_SIZE,
|
63 |
+
load_m3=CLAMP2_LOAD_M3)
|
64 |
+
model = model.to(device)
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
66 |
+
patchilizer = M3Patchilizer()
|
67 |
+
|
68 |
+
# print parameter number
|
69 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
70 |
+
|
71 |
+
model.eval()
|
72 |
+
checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
73 |
+
print(f"Successfully Loaded CLaMP 2 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
74 |
+
model.load_state_dict(checkpoint['model'])
|
75 |
+
|
76 |
+
def extract_feature(filename, get_normalized=normalize):
|
77 |
+
with open(filename, "r", encoding="utf-8") as f:
|
78 |
+
item = f.read()
|
79 |
+
|
80 |
+
if filename.endswith(".txt"):
|
81 |
+
item = list(set(item.split("\n")))
|
82 |
+
item = "\n".join(item)
|
83 |
+
item = item.split("\n")
|
84 |
+
item = [c for c in item if len(c) > 0]
|
85 |
+
item = tokenizer.sep_token.join(item)
|
86 |
+
input_data = tokenizer(item, return_tensors="pt")
|
87 |
+
input_data = input_data['input_ids'].squeeze(0)
|
88 |
+
max_input_length = MAX_TEXT_LENGTH
|
89 |
+
else:
|
90 |
+
input_data = patchilizer.encode(item, add_special_patches=True)
|
91 |
+
input_data = torch.tensor(input_data)
|
92 |
+
max_input_length = PATCH_LENGTH
|
93 |
+
|
94 |
+
segment_list = []
|
95 |
+
for i in range(0, len(input_data), max_input_length):
|
96 |
+
segment_list.append(input_data[i:i+max_input_length])
|
97 |
+
segment_list[-1] = input_data[-max_input_length:]
|
98 |
+
|
99 |
+
last_hidden_states_list = []
|
100 |
+
|
101 |
+
for input_segment in segment_list:
|
102 |
+
input_masks = torch.tensor([1]*input_segment.size(0))
|
103 |
+
if filename.endswith(".txt"):
|
104 |
+
pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
|
105 |
+
else:
|
106 |
+
pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
|
107 |
+
input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
|
108 |
+
input_segment = torch.cat((input_segment, pad_indices), 0)
|
109 |
+
|
110 |
+
if filename.endswith(".txt"):
|
111 |
+
last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
|
112 |
+
text_masks=input_masks.unsqueeze(0).to(device),
|
113 |
+
get_normalized=get_normalized)
|
114 |
+
else:
|
115 |
+
last_hidden_states = model.get_music_features(music_inputs=input_segment.unsqueeze(0).to(device),
|
116 |
+
music_masks=input_masks.unsqueeze(0).to(device),
|
117 |
+
get_normalized=get_normalized)
|
118 |
+
if not get_normalized:
|
119 |
+
last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
|
120 |
+
last_hidden_states_list.append(last_hidden_states)
|
121 |
+
|
122 |
+
if not get_normalized:
|
123 |
+
last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
|
124 |
+
last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
|
125 |
+
last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
|
126 |
+
else:
|
127 |
+
full_chunk_cnt = len(input_data) // max_input_length
|
128 |
+
remain_chunk_len = len(input_data) % max_input_length
|
129 |
+
if remain_chunk_len == 0:
|
130 |
+
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
|
131 |
+
else:
|
132 |
+
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
|
133 |
+
|
134 |
+
last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
|
135 |
+
last_hidden_states_list = last_hidden_states_list * feature_weights
|
136 |
+
last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()
|
137 |
+
|
138 |
+
return last_hidden_states_list
|
139 |
+
|
140 |
+
def process_directory(input_dir, output_dir, files):
|
141 |
+
print(f"Found {len(files)} files in total")
|
142 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
143 |
+
f.write("Found " + str(len(files)) + " files in total\n")
|
144 |
+
|
145 |
+
# calculate the number of files to process per GPU
|
146 |
+
num_files_per_gpu = len(files) // accelerator.num_processes
|
147 |
+
|
148 |
+
# calculate the start and end index for the current GPU
|
149 |
+
start_idx = accelerator.process_index * num_files_per_gpu
|
150 |
+
end_idx = start_idx + num_files_per_gpu
|
151 |
+
if accelerator.process_index == accelerator.num_processes - 1:
|
152 |
+
end_idx = len(files)
|
153 |
+
|
154 |
+
files_to_process = files[start_idx:end_idx]
|
155 |
+
|
156 |
+
# process the files
|
157 |
+
for file in tqdm(files_to_process):
|
158 |
+
output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
|
159 |
+
try:
|
160 |
+
os.makedirs(output_subdir, exist_ok=True)
|
161 |
+
except Exception as e:
|
162 |
+
print(output_subdir + " can not be created\n" + str(e))
|
163 |
+
with open("logs/log_extract_clamp.txt", "a") as f:
|
164 |
+
f.write(output_subdir + " can not be created\n" + str(e) + "\n")
|
165 |
+
|
166 |
+
output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
|
167 |
+
|
168 |
+
if os.path.exists(output_file):
|
169 |
+
print(f"Skipping {file}, output already exists")
|
170 |
+
with open("logs/skip_extract_clamp2.txt", "a", encoding="utf-8") as f:
|
171 |
+
f.write(file + "\n")
|
172 |
+
continue
|
173 |
+
|
174 |
+
try:
|
175 |
+
with torch.no_grad():
|
176 |
+
features = extract_feature(file).unsqueeze(0)
|
177 |
+
np.save(output_file, features.detach().cpu().numpy())
|
178 |
+
with open("logs/pass_extract_clamp2.txt", "a", encoding="utf-8") as f:
|
179 |
+
f.write(file + "\n")
|
180 |
+
except Exception as e:
|
181 |
+
print(f"Failed to process {file}: {e}")
|
182 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
183 |
+
f.write("Failed to process " + file + ": " + str(e) + "\n")
|
184 |
+
|
185 |
+
with open("logs/files_shuffle_extract_clamp2.json", "r", encoding="utf-8") as f:
|
186 |
+
files = json.load(f)
|
187 |
+
|
188 |
+
# process the files
|
189 |
+
process_directory(input_dir, output_dir, files)
|
190 |
+
|
191 |
+
with open("logs/log_extract_clamp.txt", "a", encoding="utf-8") as f:
|
192 |
+
f.write("GPU ID: " + str(device) + "\n")
|
code/extract_m3.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from config import *
|
8 |
+
from utils import *
|
9 |
+
from samplings import *
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from transformers import BertConfig, GPT2Config
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
# Parse command-line arguments for input_dir and output_dir
|
15 |
+
parser = argparse.ArgumentParser(description="Process files to extract features.")
|
16 |
+
parser.add_argument("input_dir", type=str, help="Directory with input files")
|
17 |
+
parser.add_argument("output_dir", type=str, help="Directory to save extracted features")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
# Use args for input and output directories
|
21 |
+
input_dir = args.input_dir
|
22 |
+
output_dir = args.output_dir
|
23 |
+
|
24 |
+
# Create logs directory if it doesn't exist
|
25 |
+
os.makedirs("logs", exist_ok=True)
|
26 |
+
|
27 |
+
# Remove existing log files if present
|
28 |
+
for file in [
|
29 |
+
"logs/files_extract_m3.json",
|
30 |
+
"logs/files_shuffle_extract_m3.json",
|
31 |
+
"logs/log_extract_m3.txt",
|
32 |
+
"logs/pass_extract_m3.txt",
|
33 |
+
"logs/skip_extract_m3.txt",
|
34 |
+
]:
|
35 |
+
if os.path.exists(file):
|
36 |
+
os.remove(file)
|
37 |
+
|
38 |
+
# Collect input files
|
39 |
+
files = []
|
40 |
+
for root, dirs, fs in os.walk(input_dir):
|
41 |
+
for f in fs:
|
42 |
+
if f.endswith(".abc") or f.endswith(".mtf"):
|
43 |
+
files.append(os.path.join(root, f))
|
44 |
+
|
45 |
+
print(f"Found {len(files)} files in total")
|
46 |
+
with open("logs/files_extract_m3.json", "w", encoding="utf-8") as f:
|
47 |
+
json.dump(files, f)
|
48 |
+
|
49 |
+
# Shuffle files and save the shuffled order
|
50 |
+
random.shuffle(files)
|
51 |
+
with open("logs/files_shuffle_extract_m3.json", "w", encoding="utf-8") as f:
|
52 |
+
json.dump(files, f)
|
53 |
+
|
54 |
+
# Initialize accelerator and device
|
55 |
+
accelerator = Accelerator()
|
56 |
+
device = accelerator.device
|
57 |
+
print("Using device:", device)
|
58 |
+
with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
|
59 |
+
f.write("Using device: " + str(device) + "\n")
|
60 |
+
|
61 |
+
# Model and configuration setup
|
62 |
+
patchilizer = M3Patchilizer()
|
63 |
+
encoder_config = BertConfig(
|
64 |
+
vocab_size=1,
|
65 |
+
hidden_size=M3_HIDDEN_SIZE,
|
66 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
67 |
+
num_attention_heads=M3_HIDDEN_SIZE // 64,
|
68 |
+
intermediate_size=M3_HIDDEN_SIZE * 4,
|
69 |
+
max_position_embeddings=PATCH_LENGTH,
|
70 |
+
)
|
71 |
+
decoder_config = GPT2Config(
|
72 |
+
vocab_size=128,
|
73 |
+
n_positions=PATCH_SIZE,
|
74 |
+
n_embd=M3_HIDDEN_SIZE,
|
75 |
+
n_layer=TOKEN_NUM_LAYERS,
|
76 |
+
n_head=M3_HIDDEN_SIZE // 64,
|
77 |
+
n_inner=M3_HIDDEN_SIZE * 4,
|
78 |
+
)
|
79 |
+
model = M3Model(encoder_config, decoder_config).to(device)
|
80 |
+
|
81 |
+
# Print parameter count
|
82 |
+
print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
83 |
+
|
84 |
+
# Load model weights
|
85 |
+
model.eval()
|
86 |
+
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
87 |
+
print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
88 |
+
model.load_state_dict(checkpoint['model'])
|
89 |
+
|
90 |
+
def extract_feature(item):
|
91 |
+
"""Extracts features from input data."""
|
92 |
+
target_patches = patchilizer.encode(item, add_special_patches=True)
|
93 |
+
target_patches_list = [target_patches[i:i + PATCH_LENGTH] for i in range(0, len(target_patches), PATCH_LENGTH)]
|
94 |
+
target_patches_list[-1] = target_patches[-PATCH_LENGTH:]
|
95 |
+
|
96 |
+
last_hidden_states_list = []
|
97 |
+
for input_patches in target_patches_list:
|
98 |
+
input_masks = torch.tensor([1] * len(input_patches))
|
99 |
+
input_patches = torch.tensor(input_patches)
|
100 |
+
last_hidden_states = model.encoder(
|
101 |
+
input_patches.unsqueeze(0).to(device), input_masks.unsqueeze(0).to(device)
|
102 |
+
)["last_hidden_state"][0]
|
103 |
+
last_hidden_states_list.append(last_hidden_states)
|
104 |
+
|
105 |
+
# Handle the last segment padding correctly
|
106 |
+
last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(target_patches) % PATCH_LENGTH):]
|
107 |
+
return torch.concat(last_hidden_states_list, 0)
|
108 |
+
|
109 |
+
def process_directory(input_dir, output_dir, files):
|
110 |
+
"""Processes files in the input directory and saves features to the output directory."""
|
111 |
+
print(f"Found {len(files)} files in total")
|
112 |
+
with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
|
113 |
+
f.write("Found " + str(len(files)) + " files in total\n")
|
114 |
+
|
115 |
+
# Distribute files across processes for parallel processing
|
116 |
+
num_files_per_gpu = len(files) // accelerator.num_processes
|
117 |
+
start_idx = accelerator.process_index * num_files_per_gpu
|
118 |
+
end_idx = start_idx + num_files_per_gpu if accelerator.process_index < accelerator.num_processes - 1 else len(files)
|
119 |
+
files_to_process = files[start_idx:end_idx]
|
120 |
+
|
121 |
+
# Process each file
|
122 |
+
for file in tqdm(files_to_process):
|
123 |
+
output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
|
124 |
+
try:
|
125 |
+
os.makedirs(output_subdir, exist_ok=True)
|
126 |
+
except Exception as e:
|
127 |
+
print(f"{output_subdir} cannot be created\n{e}")
|
128 |
+
with open("logs/log_extract_m3.txt", "a") as f:
|
129 |
+
f.write(f"{output_subdir} cannot be created\n{e}\n")
|
130 |
+
|
131 |
+
output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
|
132 |
+
|
133 |
+
if os.path.exists(output_file):
|
134 |
+
print(f"Skipping {file}, output already exists")
|
135 |
+
with open("logs/skip_extract_m3.txt", "a", encoding="utf-8") as f:
|
136 |
+
f.write(file + "\n")
|
137 |
+
continue
|
138 |
+
|
139 |
+
try:
|
140 |
+
with open(file, "r", encoding="utf-8") as f:
|
141 |
+
item = f.read()
|
142 |
+
if not item.startswith("ticks_per_beat"):
|
143 |
+
item = item.replace("L:1/8\n", "")
|
144 |
+
with torch.no_grad():
|
145 |
+
features = extract_feature(item).unsqueeze(0)
|
146 |
+
np.save(output_file, features.detach().cpu().numpy())
|
147 |
+
with open("logs/pass_extract_m3.txt", "a", encoding="utf-8") as f:
|
148 |
+
f.write(file + "\n")
|
149 |
+
except Exception as e:
|
150 |
+
print(f"Failed to process {file}: {e}")
|
151 |
+
with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
|
152 |
+
f.write(f"Failed to process {file}: {e}\n")
|
153 |
+
|
154 |
+
# Load shuffled files list and start processing
|
155 |
+
with open("logs/files_shuffle_extract_m3.json", "r", encoding="utf-8") as f:
|
156 |
+
files = json.load(f)
|
157 |
+
|
158 |
+
# Process the directory
|
159 |
+
process_directory(input_dir, output_dir, files)
|
160 |
+
|
161 |
+
with open("logs/log_extract_m3.txt", "a", encoding="utf-8") as f:
|
162 |
+
f.write("GPU ID: " + str(device) + "\n")
|
code/logs_clamp2_h_size_768_lr_5e-05_batch_128_scale_1_t_length_128_t_model_FacebookAI_xlm-roberta-base_t_dropout_True_m3_True.txt
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Epoch 1
|
2 |
+
train_loss: 4.135209383230448
|
3 |
+
eval_loss: 1.9609466393788655
|
4 |
+
time: Sun Sep 15 16:53:25 2024
|
5 |
+
|
6 |
+
Epoch 2
|
7 |
+
train_loss: 2.9014642586344244
|
8 |
+
eval_loss: 1.518966309229533
|
9 |
+
time: Sun Sep 15 18:43:11 2024
|
10 |
+
|
11 |
+
Epoch 3
|
12 |
+
train_loss: 2.5554833216573805
|
13 |
+
eval_loss: 1.2929850498835245
|
14 |
+
time: Sun Sep 15 20:35:14 2024
|
15 |
+
|
16 |
+
Epoch 4
|
17 |
+
train_loss: 2.3372960954320985
|
18 |
+
eval_loss: 1.1918509085973104
|
19 |
+
time: Sun Sep 15 22:23:17 2024
|
20 |
+
|
21 |
+
Epoch 5
|
22 |
+
train_loss: 2.183457492896627
|
23 |
+
eval_loss: 1.0725035190582275
|
24 |
+
time: Mon Sep 16 00:11:33 2024
|
25 |
+
|
26 |
+
Epoch 6
|
27 |
+
train_loss: 2.0555087922796216
|
28 |
+
eval_loss: 0.9798878033955892
|
29 |
+
time: Mon Sep 16 01:59:41 2024
|
30 |
+
|
31 |
+
Epoch 7
|
32 |
+
train_loss: 1.9556777615308922
|
33 |
+
eval_loss: 0.9242686867713928
|
34 |
+
time: Mon Sep 16 03:47:45 2024
|
35 |
+
|
36 |
+
Epoch 8
|
37 |
+
train_loss: 1.8689270445336208
|
38 |
+
eval_loss: 0.8476833502451578
|
39 |
+
time: Mon Sep 16 05:35:55 2024
|
40 |
+
|
41 |
+
Epoch 9
|
42 |
+
train_loss: 1.7932923043361795
|
43 |
+
eval_loss: 0.8043208519617716
|
44 |
+
time: Mon Sep 16 07:24:40 2024
|
45 |
+
|
46 |
+
Epoch 10
|
47 |
+
train_loss: 1.726651672090251
|
48 |
+
eval_loss: 0.7419429302215577
|
49 |
+
time: Mon Sep 16 09:12:24 2024
|
50 |
+
|
51 |
+
Epoch 11
|
52 |
+
train_loss: 1.6657210920084013
|
53 |
+
eval_loss: 0.7209376732508341
|
54 |
+
time: Mon Sep 16 11:01:02 2024
|
55 |
+
|
56 |
+
Epoch 12
|
57 |
+
train_loss: 1.6131336856822078
|
58 |
+
eval_loss: 0.6950737277666728
|
59 |
+
time: Mon Sep 16 12:50:07 2024
|
60 |
+
|
61 |
+
Epoch 13
|
62 |
+
train_loss: 1.5601606700647368
|
63 |
+
eval_loss: 0.652581254641215
|
64 |
+
time: Mon Sep 16 14:40:19 2024
|
65 |
+
|
66 |
+
Epoch 14
|
67 |
+
train_loss: 1.5198849670231784
|
68 |
+
eval_loss: 0.5868058403333029
|
69 |
+
time: Mon Sep 16 16:30:09 2024
|
70 |
+
|
71 |
+
Epoch 15
|
72 |
+
train_loss: 1.4755114693644882
|
73 |
+
eval_loss: 0.5867449243863424
|
74 |
+
time: Mon Sep 16 18:20:11 2024
|
75 |
+
|
76 |
+
Epoch 16
|
77 |
+
train_loss: 1.4336211351749464
|
78 |
+
eval_loss: 0.5479268968105316
|
79 |
+
time: Mon Sep 16 20:10:42 2024
|
80 |
+
|
81 |
+
Epoch 17
|
82 |
+
train_loss: 1.4039166571722088
|
83 |
+
eval_loss: 0.5280438164869944
|
84 |
+
time: Mon Sep 16 22:01:01 2024
|
85 |
+
|
86 |
+
Epoch 18
|
87 |
+
train_loss: 1.365380759842085
|
88 |
+
eval_loss: 0.5008598109086354
|
89 |
+
time: Mon Sep 16 23:51:36 2024
|
90 |
+
|
91 |
+
Epoch 19
|
92 |
+
train_loss: 1.332988672848894
|
93 |
+
eval_loss: 0.46479900081952413
|
94 |
+
time: Tue Sep 17 01:41:49 2024
|
95 |
+
|
96 |
+
Epoch 20
|
97 |
+
train_loss: 1.3014001791981087
|
98 |
+
eval_loss: 0.45230263272921245
|
99 |
+
time: Tue Sep 17 03:33:12 2024
|
100 |
+
|
101 |
+
Epoch 21
|
102 |
+
train_loss: 1.2688752540577755
|
103 |
+
eval_loss: 0.4297992348670959
|
104 |
+
time: Tue Sep 17 05:23:40 2024
|
105 |
+
|
106 |
+
Epoch 22
|
107 |
+
train_loss: 1.2425967695381415
|
108 |
+
eval_loss: 0.4219102164109548
|
109 |
+
time: Tue Sep 17 07:14:26 2024
|
110 |
+
|
111 |
+
Epoch 23
|
112 |
+
train_loss: 1.216824040410488
|
113 |
+
eval_loss: 0.40282649795214337
|
114 |
+
time: Tue Sep 17 09:05:53 2024
|
115 |
+
|
116 |
+
Epoch 24
|
117 |
+
train_loss: 1.1875996505747286
|
118 |
+
eval_loss: 0.36659018794695536
|
119 |
+
time: Tue Sep 17 10:56:46 2024
|
120 |
+
|
121 |
+
Epoch 25
|
122 |
+
train_loss: 1.1670776548255222
|
123 |
+
eval_loss: 0.36906688412030536
|
124 |
+
time: Tue Sep 17 12:47:39 2024
|
125 |
+
|
126 |
+
Epoch 26
|
127 |
+
train_loss: 1.1426405137974536
|
128 |
+
eval_loss: 0.3478178918361664
|
129 |
+
time: Tue Sep 17 14:38:34 2024
|
130 |
+
|
131 |
+
Epoch 27
|
132 |
+
train_loss: 1.1208335824466733
|
133 |
+
eval_loss: 0.33407697081565857
|
134 |
+
time: Tue Sep 17 16:28:45 2024
|
135 |
+
|
136 |
+
Epoch 28
|
137 |
+
train_loss: 1.0998876758880667
|
138 |
+
eval_loss: 0.33792892495791116
|
139 |
+
time: Tue Sep 17 18:19:59 2024
|
140 |
+
|
141 |
+
Epoch 29
|
142 |
+
train_loss: 1.0769698478377083
|
143 |
+
eval_loss: 0.3026650925477346
|
144 |
+
time: Tue Sep 17 20:11:16 2024
|
145 |
+
|
146 |
+
Epoch 30
|
147 |
+
train_loss: 1.0587592209657248
|
148 |
+
eval_loss: 0.2914476583401362
|
149 |
+
time: Tue Sep 17 22:03:30 2024
|
150 |
+
|
151 |
+
Epoch 31
|
152 |
+
train_loss: 1.0384011404245468
|
153 |
+
eval_loss: 0.27578969597816466
|
154 |
+
time: Tue Sep 17 23:55:15 2024
|
155 |
+
|
156 |
+
Epoch 32
|
157 |
+
train_loss: 1.0233595809527622
|
158 |
+
eval_loss: 0.2651842157046
|
159 |
+
time: Wed Sep 18 01:46:45 2024
|
160 |
+
|
161 |
+
Epoch 33
|
162 |
+
train_loss: 1.001824217418977
|
163 |
+
eval_loss: 0.2630385269721349
|
164 |
+
time: Wed Sep 18 03:39:08 2024
|
165 |
+
|
166 |
+
Epoch 34
|
167 |
+
train_loss: 0.9853754720520442
|
168 |
+
eval_loss: 0.25253995358943937
|
169 |
+
time: Wed Sep 18 05:30:33 2024
|
170 |
+
|
171 |
+
Epoch 35
|
172 |
+
train_loss: 0.9676362536067821
|
173 |
+
eval_loss: 0.24096360007921855
|
174 |
+
time: Wed Sep 18 07:22:09 2024
|
175 |
+
|
176 |
+
Epoch 36
|
177 |
+
train_loss: 0.9507065269691086
|
178 |
+
eval_loss: 0.2413844664891561
|
179 |
+
time: Wed Sep 18 09:12:59 2024
|
180 |
+
|
181 |
+
Epoch 37
|
182 |
+
train_loss: 0.9362979678186832
|
183 |
+
eval_loss: 0.23412639300028484
|
184 |
+
time: Wed Sep 18 11:04:09 2024
|
185 |
+
|
186 |
+
Epoch 38
|
187 |
+
train_loss: 0.9174621180856977
|
188 |
+
eval_loss: 0.21386308073997498
|
189 |
+
time: Wed Sep 18 12:54:52 2024
|
190 |
+
|
191 |
+
Epoch 39
|
192 |
+
train_loss: 0.9090870427650668
|
193 |
+
eval_loss: 0.19962686796983084
|
194 |
+
time: Wed Sep 18 14:45:52 2024
|
195 |
+
|
196 |
+
Epoch 40
|
197 |
+
train_loss: 0.8918763521271409
|
198 |
+
eval_loss: 0.20026112000147503
|
199 |
+
time: Wed Sep 18 16:36:37 2024
|
200 |
+
|
201 |
+
Epoch 41
|
202 |
+
train_loss: 0.8786202421428222
|
203 |
+
eval_loss: 0.18366556564966838
|
204 |
+
time: Wed Sep 18 18:27:31 2024
|
205 |
+
|
206 |
+
Epoch 42
|
207 |
+
train_loss: 0.8670675420604148
|
208 |
+
eval_loss: 0.17908457616964976
|
209 |
+
time: Wed Sep 18 20:18:16 2024
|
210 |
+
|
211 |
+
Epoch 43
|
212 |
+
train_loss: 0.8505593872931582
|
213 |
+
eval_loss: 0.17053016225496928
|
214 |
+
time: Wed Sep 18 22:10:39 2024
|
215 |
+
|
216 |
+
Epoch 44
|
217 |
+
train_loss: 0.8421949260766888
|
218 |
+
eval_loss: 0.17344878117243448
|
219 |
+
time: Thu Sep 19 00:02:24 2024
|
220 |
+
|
221 |
+
Epoch 45
|
222 |
+
train_loss: 0.8267569324702205
|
223 |
+
eval_loss: 0.1591893643140793
|
224 |
+
time: Thu Sep 19 01:53:48 2024
|
225 |
+
|
226 |
+
Epoch 46
|
227 |
+
train_loss: 0.8144617894466949
|
228 |
+
eval_loss: 0.15313500861326854
|
229 |
+
time: Thu Sep 19 03:44:58 2024
|
230 |
+
|
231 |
+
Epoch 47
|
232 |
+
train_loss: 0.8041844731303666
|
233 |
+
eval_loss: 0.14998503575722377
|
234 |
+
time: Thu Sep 19 05:36:50 2024
|
235 |
+
|
236 |
+
Epoch 48
|
237 |
+
train_loss: 0.7938160687423412
|
238 |
+
eval_loss: 0.1401842971642812
|
239 |
+
time: Thu Sep 19 07:28:21 2024
|
240 |
+
|
241 |
+
Epoch 49
|
242 |
+
train_loss: 0.7808867423096515
|
243 |
+
eval_loss: 0.1368137091398239
|
244 |
+
time: Thu Sep 19 09:20:09 2024
|
245 |
+
|
246 |
+
Epoch 50
|
247 |
+
train_loss: 0.7702171771933628
|
248 |
+
eval_loss: 0.13333487262328467
|
249 |
+
time: Thu Sep 19 11:12:37 2024
|
250 |
+
|
251 |
+
Epoch 51
|
252 |
+
train_loss: 0.7604444062967384
|
253 |
+
eval_loss: 0.13119754443566004
|
254 |
+
time: Thu Sep 19 13:04:26 2024
|
255 |
+
|
256 |
+
Epoch 52
|
257 |
+
train_loss: 0.7496546459894258
|
258 |
+
eval_loss: 0.1236343190073967
|
259 |
+
time: Thu Sep 19 14:55:53 2024
|
260 |
+
|
261 |
+
Epoch 53
|
262 |
+
train_loss: 0.7406523988345118
|
263 |
+
eval_loss: 0.12237562835216523
|
264 |
+
time: Thu Sep 19 16:47:51 2024
|
265 |
+
|
266 |
+
Epoch 54
|
267 |
+
train_loss: 0.7331518270251398
|
268 |
+
eval_loss: 0.11441469887892405
|
269 |
+
time: Thu Sep 19 18:38:48 2024
|
270 |
+
|
271 |
+
Epoch 55
|
272 |
+
train_loss: 0.7238280263746373
|
273 |
+
eval_loss: 0.10651812156041464
|
274 |
+
time: Thu Sep 19 20:29:18 2024
|
275 |
+
|
276 |
+
Epoch 56
|
277 |
+
train_loss: 0.7141688125488486
|
278 |
+
eval_loss: 0.10959143290917078
|
279 |
+
time: Thu Sep 19 22:19:28 2024
|
280 |
+
|
281 |
+
Epoch 57
|
282 |
+
train_loss: 0.7053173944645842
|
283 |
+
eval_loss: 0.10957898745934168
|
284 |
+
time: Fri Sep 20 00:10:06 2024
|
285 |
+
|
286 |
+
Epoch 58
|
287 |
+
train_loss: 0.6992166797548109
|
288 |
+
eval_loss: 0.09759224901596705
|
289 |
+
time: Fri Sep 20 02:01:02 2024
|
290 |
+
|
291 |
+
Epoch 59
|
292 |
+
train_loss: 0.6855367768623795
|
293 |
+
eval_loss: 0.10631066560745239
|
294 |
+
time: Fri Sep 20 03:51:25 2024
|
295 |
+
|
296 |
+
Epoch 60
|
297 |
+
train_loss: 0.6812366953699432
|
298 |
+
eval_loss: 0.08681503732999166
|
299 |
+
time: Fri Sep 20 05:41:32 2024
|
300 |
+
|
301 |
+
Epoch 61
|
302 |
+
train_loss: 0.6744320154854127
|
303 |
+
eval_loss: 0.08995070978999138
|
304 |
+
time: Fri Sep 20 07:32:33 2024
|
305 |
+
|
306 |
+
Epoch 62
|
307 |
+
train_loss: 0.6627048003782218
|
308 |
+
eval_loss: 0.08492780551314354
|
309 |
+
time: Fri Sep 20 09:22:52 2024
|
310 |
+
|
311 |
+
Epoch 63
|
312 |
+
train_loss: 0.6554694614403961
|
313 |
+
eval_loss: 0.09110054125388463
|
314 |
+
time: Fri Sep 20 11:15:14 2024
|
315 |
+
|
316 |
+
Epoch 64
|
317 |
+
train_loss: 0.6519363358224428
|
318 |
+
eval_loss: 0.08603844990332922
|
319 |
+
time: Fri Sep 20 13:05:45 2024
|
320 |
+
|
321 |
+
Epoch 65
|
322 |
+
train_loss: 0.6432196787488694
|
323 |
+
eval_loss: 0.07920929342508316
|
324 |
+
time: Fri Sep 20 14:56:27 2024
|
325 |
+
|
326 |
+
Epoch 66
|
327 |
+
train_loss: 0.6355774498505016
|
328 |
+
eval_loss: 0.08108622878789902
|
329 |
+
time: Fri Sep 20 16:47:00 2024
|
330 |
+
|
331 |
+
Epoch 67
|
332 |
+
train_loss: 0.628098195042665
|
333 |
+
eval_loss: 0.0835166151324908
|
334 |
+
time: Fri Sep 20 18:37:19 2024
|
335 |
+
|
336 |
+
Epoch 68
|
337 |
+
train_loss: 0.6229319736150211
|
338 |
+
eval_loss: 0.08126899500687917
|
339 |
+
time: Fri Sep 20 20:27:49 2024
|
340 |
+
|
341 |
+
Epoch 69
|
342 |
+
train_loss: 0.6162204064685376
|
343 |
+
eval_loss: 0.07405624414483707
|
344 |
+
time: Fri Sep 20 22:18:28 2024
|
345 |
+
|
346 |
+
Epoch 70
|
347 |
+
train_loss: 0.6093617768645045
|
348 |
+
eval_loss: 0.07916868552565574
|
349 |
+
time: Sat Sep 21 00:10:02 2024
|
350 |
+
|
351 |
+
Epoch 71
|
352 |
+
train_loss: 0.603765148576412
|
353 |
+
eval_loss: 0.07368899683157602
|
354 |
+
time: Sat Sep 21 02:00:29 2024
|
355 |
+
|
356 |
+
Epoch 72
|
357 |
+
train_loss: 0.5988557130088281
|
358 |
+
eval_loss: 0.06763924509286881
|
359 |
+
time: Sat Sep 21 03:51:46 2024
|
360 |
+
|
361 |
+
Epoch 73
|
362 |
+
train_loss: 0.590835969827209
|
363 |
+
eval_loss: 0.07139033873875936
|
364 |
+
time: Sat Sep 21 05:43:51 2024
|
365 |
+
|
366 |
+
Epoch 74
|
367 |
+
train_loss: 0.5864904869113879
|
368 |
+
eval_loss: 0.06859012718002001
|
369 |
+
time: Sat Sep 21 07:34:23 2024
|
370 |
+
|
371 |
+
Epoch 75
|
372 |
+
train_loss: 0.5819329118342274
|
373 |
+
eval_loss: 0.07611284777522087
|
374 |
+
time: Sat Sep 21 09:25:24 2024
|
375 |
+
|
376 |
+
Epoch 76
|
377 |
+
train_loss: 0.5750655913014898
|
378 |
+
eval_loss: 0.06813529431819916
|
379 |
+
time: Sat Sep 21 11:16:26 2024
|
380 |
+
|
381 |
+
Epoch 77
|
382 |
+
train_loss: 0.5703848759963817
|
383 |
+
eval_loss: 0.07192744488517443
|
384 |
+
time: Sat Sep 21 13:07:32 2024
|
385 |
+
|
386 |
+
Epoch 78
|
387 |
+
train_loss: 0.5666614368024667
|
388 |
+
eval_loss: 0.06931692684690158
|
389 |
+
time: Sat Sep 21 14:59:16 2024
|
390 |
+
|
391 |
+
Epoch 79
|
392 |
+
train_loss: 0.5610024514409998
|
393 |
+
eval_loss: 0.06487631574273109
|
394 |
+
time: Sat Sep 21 16:50:56 2024
|
395 |
+
|
396 |
+
Epoch 80
|
397 |
+
train_loss: 0.5552226794301296
|
398 |
+
eval_loss: 0.06034566586216291
|
399 |
+
time: Sat Sep 21 18:43:49 2024
|
400 |
+
|
401 |
+
Epoch 81
|
402 |
+
train_loss: 0.5512203840912394
|
403 |
+
eval_loss: 0.05962909683585167
|
404 |
+
time: Sat Sep 21 20:36:01 2024
|
405 |
+
|
406 |
+
Epoch 82
|
407 |
+
train_loss: 0.5477618443893468
|
408 |
+
eval_loss: 0.05546447386344274
|
409 |
+
time: Sat Sep 21 22:28:13 2024
|
410 |
+
|
411 |
+
Epoch 83
|
412 |
+
train_loss: 0.5428704522615506
|
413 |
+
eval_loss: 0.05013169844945272
|
414 |
+
time: Sun Sep 22 00:21:20 2024
|
415 |
+
|
416 |
+
Epoch 84
|
417 |
+
train_loss: 0.5396500316264258
|
418 |
+
eval_loss: 0.062498694161574046
|
419 |
+
time: Sun Sep 22 02:13:07 2024
|
420 |
+
|
421 |
+
Epoch 85
|
422 |
+
train_loss: 0.5349479554715307
|
423 |
+
eval_loss: 0.06073434228698413
|
424 |
+
time: Sun Sep 22 04:05:17 2024
|
425 |
+
|
426 |
+
Epoch 86
|
427 |
+
train_loss: 0.5292192482811466
|
428 |
+
eval_loss: 0.05734321524699529
|
429 |
+
time: Sun Sep 22 05:57:05 2024
|
430 |
+
|
431 |
+
Epoch 87
|
432 |
+
train_loss: 0.5249555090607058
|
433 |
+
eval_loss: 0.05274935985604922
|
434 |
+
time: Sun Sep 22 07:48:52 2024
|
435 |
+
|
436 |
+
Epoch 88
|
437 |
+
train_loss: 0.523276918144503
|
438 |
+
eval_loss: 0.05601314604282379
|
439 |
+
time: Sun Sep 22 09:41:05 2024
|
440 |
+
|
441 |
+
Epoch 89
|
442 |
+
train_loss: 0.5179934711230115
|
443 |
+
eval_loss: 0.057493301729361214
|
444 |
+
time: Sun Sep 22 11:33:47 2024
|
445 |
+
|
446 |
+
Epoch 90
|
447 |
+
train_loss: 0.5129834874146376
|
448 |
+
eval_loss: 0.05289425750573476
|
449 |
+
time: Sun Sep 22 13:25:54 2024
|
450 |
+
|
451 |
+
Epoch 91
|
452 |
+
train_loss: 0.5104886514866054
|
453 |
+
eval_loss: 0.0586332509915034
|
454 |
+
time: Sun Sep 22 15:18:13 2024
|
455 |
+
|
456 |
+
Epoch 92
|
457 |
+
train_loss: 0.5067275374282622
|
458 |
+
eval_loss: 0.0489634457975626
|
459 |
+
time: Sun Sep 22 17:10:39 2024
|
460 |
+
|
461 |
+
Epoch 93
|
462 |
+
train_loss: 0.5038576471461468
|
463 |
+
eval_loss: 0.05257208868861198
|
464 |
+
time: Sun Sep 22 19:04:46 2024
|
465 |
+
|
466 |
+
Epoch 94
|
467 |
+
train_loss: 0.5013840998762528
|
468 |
+
eval_loss: 0.05249967947602272
|
469 |
+
time: Sun Sep 22 20:57:55 2024
|
470 |
+
|
471 |
+
Epoch 95
|
472 |
+
train_loss: 0.4949465335763684
|
473 |
+
eval_loss: 0.048154672731955846
|
474 |
+
time: Sun Sep 22 22:50:30 2024
|
475 |
+
|
476 |
+
Epoch 96
|
477 |
+
train_loss: 0.4925781255166608
|
478 |
+
eval_loss: 0.052830965568621956
|
479 |
+
time: Mon Sep 23 00:43:13 2024
|
480 |
+
|
481 |
+
Epoch 97
|
482 |
+
train_loss: 0.4875780233282
|
483 |
+
eval_loss: 0.04684837857882182
|
484 |
+
time: Mon Sep 23 02:35:38 2024
|
485 |
+
|
486 |
+
Epoch 98
|
487 |
+
train_loss: 0.4858591078021573
|
488 |
+
eval_loss: 0.04507673804958661
|
489 |
+
time: Mon Sep 23 04:28:25 2024
|
490 |
+
|
491 |
+
Epoch 99
|
492 |
+
train_loss: 0.4804891498405977
|
493 |
+
eval_loss: 0.048148307204246524
|
494 |
+
time: Mon Sep 23 06:21:11 2024
|
495 |
+
|
496 |
+
Epoch 100
|
497 |
+
train_loss: 0.4782898508661265
|
498 |
+
eval_loss: 0.044317328557372096
|
499 |
+
time: Mon Sep 23 08:13:38 2024
|
500 |
+
|
code/logs_m3_p_size_64_p_length_512_t_layers_3_p_layers_12_h_size_768_lr_0.0001_batch_16_mask_0.45.txt
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Epoch 1
|
2 |
+
train_loss: 0.3055062843047765
|
3 |
+
eval_loss: 0.03727418192925584
|
4 |
+
time: Wed Aug 7 08:54:27 2024
|
5 |
+
|
6 |
+
Epoch 2
|
7 |
+
train_loss: 0.1718286834194018
|
8 |
+
eval_loss: 0.020085958587123625
|
9 |
+
time: Wed Aug 7 14:38:40 2024
|
10 |
+
|
11 |
+
Epoch 3
|
12 |
+
train_loss: 0.1476379437283353
|
13 |
+
eval_loss: 0.013794219702922342
|
14 |
+
time: Wed Aug 7 20:24:50 2024
|
15 |
+
|
16 |
+
Epoch 4
|
17 |
+
train_loss: 0.13554848474242498
|
18 |
+
eval_loss: 0.011902455668817844
|
19 |
+
time: Thu Aug 8 02:13:14 2024
|
20 |
+
|
21 |
+
Epoch 5
|
22 |
+
train_loss: 0.12781724531702496
|
23 |
+
eval_loss: 0.008929020740909163
|
24 |
+
time: Thu Aug 8 08:00:32 2024
|
25 |
+
|
26 |
+
Epoch 6
|
27 |
+
train_loss: 0.12264176163121285
|
28 |
+
eval_loss: 0.008453744098166877
|
29 |
+
time: Thu Aug 8 13:47:12 2024
|
30 |
+
|
31 |
+
Epoch 7
|
32 |
+
train_loss: 0.11872020949762974
|
33 |
+
eval_loss: 0.007165850172573819
|
34 |
+
time: Thu Aug 8 19:34:09 2024
|
35 |
+
|
36 |
+
Epoch 8
|
37 |
+
train_loss: 0.1153576639058103
|
38 |
+
eval_loss: 0.006601243383027142
|
39 |
+
time: Fri Aug 9 01:22:19 2024
|
40 |
+
|
41 |
+
Epoch 9
|
42 |
+
train_loss: 0.11312788720856465
|
43 |
+
eval_loss: 0.005973297544609645
|
44 |
+
time: Fri Aug 9 07:11:25 2024
|
45 |
+
|
46 |
+
Epoch 10
|
47 |
+
train_loss: 0.11096722687304313
|
48 |
+
eval_loss: 0.005796642405204946
|
49 |
+
time: Fri Aug 9 12:59:13 2024
|
50 |
+
|
51 |
+
Epoch 11
|
52 |
+
train_loss: 0.10913465011501206
|
53 |
+
eval_loss: 0.005249736892483845
|
54 |
+
time: Fri Aug 9 18:46:14 2024
|
55 |
+
|
56 |
+
Epoch 12
|
57 |
+
train_loss: 0.10780615577682347
|
58 |
+
eval_loss: 0.005115668955435858
|
59 |
+
time: Sat Aug 10 00:34:28 2024
|
60 |
+
|
61 |
+
Epoch 13
|
62 |
+
train_loss: 0.10650418949283817
|
63 |
+
eval_loss: 0.00475350690255028
|
64 |
+
time: Sat Aug 10 06:22:51 2024
|
65 |
+
|
66 |
+
Epoch 14
|
67 |
+
train_loss: 0.10524643352798381
|
68 |
+
eval_loss: 0.004583307054632575
|
69 |
+
time: Sat Aug 10 12:11:18 2024
|
70 |
+
|
71 |
+
Epoch 15
|
72 |
+
train_loss: 0.1041887047117438
|
73 |
+
eval_loss: 0.004289783886142609
|
74 |
+
time: Sat Aug 10 17:59:48 2024
|
75 |
+
|
76 |
+
Epoch 16
|
77 |
+
train_loss: 0.10343191375801945
|
78 |
+
eval_loss: 0.004421581262192111
|
79 |
+
time: Sat Aug 10 23:47:37 2024
|
80 |
+
|
81 |
+
Epoch 17
|
82 |
+
train_loss: 0.10256196519161385
|
83 |
+
eval_loss: 0.004104017818401634
|
84 |
+
time: Sun Aug 11 05:35:49 2024
|
85 |
+
|
86 |
+
Epoch 18
|
87 |
+
train_loss: 0.10170993055767087
|
88 |
+
eval_loss: 0.0039769458375234585
|
89 |
+
time: Sun Aug 11 11:23:39 2024
|
90 |
+
|
91 |
+
Epoch 19
|
92 |
+
train_loss: 0.1011880517951369
|
93 |
+
eval_loss: 0.0039005329324529833
|
94 |
+
time: Sun Aug 11 17:11:19 2024
|
95 |
+
|
96 |
+
Epoch 20
|
97 |
+
train_loss: 0.10030771156829077
|
98 |
+
eval_loss: 0.0036845325137673237
|
99 |
+
time: Sun Aug 11 22:59:49 2024
|
100 |
+
|
101 |
+
Epoch 21
|
102 |
+
train_loss: 0.09972109616302548
|
103 |
+
eval_loss: 0.0038503893043940205
|
104 |
+
time: Mon Aug 12 04:48:03 2024
|
105 |
+
|
106 |
+
Epoch 22
|
107 |
+
train_loss: 0.09932596696744844
|
108 |
+
eval_loss: 0.00370702411211194
|
109 |
+
time: Mon Aug 12 10:36:32 2024
|
110 |
+
|
111 |
+
Epoch 23
|
112 |
+
train_loss: 0.09888291950362459
|
113 |
+
eval_loss: 0.0034573812171313834
|
114 |
+
time: Mon Aug 12 16:24:55 2024
|
115 |
+
|
116 |
+
Epoch 24
|
117 |
+
train_loss: 0.09852503939581284
|
118 |
+
eval_loss: 0.003370235667697582
|
119 |
+
time: Mon Aug 12 22:12:14 2024
|
120 |
+
|
121 |
+
Epoch 25
|
122 |
+
train_loss: 0.09825884147004627
|
123 |
+
eval_loss: 0.00346387299475209
|
124 |
+
time: Tue Aug 13 04:00:42 2024
|
125 |
+
|
126 |
+
Epoch 26
|
127 |
+
train_loss: 0.09756856258879791
|
128 |
+
eval_loss: 0.0033276399650575615
|
129 |
+
time: Tue Aug 13 09:49:22 2024
|
130 |
+
|
131 |
+
Epoch 27
|
132 |
+
train_loss: 0.09730380131801182
|
133 |
+
eval_loss: 0.003326884365762399
|
134 |
+
time: Tue Aug 13 15:36:05 2024
|
135 |
+
|
136 |
+
Epoch 28
|
137 |
+
train_loss: 0.09687296288584166
|
138 |
+
eval_loss: 0.0034621171255395573
|
139 |
+
time: Tue Aug 13 21:23:24 2024
|
140 |
+
|
141 |
+
Epoch 29
|
142 |
+
train_loss: 0.09668537175198876
|
143 |
+
eval_loss: 0.003284947640647648
|
144 |
+
time: Wed Aug 14 03:10:21 2024
|
145 |
+
|
146 |
+
Epoch 30
|
147 |
+
train_loss: 0.09628572566572022
|
148 |
+
eval_loss: 0.003119471057549999
|
149 |
+
time: Wed Aug 14 08:56:45 2024
|
150 |
+
|
151 |
+
Epoch 31
|
152 |
+
train_loss: 0.09617123452549026
|
153 |
+
eval_loss: 0.003124797866062776
|
154 |
+
time: Wed Aug 14 14:43:12 2024
|
155 |
+
|
156 |
+
Epoch 32
|
157 |
+
train_loss: 0.09578377932399237
|
158 |
+
eval_loss: 0.0030736677601092537
|
159 |
+
time: Wed Aug 14 20:31:01 2024
|
160 |
+
|
161 |
+
Epoch 33
|
162 |
+
train_loss: 0.09558304869954821
|
163 |
+
eval_loss: 0.003178201471396451
|
164 |
+
time: Thu Aug 15 02:19:14 2024
|
165 |
+
|
166 |
+
Epoch 34
|
167 |
+
train_loss: 0.0952804450174092
|
168 |
+
eval_loss: 0.0030847328114775225
|
169 |
+
time: Thu Aug 15 08:06:29 2024
|
170 |
+
|
171 |
+
Epoch 35
|
172 |
+
train_loss: 0.09513826066486042
|
173 |
+
eval_loss: 0.00303873973446682
|
174 |
+
time: Thu Aug 15 13:52:17 2024
|
175 |
+
|
176 |
+
Epoch 36
|
177 |
+
train_loss: 0.09466769916316405
|
178 |
+
eval_loss: 0.0030122215467611258
|
179 |
+
time: Thu Aug 15 19:38:29 2024
|
180 |
+
|
181 |
+
Epoch 37
|
182 |
+
train_loss: 0.09465687754501316
|
183 |
+
eval_loss: 0.00289094522015785
|
184 |
+
time: Fri Aug 16 01:25:14 2024
|
185 |
+
|
186 |
+
Epoch 38
|
187 |
+
train_loss: 0.09435585222324992
|
188 |
+
eval_loss: 0.0030173959307773393
|
189 |
+
time: Fri Aug 16 07:11:56 2024
|
190 |
+
|
191 |
+
Epoch 39
|
192 |
+
train_loss: 0.09413478592045794
|
193 |
+
eval_loss: 0.002968058454507435
|
194 |
+
time: Fri Aug 16 12:59:16 2024
|
195 |
+
|
196 |
+
Epoch 40
|
197 |
+
train_loss: 0.09393180562734375
|
198 |
+
eval_loss: 0.0030673167865746948
|
199 |
+
time: Fri Aug 16 18:45:23 2024
|
200 |
+
|
201 |
+
Epoch 41
|
202 |
+
train_loss: 0.09365266143982799
|
203 |
+
eval_loss: 0.00287582161187937
|
204 |
+
time: Sat Aug 17 00:31:47 2024
|
205 |
+
|
206 |
+
Epoch 42
|
207 |
+
train_loss: 0.09359205519747489
|
208 |
+
eval_loss: 0.0027280030162997134
|
209 |
+
time: Sat Aug 17 06:18:32 2024
|
210 |
+
|
211 |
+
Epoch 43
|
212 |
+
train_loss: 0.09349238520961266
|
213 |
+
eval_loss: 0.0029261269570300787
|
214 |
+
time: Sat Aug 17 12:05:26 2024
|
215 |
+
|
216 |
+
Epoch 44
|
217 |
+
train_loss: 0.09324607778116949
|
218 |
+
eval_loss: 0.002691730654519444
|
219 |
+
time: Sat Aug 17 17:52:20 2024
|
220 |
+
|
221 |
+
Epoch 45
|
222 |
+
train_loss: 0.09310021795996155
|
223 |
+
eval_loss: 0.0028863806760858132
|
224 |
+
time: Sat Aug 17 23:38:57 2024
|
225 |
+
|
226 |
+
Epoch 46
|
227 |
+
train_loss: 0.09307358593283441
|
228 |
+
eval_loss: 0.002793597210717352
|
229 |
+
time: Sun Aug 18 05:25:42 2024
|
230 |
+
|
231 |
+
Epoch 47
|
232 |
+
train_loss: 0.09299390690766882
|
233 |
+
eval_loss: 0.0027052024821456098
|
234 |
+
time: Sun Aug 18 11:12:32 2024
|
235 |
+
|
236 |
+
Epoch 48
|
237 |
+
train_loss: 0.09253486422624911
|
238 |
+
eval_loss: 0.0027312307396534247
|
239 |
+
time: Sun Aug 18 16:59:16 2024
|
240 |
+
|
241 |
+
Epoch 49
|
242 |
+
train_loss: 0.09243107154309635
|
243 |
+
eval_loss: 0.002648197562936772
|
244 |
+
time: Sun Aug 18 22:46:41 2024
|
245 |
+
|
246 |
+
Epoch 50
|
247 |
+
train_loss: 0.09237845186490301
|
248 |
+
eval_loss: 0.0026844193827840284
|
249 |
+
time: Mon Aug 19 04:35:10 2024
|
250 |
+
|
251 |
+
Epoch 51
|
252 |
+
train_loss: 0.09231985249015236
|
253 |
+
eval_loss: 0.002708845011956738
|
254 |
+
time: Mon Aug 19 10:24:17 2024
|
255 |
+
|
256 |
+
Epoch 52
|
257 |
+
train_loss: 0.0922615721153286
|
258 |
+
eval_loss: 0.0035362059711223225
|
259 |
+
time: Mon Aug 19 16:11:39 2024
|
260 |
+
|
261 |
+
Epoch 53
|
262 |
+
train_loss: 0.09200190843071623
|
263 |
+
eval_loss: 0.0025848455890180064
|
264 |
+
time: Mon Aug 19 21:58:31 2024
|
265 |
+
|
266 |
+
Epoch 54
|
267 |
+
train_loss: 0.09200848002425245
|
268 |
+
eval_loss: 0.0026311414897881983
|
269 |
+
time: Tue Aug 20 03:45:36 2024
|
270 |
+
|
271 |
+
Epoch 55
|
272 |
+
train_loss: 0.09154813869071807
|
273 |
+
eval_loss: 0.0025586662145983823
|
274 |
+
time: Tue Aug 20 09:34:48 2024
|
275 |
+
|
276 |
+
Epoch 56
|
277 |
+
train_loss: 0.09162745474034129
|
278 |
+
eval_loss: 0.0026280648907143545
|
279 |
+
time: Tue Aug 20 15:23:23 2024
|
280 |
+
|
281 |
+
Epoch 57
|
282 |
+
train_loss: 0.09156280245772795
|
283 |
+
eval_loss: 0.002539119078534093
|
284 |
+
time: Tue Aug 20 21:11:25 2024
|
285 |
+
|
286 |
+
Epoch 58
|
287 |
+
train_loss: 0.09142590950099329
|
288 |
+
eval_loss: 0.0026369429265152866
|
289 |
+
time: Wed Aug 21 02:59:19 2024
|
290 |
+
|
291 |
+
Epoch 59
|
292 |
+
train_loss: 0.09139848643851392
|
293 |
+
eval_loss: 0.0024354966580356916
|
294 |
+
time: Wed Aug 21 08:46:23 2024
|
295 |
+
|
296 |
+
Epoch 60
|
297 |
+
train_loss: 0.09131192888740647
|
298 |
+
eval_loss: 0.0024594995301248277
|
299 |
+
time: Wed Aug 21 14:33:28 2024
|
300 |
+
|
301 |
+
Epoch 61
|
302 |
+
train_loss: 0.09122042933562911
|
303 |
+
eval_loss: 0.002616936316367883
|
304 |
+
time: Wed Aug 21 20:20:57 2024
|
305 |
+
|
306 |
+
Epoch 62
|
307 |
+
train_loss: 0.09109125168796305
|
308 |
+
eval_loss: 0.0025555431279884297
|
309 |
+
time: Thu Aug 22 02:08:45 2024
|
310 |
+
|
311 |
+
Epoch 63
|
312 |
+
train_loss: 0.09106527324403817
|
313 |
+
eval_loss: 0.0025145284593781213
|
314 |
+
time: Thu Aug 22 07:56:26 2024
|
315 |
+
|
316 |
+
Epoch 64
|
317 |
+
train_loss: 0.09095406525682191
|
318 |
+
eval_loss: 0.0025151555842959678
|
319 |
+
time: Thu Aug 22 13:45:57 2024
|
320 |
+
|
321 |
+
Epoch 65
|
322 |
+
train_loss: 0.09102793501718281
|
323 |
+
eval_loss: 0.0024135450126194563
|
324 |
+
time: Thu Aug 22 19:54:28 2024
|
325 |
+
|
326 |
+
Epoch 66
|
327 |
+
train_loss: 0.0908411063853937
|
328 |
+
eval_loss: 0.002460922076728368
|
329 |
+
time: Fri Aug 23 01:59:41 2024
|
330 |
+
|
331 |
+
Epoch 67
|
332 |
+
train_loss: 0.09070221083785855
|
333 |
+
eval_loss: 0.002453409551882543
|
334 |
+
time: Fri Aug 23 07:52:30 2024
|
335 |
+
|
336 |
+
Epoch 68
|
337 |
+
train_loss: 0.0906545008953897
|
338 |
+
eval_loss: 0.0024080786435031784
|
339 |
+
time: Fri Aug 23 13:41:28 2024
|
340 |
+
|
341 |
+
Epoch 69
|
342 |
+
train_loss: 0.0907353380525871
|
343 |
+
eval_loss: 0.0024573436347799147
|
344 |
+
time: Fri Aug 23 19:27:14 2024
|
345 |
+
|
346 |
+
Epoch 70
|
347 |
+
train_loss: 0.09040538104085095
|
348 |
+
eval_loss: 0.0023765437401249566
|
349 |
+
time: Sat Aug 24 01:14:45 2024
|
350 |
+
|
351 |
+
Epoch 71
|
352 |
+
train_loss: 0.09036114065518137
|
353 |
+
eval_loss: 0.0023877528348234226
|
354 |
+
time: Sat Aug 24 07:02:04 2024
|
355 |
+
|
356 |
+
Epoch 72
|
357 |
+
train_loss: 0.09037455027205546
|
358 |
+
eval_loss: 0.002315233082103814
|
359 |
+
time: Sat Aug 24 12:49:24 2024
|
360 |
+
|
361 |
+
Epoch 73
|
362 |
+
train_loss: 0.09026183628343257
|
363 |
+
eval_loss: 0.0024284060419643228
|
364 |
+
time: Sat Aug 24 18:35:36 2024
|
365 |
+
|
366 |
+
Epoch 74
|
367 |
+
train_loss: 0.09019025581511034
|
368 |
+
eval_loss: 0.002393116130206718
|
369 |
+
time: Sun Aug 25 00:21:29 2024
|
370 |
+
|
371 |
+
Epoch 75
|
372 |
+
train_loss: 0.089901714783446
|
373 |
+
eval_loss: 0.002298152916632467
|
374 |
+
time: Sun Aug 25 06:08:01 2024
|
375 |
+
|
376 |
+
Epoch 76
|
377 |
+
train_loss: 0.09018262871273484
|
378 |
+
eval_loss: 0.002273971366672482
|
379 |
+
time: Sun Aug 25 11:54:02 2024
|
380 |
+
|
381 |
+
Epoch 77
|
382 |
+
train_loss: 0.08998425874228
|
383 |
+
eval_loss: 0.002317420323379338
|
384 |
+
time: Sun Aug 25 17:44:05 2024
|
385 |
+
|
386 |
+
Epoch 78
|
387 |
+
train_loss: 0.08983653943919646
|
388 |
+
eval_loss: 0.0024391192159878743
|
389 |
+
time: Sun Aug 25 23:31:34 2024
|
390 |
+
|
391 |
+
Epoch 79
|
392 |
+
train_loss: 0.08981405456901183
|
393 |
+
eval_loss: 0.002319374949895317
|
394 |
+
time: Mon Aug 26 05:24:56 2024
|
395 |
+
|
396 |
+
Epoch 80
|
397 |
+
train_loss: 0.08974534569690559
|
398 |
+
eval_loss: 0.0023008979344151066
|
399 |
+
time: Mon Aug 26 11:28:33 2024
|
400 |
+
|
401 |
+
Epoch 81
|
402 |
+
train_loss: 0.08972110153310983
|
403 |
+
eval_loss: 0.002406696710865237
|
404 |
+
time: Mon Aug 26 17:33:30 2024
|
405 |
+
|
406 |
+
Epoch 82
|
407 |
+
train_loss: 0.0895689915361898
|
408 |
+
eval_loss: 0.002241936448434926
|
409 |
+
time: Mon Aug 26 23:39:15 2024
|
410 |
+
|
411 |
+
Epoch 83
|
412 |
+
train_loss: 0.08950625452328584
|
413 |
+
eval_loss: 0.002408353965493697
|
414 |
+
time: Tue Aug 27 05:37:59 2024
|
415 |
+
|
416 |
+
Epoch 84
|
417 |
+
train_loss: 0.08959725393084628
|
418 |
+
eval_loss: 0.0023435966142665455
|
419 |
+
time: Tue Aug 27 11:34:29 2024
|
420 |
+
|
421 |
+
Epoch 85
|
422 |
+
train_loss: 0.08970333726515986
|
423 |
+
eval_loss: 0.0023965956810233086
|
424 |
+
time: Tue Aug 27 17:27:31 2024
|
425 |
+
|
426 |
+
Epoch 86
|
427 |
+
train_loss: 0.08948115523227308
|
428 |
+
eval_loss: 0.002325803569256709
|
429 |
+
time: Tue Aug 27 23:19:43 2024
|
430 |
+
|
431 |
+
Epoch 87
|
432 |
+
train_loss: 0.08933937688654775
|
433 |
+
eval_loss: 0.0023552257988114647
|
434 |
+
time: Wed Aug 28 05:11:34 2024
|
435 |
+
|
436 |
+
Epoch 88
|
437 |
+
train_loss: 0.08938353908107184
|
438 |
+
eval_loss: 0.0024397599904794043
|
439 |
+
time: Wed Aug 28 11:01:23 2024
|
440 |
+
|
441 |
+
Epoch 89
|
442 |
+
train_loss: 0.08921640703096091
|
443 |
+
eval_loss: 0.002223708766084243
|
444 |
+
time: Wed Aug 28 16:50:21 2024
|
445 |
+
|
446 |
+
Epoch 90
|
447 |
+
train_loss: 0.08929300930090782
|
448 |
+
eval_loss: 0.0022849828316260303
|
449 |
+
time: Wed Aug 28 22:38:53 2024
|
450 |
+
|
451 |
+
Epoch 91
|
452 |
+
train_loss: 0.08910525214309825
|
453 |
+
eval_loss: 0.0022257193633186227
|
454 |
+
time: Thu Aug 29 04:35:16 2024
|
455 |
+
|
456 |
+
Epoch 92
|
457 |
+
train_loss: 0.08905495976636461
|
458 |
+
eval_loss: 0.0022299331251850137
|
459 |
+
time: Thu Aug 29 10:29:46 2024
|
460 |
+
|
461 |
+
Epoch 93
|
462 |
+
train_loss: 0.08890526102100955
|
463 |
+
eval_loss: 0.0022962711695463786
|
464 |
+
time: Thu Aug 29 16:23:49 2024
|
465 |
+
|
466 |
+
Epoch 94
|
467 |
+
train_loss: 0.08908289874104246
|
468 |
+
eval_loss: 0.002243622880820028
|
469 |
+
time: Thu Aug 29 22:15:42 2024
|
470 |
+
|
471 |
+
Epoch 95
|
472 |
+
train_loss: 0.08908785978677156
|
473 |
+
eval_loss: 0.0022457318524397784
|
474 |
+
time: Fri Aug 30 04:06:57 2024
|
475 |
+
|
476 |
+
Epoch 96
|
477 |
+
train_loss: 0.08888098475318565
|
478 |
+
eval_loss: 0.002224675611787346
|
479 |
+
time: Fri Aug 30 09:58:43 2024
|
480 |
+
|
481 |
+
Epoch 97
|
482 |
+
train_loss: 0.08888529259134526
|
483 |
+
eval_loss: 0.0021844924980664493
|
484 |
+
time: Fri Aug 30 15:50:16 2024
|
485 |
+
|
486 |
+
Epoch 98
|
487 |
+
train_loss: 0.08885388837534758
|
488 |
+
eval_loss: 0.0022109088076294288
|
489 |
+
time: Fri Aug 30 21:41:59 2024
|
490 |
+
|
491 |
+
Epoch 99
|
492 |
+
train_loss: 0.08873902663868657
|
493 |
+
eval_loss: 0.0022606451996653202
|
494 |
+
time: Sat Aug 31 03:34:46 2024
|
495 |
+
|
496 |
+
Epoch 100
|
497 |
+
train_loss: 0.08877080098666765
|
498 |
+
eval_loss: 0.002279470525367602
|
499 |
+
time: Sat Aug 31 09:28:38 2024
|
500 |
+
|
code/train_clamp2.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import wandb
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from utils import *
|
9 |
+
from config import *
|
10 |
+
from tqdm import tqdm
|
11 |
+
from copy import deepcopy
|
12 |
+
import torch.distributed as dist
|
13 |
+
from torch.amp import autocast, GradScaler
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
from torch.utils.data.distributed import DistributedSampler
|
16 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
17 |
+
from transformers import AutoTokenizer, BertConfig, get_constant_schedule_with_warmup
|
18 |
+
|
19 |
+
def list_files_in_json(json_path):
|
20 |
+
file_list = []
|
21 |
+
|
22 |
+
if os.path.exists(json_path):
|
23 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
24 |
+
for line in f:
|
25 |
+
item = json.loads(line)
|
26 |
+
file_list.append(item)
|
27 |
+
|
28 |
+
return file_list
|
29 |
+
|
30 |
+
def collate_batch(batch):
|
31 |
+
text_inputs, text_masks, music_inputs, music_masks = zip(*batch)
|
32 |
+
|
33 |
+
text_inputs = torch.stack(text_inputs)
|
34 |
+
text_masks = torch.stack(text_masks)
|
35 |
+
music_inputs = torch.stack(music_inputs)
|
36 |
+
music_masks = torch.stack(music_masks)
|
37 |
+
|
38 |
+
return text_inputs, text_masks, music_inputs, music_masks
|
39 |
+
|
40 |
+
class TextMusicDataset(Dataset):
|
41 |
+
def __init__(self, items, mode):
|
42 |
+
print("The number of "+mode+" data: "+str(len(items)))
|
43 |
+
self.items = items
|
44 |
+
self.mode = mode
|
45 |
+
if self.mode == 'train' or not EVAL_JSONL:
|
46 |
+
self.datapath = os.path.dirname(TRAIN_JSONL)
|
47 |
+
elif self.mode == 'eval':
|
48 |
+
self.datapath = os.path.dirname(EVAL_JSONL)
|
49 |
+
|
50 |
+
def text_dropout(self, item):
|
51 |
+
if random.random() < 0.5:
|
52 |
+
candidates = []
|
53 |
+
for key in item.keys():
|
54 |
+
if key not in ["summary_en", "summary_nen", "filepaths"]:
|
55 |
+
if item[key] == None:
|
56 |
+
continue
|
57 |
+
elif isinstance(item[key], str):
|
58 |
+
candidates.append(item[key])
|
59 |
+
elif isinstance(item[key], list):
|
60 |
+
candidates.extend(item[key])
|
61 |
+
candidates = list(set(candidates))
|
62 |
+
candidates = "\n".join(candidates)
|
63 |
+
candidates = candidates.split("\n")
|
64 |
+
selected_candidates = [c for c in candidates if len(c) > 0 and random.random() < 0.5]
|
65 |
+
if len(selected_candidates) == 0:
|
66 |
+
selected_candidates = candidates
|
67 |
+
random.shuffle(selected_candidates)
|
68 |
+
text = tokenizer.sep_token.join(selected_candidates)
|
69 |
+
else:
|
70 |
+
if random.random() < 0.5:
|
71 |
+
text = random.choice(item["summary_en"])
|
72 |
+
else:
|
73 |
+
text = random.choice(item["summary_nen"])["summary"]
|
74 |
+
|
75 |
+
return text
|
76 |
+
|
77 |
+
def random_truncate(self, input_tensor, max_length):
|
78 |
+
choices = ["head", "tail", "middle"]
|
79 |
+
choice = random.choice(choices)
|
80 |
+
if choice == "head" or self.mode == 'eval':
|
81 |
+
input_tensor = input_tensor[:max_length]
|
82 |
+
elif choice == "tail":
|
83 |
+
input_tensor = input_tensor[-max_length:]
|
84 |
+
elif choice == "middle":
|
85 |
+
start = random.randint(1, input_tensor.size(0)-max_length)
|
86 |
+
input_tensor = input_tensor[start:start+max_length]
|
87 |
+
|
88 |
+
return input_tensor
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return len(self.items)
|
92 |
+
|
93 |
+
def __getitem__(self, idx):
|
94 |
+
item = self.items[idx]
|
95 |
+
|
96 |
+
# randomly select text from the item
|
97 |
+
if self.mode == 'train' and TEXT_DROPOUT:
|
98 |
+
text = self.text_dropout(item)
|
99 |
+
else:
|
100 |
+
text = item["summary_en"][0]
|
101 |
+
|
102 |
+
# tokenize text and build mask for text tokens
|
103 |
+
text_inputs = tokenizer(text, return_tensors='pt')
|
104 |
+
text_inputs = text_inputs['input_ids'].squeeze(0)
|
105 |
+
if text_inputs.size(0) > MAX_TEXT_LENGTH:
|
106 |
+
text_inputs = self.random_truncate(text_inputs, MAX_TEXT_LENGTH)
|
107 |
+
text_masks = torch.ones(text_inputs.size(0))
|
108 |
+
|
109 |
+
# load music file
|
110 |
+
if self.mode == 'train':
|
111 |
+
filepath = random.choice(item["filepaths"])
|
112 |
+
else:
|
113 |
+
if item["filepaths"][0].endswith(".abc"):
|
114 |
+
filepath = item["filepaths"][0]
|
115 |
+
else:
|
116 |
+
filepath = item["filepaths"][1]
|
117 |
+
filepath = self.datapath + '/' + filepath
|
118 |
+
|
119 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
120 |
+
item = f.read().replace("L:1/8\n", "") if filepath.endswith(".abc") else f.read()
|
121 |
+
|
122 |
+
# randomly remove instrument info from the music file
|
123 |
+
if random.random() < 0.9 and self.mode == 'train':
|
124 |
+
item = remove_instrument_info(item)
|
125 |
+
|
126 |
+
# mask music inputs
|
127 |
+
music_inputs = patchilizer.encode(item, add_special_patches=True, truncate=True, random_truncate=(self.mode=="train"))
|
128 |
+
music_inputs = torch.tensor(music_inputs)
|
129 |
+
music_masks = torch.ones(music_inputs.size(0))
|
130 |
+
|
131 |
+
# pad text inputs and masks
|
132 |
+
pad_indices = torch.ones(MAX_TEXT_LENGTH - text_inputs.size(0)).long() * tokenizer.pad_token_id
|
133 |
+
text_inputs = torch.cat((text_inputs, pad_indices), 0)
|
134 |
+
text_masks = torch.cat((text_masks, torch.zeros(MAX_TEXT_LENGTH - text_masks.size(0))), 0)
|
135 |
+
|
136 |
+
# pad music inputs and masks
|
137 |
+
pad_indices = torch.ones((PATCH_LENGTH - music_inputs.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
|
138 |
+
music_inputs = torch.cat((music_inputs, pad_indices), 0)
|
139 |
+
music_masks = torch.cat((music_masks, torch.zeros(PATCH_LENGTH - music_masks.size(0))), 0)
|
140 |
+
|
141 |
+
return text_inputs, text_masks, music_inputs, music_masks
|
142 |
+
|
143 |
+
# call model with a batch of input
|
144 |
+
def process_one_batch(batch):
|
145 |
+
text_inputs, text_masks, music_inputs, music_masks = batch
|
146 |
+
|
147 |
+
loss = model(text_inputs,
|
148 |
+
text_masks,
|
149 |
+
music_inputs,
|
150 |
+
music_masks)
|
151 |
+
|
152 |
+
# Reduce the loss on GPU 0
|
153 |
+
if world_size > 1:
|
154 |
+
loss = loss.unsqueeze(0)
|
155 |
+
dist.reduce(loss, dst=0)
|
156 |
+
loss = loss / world_size
|
157 |
+
dist.broadcast(loss, src=0)
|
158 |
+
|
159 |
+
return loss.mean()
|
160 |
+
|
161 |
+
# do one epoch for training
|
162 |
+
def train_epoch(epoch):
|
163 |
+
tqdm_train_set = tqdm(train_set)
|
164 |
+
total_train_loss = 0
|
165 |
+
iter_idx = 1
|
166 |
+
model.train()
|
167 |
+
train_steps = (epoch-1)*len(train_set)
|
168 |
+
|
169 |
+
for batch in tqdm_train_set:
|
170 |
+
with autocast(device_type='cuda'):
|
171 |
+
loss = process_one_batch(batch)
|
172 |
+
scaler.scale(loss).backward()
|
173 |
+
total_train_loss += loss.item()
|
174 |
+
scaler.step(optimizer)
|
175 |
+
scaler.update()
|
176 |
+
|
177 |
+
lr_scheduler.step()
|
178 |
+
model.zero_grad(set_to_none=True)
|
179 |
+
tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
|
180 |
+
train_steps += 1
|
181 |
+
|
182 |
+
# Log the training loss to wandb
|
183 |
+
if global_rank==0 and CLAMP2_WANDB_LOG:
|
184 |
+
wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
|
185 |
+
|
186 |
+
iter_idx += 1
|
187 |
+
|
188 |
+
return total_train_loss / (iter_idx-1)
|
189 |
+
|
190 |
+
# do one epoch for eval
|
191 |
+
def eval_epoch():
|
192 |
+
tqdm_eval_set = tqdm(eval_set)
|
193 |
+
total_eval_loss = 0
|
194 |
+
iter_idx = 1
|
195 |
+
model.eval()
|
196 |
+
|
197 |
+
# Evaluate data for one epoch
|
198 |
+
for batch in tqdm_eval_set:
|
199 |
+
with torch.no_grad():
|
200 |
+
loss = process_one_batch(batch)
|
201 |
+
|
202 |
+
total_eval_loss += loss.item()
|
203 |
+
tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
|
204 |
+
iter_idx += 1
|
205 |
+
|
206 |
+
return total_eval_loss / (iter_idx-1)
|
207 |
+
|
208 |
+
# train and eval
|
209 |
+
if __name__ == "__main__":
|
210 |
+
|
211 |
+
# Set up distributed training
|
212 |
+
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
213 |
+
global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
|
214 |
+
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
215 |
+
|
216 |
+
if world_size > 1:
|
217 |
+
torch.cuda.set_device(local_rank)
|
218 |
+
device = torch.device("cuda", local_rank)
|
219 |
+
dist.init_process_group(backend='nccl')
|
220 |
+
else:
|
221 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
222 |
+
|
223 |
+
if CLAMP2_DETERMINISTIC:
|
224 |
+
seed = 42 + global_rank
|
225 |
+
random.seed(seed)
|
226 |
+
np.random.seed(seed)
|
227 |
+
torch.manual_seed(seed)
|
228 |
+
torch.cuda.manual_seed_all(seed)
|
229 |
+
torch.backends.cudnn.deterministic = True
|
230 |
+
torch.backends.cudnn.benchmark = False
|
231 |
+
|
232 |
+
m3_config = BertConfig(vocab_size=1,
|
233 |
+
hidden_size=M3_HIDDEN_SIZE,
|
234 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
235 |
+
num_attention_heads=M3_HIDDEN_SIZE//64,
|
236 |
+
intermediate_size=M3_HIDDEN_SIZE*4,
|
237 |
+
max_position_embeddings=PATCH_LENGTH)
|
238 |
+
model = CLaMP2Model(m3_config,
|
239 |
+
global_rank,
|
240 |
+
world_size,
|
241 |
+
TEXT_MODEL_NAME,
|
242 |
+
CLAMP2_HIDDEN_SIZE,
|
243 |
+
CLAMP2_LOAD_M3)
|
244 |
+
model = model.to(device)
|
245 |
+
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
|
246 |
+
patchilizer = M3Patchilizer()
|
247 |
+
|
248 |
+
# print parameter number
|
249 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
250 |
+
|
251 |
+
if world_size > 1:
|
252 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
253 |
+
|
254 |
+
scaler = GradScaler()
|
255 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=CLAMP2_LEARNING_RATE)
|
256 |
+
|
257 |
+
if CLAMP2_WANDB_LOG and global_rank==0:
|
258 |
+
# Initialize wandb
|
259 |
+
if WANDB_KEY:
|
260 |
+
wandb.login(key=WANDB_KEY)
|
261 |
+
wandb.init(project="clamp2",
|
262 |
+
name=CLAMP2_WEIGHTS_PATH.replace("weights_", "").replace(".pth", ""))
|
263 |
+
|
264 |
+
# load filenames under train and eval folder
|
265 |
+
train_files = list_files_in_json(TRAIN_JSONL)
|
266 |
+
eval_files = list_files_in_json(EVAL_JSONL)
|
267 |
+
|
268 |
+
if len(eval_files)==0:
|
269 |
+
train_files, eval_files = split_data(train_files)
|
270 |
+
|
271 |
+
train_batch_nums = int(len(train_files) / CLAMP2_BATCH_SIZE)
|
272 |
+
eval_batch_nums = int(len(eval_files) / CLAMP2_BATCH_SIZE)
|
273 |
+
|
274 |
+
train_files = train_files[:train_batch_nums*CLAMP2_BATCH_SIZE]
|
275 |
+
eval_files = eval_files[:eval_batch_nums*CLAMP2_BATCH_SIZE]
|
276 |
+
|
277 |
+
train_set = TextMusicDataset(train_files, 'train')
|
278 |
+
eval_set = TextMusicDataset(eval_files, 'eval')
|
279 |
+
|
280 |
+
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank)
|
281 |
+
eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank)
|
282 |
+
|
283 |
+
train_set = DataLoader(train_set, batch_size=CLAMP2_BATCH_SIZE, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
|
284 |
+
eval_set = DataLoader(eval_set, batch_size=CLAMP2_BATCH_SIZE, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
|
285 |
+
|
286 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = 1000)
|
287 |
+
|
288 |
+
if CLAMP2_LOAD_CKPT and os.path.exists(CLAMP2_WEIGHTS_PATH):
|
289 |
+
# Load checkpoint to CPU
|
290 |
+
checkpoint = torch.load(CLAMP2_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
291 |
+
|
292 |
+
# Here, model is assumed to be on GPU
|
293 |
+
# Load state dict to CPU model first, then move the model to GPU
|
294 |
+
if torch.cuda.device_count() > 1:
|
295 |
+
# If you have a DataParallel model, you need to load to model.module instead
|
296 |
+
cpu_model = deepcopy(model.module)
|
297 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
298 |
+
model.module.load_state_dict(cpu_model.state_dict())
|
299 |
+
else:
|
300 |
+
# Load to a CPU clone of the model, then load back
|
301 |
+
cpu_model = deepcopy(model)
|
302 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
303 |
+
model.load_state_dict(cpu_model.state_dict())
|
304 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
305 |
+
lr_scheduler.load_state_dict(checkpoint['lr_sched'])
|
306 |
+
pre_epoch = checkpoint['epoch']
|
307 |
+
best_epoch = checkpoint['best_epoch']
|
308 |
+
min_eval_loss = checkpoint['min_eval_loss']
|
309 |
+
print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
310 |
+
checkpoint = None
|
311 |
+
|
312 |
+
else:
|
313 |
+
pre_epoch = 0
|
314 |
+
best_epoch = 0
|
315 |
+
min_eval_loss = float('inf')
|
316 |
+
|
317 |
+
model = model.to(device)
|
318 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=CLAMP2_LEARNING_RATE)
|
319 |
+
|
320 |
+
for epoch in range(1+pre_epoch, CLAMP2_NUM_EPOCH+1):
|
321 |
+
train_sampler.set_epoch(epoch)
|
322 |
+
eval_sampler.set_epoch(epoch)
|
323 |
+
print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
|
324 |
+
train_loss = train_epoch(epoch)
|
325 |
+
eval_loss = eval_epoch()
|
326 |
+
if global_rank==0:
|
327 |
+
with open(CLAMP2_LOGS_PATH,'a') as f:
|
328 |
+
f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
|
329 |
+
if eval_loss < min_eval_loss:
|
330 |
+
best_epoch = epoch
|
331 |
+
min_eval_loss = eval_loss
|
332 |
+
checkpoint = {
|
333 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
334 |
+
'optimizer': optimizer.state_dict(),
|
335 |
+
'lr_sched': lr_scheduler.state_dict(),
|
336 |
+
'epoch': epoch,
|
337 |
+
'best_epoch': best_epoch,
|
338 |
+
'min_eval_loss': min_eval_loss
|
339 |
+
}
|
340 |
+
torch.save(checkpoint, CLAMP2_WEIGHTS_PATH)
|
341 |
+
checkpoint = {
|
342 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
343 |
+
'optimizer': optimizer.state_dict(),
|
344 |
+
'lr_sched': lr_scheduler.state_dict(),
|
345 |
+
'epoch': epoch,
|
346 |
+
'best_epoch': best_epoch,
|
347 |
+
'min_eval_loss': min_eval_loss
|
348 |
+
}
|
349 |
+
torch.save(checkpoint, "latest_"+CLAMP2_WEIGHTS_PATH)
|
350 |
+
|
351 |
+
if world_size > 1:
|
352 |
+
dist.barrier()
|
353 |
+
|
354 |
+
if global_rank==0:
|
355 |
+
print("Best Eval Epoch : "+str(best_epoch))
|
356 |
+
print("Min Eval Loss : "+str(min_eval_loss))
|
code/train_m3.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
import wandb
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import weakref
|
8 |
+
import numpy as np
|
9 |
+
from utils import *
|
10 |
+
from config import *
|
11 |
+
from tqdm import tqdm
|
12 |
+
from copy import deepcopy
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.amp import autocast, GradScaler
|
15 |
+
from torch.utils.data import Dataset, DataLoader
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
18 |
+
from transformers import BertConfig, GPT2Config, get_constant_schedule_with_warmup
|
19 |
+
|
20 |
+
patchilizer = M3Patchilizer()
|
21 |
+
|
22 |
+
def clear_unused_tensors():
|
23 |
+
gc.disable() # Temporarily disable garbage collection
|
24 |
+
try:
|
25 |
+
# Get the set of tensor ids used by the model
|
26 |
+
if hasattr(model, "module"):
|
27 |
+
model_tensors = {id(p) for p in model.module.parameters()}
|
28 |
+
else:
|
29 |
+
model_tensors = {id(p) for p in model.parameters()}
|
30 |
+
|
31 |
+
# Get the set of tensor ids used by the optimizer
|
32 |
+
optimizer_tensors = {
|
33 |
+
id(state)
|
34 |
+
for state_dict in optimizer.state.values()
|
35 |
+
for state in state_dict.values()
|
36 |
+
if isinstance(state, torch.Tensor) # Ensure only tensors are considered
|
37 |
+
}
|
38 |
+
|
39 |
+
# List of all CUDA tensors currently in memory
|
40 |
+
tensors = [obj for obj in gc.get_objects() if isinstance(obj, torch.Tensor) and obj.is_cuda]
|
41 |
+
|
42 |
+
# Create weak references to avoid interfering with garbage collection
|
43 |
+
tensor_refs = [weakref.ref(tensor) for tensor in tensors]
|
44 |
+
|
45 |
+
for tensor_ref in tensor_refs:
|
46 |
+
tensor = tensor_ref() # Dereference the weak reference
|
47 |
+
if tensor is not None and id(tensor) not in model_tensors and id(tensor) not in optimizer_tensors:
|
48 |
+
# Mark the tensor for deletion
|
49 |
+
tensor.detach_() # Detach from computation graph
|
50 |
+
del tensor # Delete the tensor reference
|
51 |
+
except:
|
52 |
+
pass
|
53 |
+
|
54 |
+
finally:
|
55 |
+
gc.enable() # Re-enable garbage collection
|
56 |
+
gc.collect() # Force a garbage collection
|
57 |
+
torch.cuda.empty_cache() # Clear the CUDA cache
|
58 |
+
|
59 |
+
def list_files_in_directory(directories, extensions=["abc", "mtf"]):
|
60 |
+
file_list = []
|
61 |
+
|
62 |
+
for directory in directories:
|
63 |
+
for root, dirs, files in os.walk(directory):
|
64 |
+
for file in files:
|
65 |
+
if any(file.endswith(ext) for ext in extensions):
|
66 |
+
file_path = os.path.join(root, file)
|
67 |
+
file_list.append(file_path)
|
68 |
+
|
69 |
+
return file_list
|
70 |
+
|
71 |
+
def collate_batch(batch):
|
72 |
+
input_patches, input_masks, selected_indices, target_patches = zip(*batch)
|
73 |
+
|
74 |
+
input_patches = torch.nn.utils.rnn.pad_sequence(input_patches, batch_first=True, padding_value=patchilizer.pad_token_id)
|
75 |
+
input_masks = torch.nn.utils.rnn.pad_sequence(input_masks, batch_first=True, padding_value=0)
|
76 |
+
selected_indices = torch.nn.utils.rnn.pad_sequence(selected_indices, batch_first=True, padding_value=0)
|
77 |
+
target_patches = torch.nn.utils.rnn.pad_sequence(target_patches, batch_first=True, padding_value=patchilizer.pad_token_id)
|
78 |
+
|
79 |
+
return input_patches, input_masks, selected_indices, target_patches
|
80 |
+
|
81 |
+
class M3Dataset(Dataset):
|
82 |
+
def __init__(self, filenames, mode):
|
83 |
+
print("The number of "+mode+" data: "+str(len(filenames)))
|
84 |
+
self.filenames = filenames
|
85 |
+
self.mode = mode
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.filenames)
|
89 |
+
|
90 |
+
def __getitem__(self, idx):
|
91 |
+
filename = self.filenames[idx]
|
92 |
+
try:
|
93 |
+
with open(filename, "r", encoding="utf-8") as f:
|
94 |
+
item = f.read().replace("L:1/8\n", "") if filename.endswith(".abc") else f.read()
|
95 |
+
except Exception as e:
|
96 |
+
print(e)
|
97 |
+
print("Failed to load: "+filename)
|
98 |
+
item = ""
|
99 |
+
|
100 |
+
target_patches = patchilizer.encode(item, add_special_patches=True, truncate=True, random_truncate=(self.mode=="train"))
|
101 |
+
input_masks = torch.tensor([1]*len(target_patches))
|
102 |
+
input_patches, selected_indices = mask_patches(target_patches, patchilizer, self.mode)
|
103 |
+
input_patches = input_patches.reshape(-1)
|
104 |
+
target_patches = torch.tensor(target_patches).reshape(-1)
|
105 |
+
return input_patches, input_masks, selected_indices, target_patches
|
106 |
+
|
107 |
+
# call model with a batch of input
|
108 |
+
def process_one_batch(batch):
|
109 |
+
input_patches, input_masks, selected_indices, target_patches = batch
|
110 |
+
|
111 |
+
loss = model(input_patches,
|
112 |
+
input_masks,
|
113 |
+
selected_indices,
|
114 |
+
target_patches).loss
|
115 |
+
|
116 |
+
# Reduce the loss on GPU 0
|
117 |
+
if world_size > 1:
|
118 |
+
loss = loss.unsqueeze(0)
|
119 |
+
dist.reduce(loss, dst=0)
|
120 |
+
loss = loss / world_size
|
121 |
+
dist.broadcast(loss, src=0)
|
122 |
+
|
123 |
+
return loss.mean()
|
124 |
+
|
125 |
+
# do one epoch for training
|
126 |
+
def train_epoch(epoch):
|
127 |
+
tqdm_train_set = tqdm(train_set)
|
128 |
+
total_train_loss = 0
|
129 |
+
iter_idx = 1
|
130 |
+
model.train()
|
131 |
+
train_steps = (epoch-1)*len(train_set)
|
132 |
+
|
133 |
+
for batch in tqdm_train_set:
|
134 |
+
with autocast(device_type='cuda'):
|
135 |
+
loss = process_one_batch(batch)
|
136 |
+
scaler.scale(loss).backward()
|
137 |
+
total_train_loss += loss.item()
|
138 |
+
scaler.step(optimizer)
|
139 |
+
scaler.update()
|
140 |
+
|
141 |
+
lr_scheduler.step()
|
142 |
+
model.zero_grad(set_to_none=True)
|
143 |
+
tqdm_train_set.set_postfix({str(global_rank)+'_train_loss': total_train_loss / iter_idx})
|
144 |
+
train_steps += 1
|
145 |
+
|
146 |
+
# Log the training loss to wandb
|
147 |
+
if global_rank==0 and M3_WANDB_LOG:
|
148 |
+
wandb.log({"train_loss": total_train_loss / iter_idx}, step=train_steps)
|
149 |
+
|
150 |
+
iter_idx += 1
|
151 |
+
if iter_idx % 1000 == 0:
|
152 |
+
clear_unused_tensors()
|
153 |
+
|
154 |
+
return total_train_loss / (iter_idx-1)
|
155 |
+
|
156 |
+
# do one epoch for eval
|
157 |
+
def eval_epoch():
|
158 |
+
tqdm_eval_set = tqdm(eval_set)
|
159 |
+
total_eval_loss = 0
|
160 |
+
iter_idx = 1
|
161 |
+
model.eval()
|
162 |
+
|
163 |
+
# Evaluate data for one epoch
|
164 |
+
for batch in tqdm_eval_set:
|
165 |
+
with torch.no_grad():
|
166 |
+
loss = process_one_batch(batch)
|
167 |
+
|
168 |
+
total_eval_loss += loss.item()
|
169 |
+
tqdm_eval_set.set_postfix({str(global_rank)+'_eval_loss': total_eval_loss / iter_idx})
|
170 |
+
iter_idx += 1
|
171 |
+
|
172 |
+
return total_eval_loss / (iter_idx-1)
|
173 |
+
|
174 |
+
# train and eval
|
175 |
+
if __name__ == "__main__":
|
176 |
+
|
177 |
+
# Set up distributed training
|
178 |
+
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
179 |
+
global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
|
180 |
+
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
181 |
+
|
182 |
+
if world_size > 1:
|
183 |
+
torch.cuda.set_device(local_rank)
|
184 |
+
device = torch.device("cuda", local_rank)
|
185 |
+
dist.init_process_group(backend='nccl')
|
186 |
+
else:
|
187 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
188 |
+
|
189 |
+
if M3_DETERMINISTIC:
|
190 |
+
seed = 42 + global_rank
|
191 |
+
random.seed(seed)
|
192 |
+
np.random.seed(seed)
|
193 |
+
torch.manual_seed(seed)
|
194 |
+
torch.cuda.manual_seed_all(seed)
|
195 |
+
torch.backends.cudnn.deterministic = True
|
196 |
+
torch.backends.cudnn.benchmark = False
|
197 |
+
|
198 |
+
encoder_config = BertConfig(vocab_size=1,
|
199 |
+
hidden_size=M3_HIDDEN_SIZE,
|
200 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
201 |
+
num_attention_heads=M3_HIDDEN_SIZE//64,
|
202 |
+
intermediate_size=M3_HIDDEN_SIZE*4,
|
203 |
+
max_position_embeddings=PATCH_LENGTH)
|
204 |
+
decoder_config = GPT2Config(vocab_size=128,
|
205 |
+
n_positions=PATCH_SIZE,
|
206 |
+
n_embd=M3_HIDDEN_SIZE,
|
207 |
+
n_layer=TOKEN_NUM_LAYERS,
|
208 |
+
n_head=M3_HIDDEN_SIZE//64,
|
209 |
+
n_inner=M3_HIDDEN_SIZE*4)
|
210 |
+
model = M3Model(encoder_config, decoder_config)
|
211 |
+
model = model.to(device)
|
212 |
+
|
213 |
+
# print parameter number
|
214 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
215 |
+
|
216 |
+
if world_size > 1:
|
217 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
218 |
+
|
219 |
+
scaler = GradScaler()
|
220 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=M3_LEARNING_RATE)
|
221 |
+
|
222 |
+
if M3_WANDB_LOG and global_rank==0:
|
223 |
+
# Initialize wandb
|
224 |
+
if WANDB_KEY:
|
225 |
+
wandb.login(key=WANDB_KEY)
|
226 |
+
wandb.init(project="m3",
|
227 |
+
name=M3_WEIGHTS_PATH.replace("weights_", "").replace(".pth", ""))
|
228 |
+
|
229 |
+
# load filenames under train and eval folder
|
230 |
+
train_files = list_files_in_directory(TRAIN_FOLDERS)
|
231 |
+
eval_files = list_files_in_directory(EVAL_FOLDERS)
|
232 |
+
|
233 |
+
if len(eval_files)==0:
|
234 |
+
train_files, eval_files = split_data(train_files)
|
235 |
+
|
236 |
+
train_batch_nums = int(len(train_files) / M3_BATCH_SIZE)
|
237 |
+
eval_batch_nums = int(len(eval_files) / M3_BATCH_SIZE)
|
238 |
+
|
239 |
+
train_files = train_files[:train_batch_nums*M3_BATCH_SIZE]
|
240 |
+
eval_files = eval_files[:eval_batch_nums*M3_BATCH_SIZE]
|
241 |
+
|
242 |
+
train_set = M3Dataset(train_files, 'train')
|
243 |
+
eval_set = M3Dataset(eval_files, 'eval')
|
244 |
+
|
245 |
+
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank)
|
246 |
+
eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank)
|
247 |
+
|
248 |
+
train_set = DataLoader(train_set, batch_size=M3_BATCH_SIZE, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
|
249 |
+
eval_set = DataLoader(eval_set, batch_size=M3_BATCH_SIZE, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
|
250 |
+
|
251 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = 1000)
|
252 |
+
|
253 |
+
if M3_LOAD_CKPT and os.path.exists(M3_WEIGHTS_PATH):
|
254 |
+
# Load checkpoint to CPU
|
255 |
+
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
256 |
+
|
257 |
+
# Here, model is assumed to be on GPU
|
258 |
+
# Load state dict to CPU model first, then move the model to GPU
|
259 |
+
if torch.cuda.device_count() > 1:
|
260 |
+
# If you have a DataParallel model, you need to load to model.module instead
|
261 |
+
cpu_model = deepcopy(model.module)
|
262 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
263 |
+
model.module.load_state_dict(cpu_model.state_dict())
|
264 |
+
else:
|
265 |
+
# Load to a CPU clone of the model, then load back
|
266 |
+
cpu_model = deepcopy(model)
|
267 |
+
cpu_model.load_state_dict(checkpoint['model'])
|
268 |
+
model.load_state_dict(cpu_model.state_dict())
|
269 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
270 |
+
lr_scheduler.load_state_dict(checkpoint['lr_sched'])
|
271 |
+
pre_epoch = checkpoint['epoch']
|
272 |
+
best_epoch = checkpoint['best_epoch']
|
273 |
+
min_eval_loss = checkpoint['min_eval_loss']
|
274 |
+
print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
275 |
+
checkpoint = None
|
276 |
+
|
277 |
+
else:
|
278 |
+
pre_epoch = 0
|
279 |
+
best_epoch = 0
|
280 |
+
min_eval_loss = float('inf')
|
281 |
+
|
282 |
+
model = model.to(device)
|
283 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=M3_LEARNING_RATE)
|
284 |
+
|
285 |
+
for epoch in range(1+pre_epoch, M3_NUM_EPOCH+1):
|
286 |
+
train_sampler.set_epoch(epoch)
|
287 |
+
eval_sampler.set_epoch(epoch)
|
288 |
+
print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
|
289 |
+
train_loss = train_epoch(epoch)
|
290 |
+
eval_loss = eval_epoch()
|
291 |
+
if global_rank==0:
|
292 |
+
with open(M3_LOGS_PATH,'a') as f:
|
293 |
+
f.write("Epoch " + str(epoch) + "\ntrain_loss: " + str(train_loss) + "\neval_loss: " +str(eval_loss) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
|
294 |
+
if eval_loss < min_eval_loss:
|
295 |
+
best_epoch = epoch
|
296 |
+
min_eval_loss = eval_loss
|
297 |
+
checkpoint = {
|
298 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
299 |
+
'optimizer': optimizer.state_dict(),
|
300 |
+
'lr_sched': lr_scheduler.state_dict(),
|
301 |
+
'epoch': epoch,
|
302 |
+
'best_epoch': best_epoch,
|
303 |
+
'min_eval_loss': min_eval_loss
|
304 |
+
}
|
305 |
+
torch.save(checkpoint, M3_WEIGHTS_PATH)
|
306 |
+
checkpoint = {
|
307 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
308 |
+
'optimizer': optimizer.state_dict(),
|
309 |
+
'lr_sched': lr_scheduler.state_dict(),
|
310 |
+
'epoch': epoch,
|
311 |
+
'best_epoch': best_epoch,
|
312 |
+
'min_eval_loss': min_eval_loss
|
313 |
+
}
|
314 |
+
torch.save(checkpoint, "latest_"+M3_WEIGHTS_PATH)
|
315 |
+
|
316 |
+
if world_size > 1:
|
317 |
+
dist.barrier()
|
318 |
+
|
319 |
+
if global_rank==0:
|
320 |
+
print("Best Eval Epoch : "+str(best_epoch))
|
321 |
+
print("Min Eval Loss : "+str(min_eval_loss))
|
code/utils.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from config import *
|
7 |
+
from unidecode import unidecode
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config
|
10 |
+
|
11 |
+
try:
|
12 |
+
import torch.distributed.nn
|
13 |
+
from torch import distributed as dist
|
14 |
+
|
15 |
+
has_distributed = True
|
16 |
+
except ImportError:
|
17 |
+
has_distributed = False
|
18 |
+
|
19 |
+
try:
|
20 |
+
import horovod.torch as hvd
|
21 |
+
except ImportError:
|
22 |
+
hvd = None
|
23 |
+
|
24 |
+
class ClipLoss(torch.nn.Module):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
local_loss=False,
|
29 |
+
gather_with_grad=False,
|
30 |
+
cache_labels=False,
|
31 |
+
rank=0,
|
32 |
+
world_size=1,
|
33 |
+
use_horovod=False,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.local_loss = local_loss
|
37 |
+
self.gather_with_grad = gather_with_grad
|
38 |
+
self.cache_labels = cache_labels
|
39 |
+
self.rank = rank
|
40 |
+
self.world_size = world_size
|
41 |
+
self.use_horovod = use_horovod
|
42 |
+
|
43 |
+
# cache state
|
44 |
+
self.prev_num_logits = 0
|
45 |
+
self.labels = {}
|
46 |
+
|
47 |
+
def gather_features(
|
48 |
+
self,
|
49 |
+
image_features,
|
50 |
+
text_features,
|
51 |
+
local_loss=False,
|
52 |
+
gather_with_grad=False,
|
53 |
+
rank=0,
|
54 |
+
world_size=1,
|
55 |
+
use_horovod=False
|
56 |
+
):
|
57 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
58 |
+
if use_horovod:
|
59 |
+
assert hvd is not None, 'Please install horovod'
|
60 |
+
if gather_with_grad:
|
61 |
+
all_image_features = hvd.allgather(image_features)
|
62 |
+
all_text_features = hvd.allgather(text_features)
|
63 |
+
else:
|
64 |
+
with torch.no_grad():
|
65 |
+
all_image_features = hvd.allgather(image_features)
|
66 |
+
all_text_features = hvd.allgather(text_features)
|
67 |
+
if not local_loss:
|
68 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
69 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
70 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
71 |
+
gathered_image_features[rank] = image_features
|
72 |
+
gathered_text_features[rank] = text_features
|
73 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
74 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
75 |
+
else:
|
76 |
+
# We gather tensors from all gpus
|
77 |
+
if gather_with_grad:
|
78 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
79 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
80 |
+
else:
|
81 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
82 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
83 |
+
dist.all_gather(gathered_image_features, image_features)
|
84 |
+
dist.all_gather(gathered_text_features, text_features)
|
85 |
+
if not local_loss:
|
86 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
87 |
+
gathered_image_features[rank] = image_features
|
88 |
+
gathered_text_features[rank] = text_features
|
89 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
90 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
91 |
+
|
92 |
+
return all_image_features, all_text_features
|
93 |
+
|
94 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
95 |
+
# calculated ground-truth and cache if enabled
|
96 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
97 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
98 |
+
if self.world_size > 1 and self.local_loss:
|
99 |
+
labels = labels + num_logits * self.rank
|
100 |
+
if self.cache_labels:
|
101 |
+
self.labels[device] = labels
|
102 |
+
self.prev_num_logits = num_logits
|
103 |
+
else:
|
104 |
+
labels = self.labels[device]
|
105 |
+
return labels
|
106 |
+
|
107 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
108 |
+
if self.world_size > 1:
|
109 |
+
all_image_features, all_text_features = self.gather_features(
|
110 |
+
image_features, text_features,
|
111 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
112 |
+
|
113 |
+
if self.local_loss:
|
114 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
115 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
116 |
+
else:
|
117 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
118 |
+
logits_per_text = logits_per_image.T
|
119 |
+
else:
|
120 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
121 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
122 |
+
|
123 |
+
return logits_per_image, logits_per_text
|
124 |
+
|
125 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
126 |
+
device = image_features.device
|
127 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
128 |
+
|
129 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
130 |
+
|
131 |
+
total_loss = (
|
132 |
+
F.cross_entropy(logits_per_image, labels) +
|
133 |
+
F.cross_entropy(logits_per_text, labels)
|
134 |
+
) / 2
|
135 |
+
|
136 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
137 |
+
|
138 |
+
class M3Patchilizer:
|
139 |
+
def __init__(self):
|
140 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
141 |
+
self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
|
142 |
+
self.pad_token_id = 0
|
143 |
+
self.bos_token_id = 1
|
144 |
+
self.eos_token_id = 2
|
145 |
+
self.mask_token_id = 3
|
146 |
+
|
147 |
+
def split_bars(self, body):
|
148 |
+
bars = re.split(self.regexPattern, ''.join(body))
|
149 |
+
bars = list(filter(None, bars)) # remove empty strings
|
150 |
+
if bars[0] in self.delimiters:
|
151 |
+
bars[1] = bars[0] + bars[1]
|
152 |
+
bars = bars[1:]
|
153 |
+
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
|
154 |
+
return bars
|
155 |
+
|
156 |
+
def bar2patch(self, bar, patch_size=PATCH_SIZE):
|
157 |
+
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
|
158 |
+
patch = patch[:patch_size]
|
159 |
+
patch += [self.pad_token_id] * (patch_size - len(patch))
|
160 |
+
return patch
|
161 |
+
|
162 |
+
def patch2bar(self, patch):
|
163 |
+
return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch)
|
164 |
+
|
165 |
+
def encode(self,
|
166 |
+
item,
|
167 |
+
patch_size=PATCH_SIZE,
|
168 |
+
add_special_patches=False,
|
169 |
+
truncate=False,
|
170 |
+
random_truncate=False):
|
171 |
+
|
172 |
+
item = unidecode(item)
|
173 |
+
lines = re.findall(r'.*?\n|.*$', item)
|
174 |
+
lines = list(filter(None, lines)) # remove empty lines
|
175 |
+
|
176 |
+
patches = []
|
177 |
+
|
178 |
+
if lines[0].split(" ")[0] == "ticks_per_beat":
|
179 |
+
patch = ""
|
180 |
+
for line in lines:
|
181 |
+
if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2):
|
182 |
+
patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:])
|
183 |
+
else:
|
184 |
+
if patch:
|
185 |
+
patches.append(patch)
|
186 |
+
patch = line
|
187 |
+
if patch!="":
|
188 |
+
patches.append(patch)
|
189 |
+
else:
|
190 |
+
for line in lines:
|
191 |
+
if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')):
|
192 |
+
patches.append(line)
|
193 |
+
else:
|
194 |
+
bars = self.split_bars(line)
|
195 |
+
if bars:
|
196 |
+
bars[-1] += '\n'
|
197 |
+
patches.extend(bars)
|
198 |
+
|
199 |
+
if add_special_patches:
|
200 |
+
bos_patch = chr(self.bos_token_id) * patch_size
|
201 |
+
eos_patch = chr(self.eos_token_id) * patch_size
|
202 |
+
patches = [bos_patch] + patches + [eos_patch]
|
203 |
+
|
204 |
+
if len(patches) > PATCH_LENGTH and truncate:
|
205 |
+
choices = ["head", "tail", "middle"]
|
206 |
+
choice = random.choice(choices)
|
207 |
+
if choice=="head" or random_truncate==False:
|
208 |
+
patches = patches[:PATCH_LENGTH]
|
209 |
+
elif choice=="tail":
|
210 |
+
patches = patches[-PATCH_LENGTH:]
|
211 |
+
else:
|
212 |
+
start = random.randint(1, len(patches)-PATCH_LENGTH)
|
213 |
+
patches = patches[start:start+PATCH_LENGTH]
|
214 |
+
|
215 |
+
patches = [self.bar2patch(patch) for patch in patches]
|
216 |
+
|
217 |
+
return patches
|
218 |
+
|
219 |
+
def decode(self, patches):
|
220 |
+
return ''.join(self.patch2bar(patch) for patch in patches)
|
221 |
+
|
222 |
+
class M3PatchEncoder(PreTrainedModel):
|
223 |
+
def __init__(self, config):
|
224 |
+
super(M3PatchEncoder, self).__init__(config)
|
225 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE)
|
226 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
227 |
+
self.base = BertModel(config=config)
|
228 |
+
self.pad_token_id = 0
|
229 |
+
self.bos_token_id = 1
|
230 |
+
self.eos_token_id = 2
|
231 |
+
self.mask_token_id = 3
|
232 |
+
|
233 |
+
def forward(self,
|
234 |
+
input_patches, # [batch_size, seq_length, hidden_size]
|
235 |
+
input_masks): # [batch_size, seq_length]
|
236 |
+
# Transform input_patches into embeddings
|
237 |
+
input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128)
|
238 |
+
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor)
|
239 |
+
input_patches = self.patch_embedding(input_patches.to(self.device))
|
240 |
+
|
241 |
+
# Apply BERT model to input_patches and input_masks
|
242 |
+
return self.base(inputs_embeds=input_patches, attention_mask=input_masks)
|
243 |
+
|
244 |
+
class M3TokenDecoder(PreTrainedModel):
|
245 |
+
def __init__(self, config):
|
246 |
+
super(M3TokenDecoder, self).__init__(config)
|
247 |
+
self.base = GPT2LMHeadModel(config=config)
|
248 |
+
self.pad_token_id = 0
|
249 |
+
self.bos_token_id = 1
|
250 |
+
self.eos_token_id = 2
|
251 |
+
self.mask_token_id = 3
|
252 |
+
|
253 |
+
def forward(self,
|
254 |
+
patch_features, # [batch_size, hidden_size]
|
255 |
+
target_patches): # [batch_size, seq_length]
|
256 |
+
# get input embeddings
|
257 |
+
inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
|
258 |
+
|
259 |
+
# concatenate the encoded patches with the input embeddings
|
260 |
+
inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
|
261 |
+
|
262 |
+
# preparing the labels for model training
|
263 |
+
target_masks = target_patches == self.pad_token_id
|
264 |
+
target_patches = target_patches.clone().masked_fill_(target_masks, -100)
|
265 |
+
|
266 |
+
# get the attention mask
|
267 |
+
target_masks = ~target_masks
|
268 |
+
target_masks = target_masks.type(torch.int)
|
269 |
+
|
270 |
+
return self.base(inputs_embeds=inputs_embeds,
|
271 |
+
attention_mask=target_masks,
|
272 |
+
labels=target_patches)
|
273 |
+
|
274 |
+
def generate(self,
|
275 |
+
patch_feature,
|
276 |
+
tokens):
|
277 |
+
# reshape the patch_feature and tokens
|
278 |
+
patch_feature = patch_feature.reshape(1, 1, -1)
|
279 |
+
tokens = tokens.reshape(1, -1)
|
280 |
+
|
281 |
+
# get input embeddings
|
282 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
283 |
+
|
284 |
+
# concatenate the encoded patches with the input embeddings
|
285 |
+
tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1)
|
286 |
+
|
287 |
+
# get the outputs from the model
|
288 |
+
outputs = self.base(inputs_embeds=tokens)
|
289 |
+
|
290 |
+
# get the probabilities of the next token
|
291 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
292 |
+
|
293 |
+
return probs.detach().cpu().numpy()
|
294 |
+
|
295 |
+
class M3Model(PreTrainedModel):
|
296 |
+
def __init__(self, encoder_config, decoder_config):
|
297 |
+
super(M3Model, self).__init__(encoder_config)
|
298 |
+
self.encoder = M3PatchEncoder(encoder_config)
|
299 |
+
self.decoder = M3TokenDecoder(decoder_config)
|
300 |
+
self.pad_token_id = 0
|
301 |
+
self.bos_token_id = 1
|
302 |
+
self.eos_token_id = 2
|
303 |
+
self.mask_token_id = 3
|
304 |
+
|
305 |
+
def forward(self,
|
306 |
+
input_patches, # [batch_size, seq_length, hidden_size]
|
307 |
+
input_masks, # [batch_size, seq_length]
|
308 |
+
selected_indices, # [batch_size, seq_length]
|
309 |
+
target_patches): # [batch_size, seq_length, hidden_size]
|
310 |
+
input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device)
|
311 |
+
input_masks = input_masks.to(self.device)
|
312 |
+
selected_indices = selected_indices.to(self.device)
|
313 |
+
target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device)
|
314 |
+
|
315 |
+
# Pass the input_patches and input_masks through the encoder
|
316 |
+
outputs = self.encoder(input_patches, input_masks)["last_hidden_state"]
|
317 |
+
|
318 |
+
# Use selected_indices to form target_patches
|
319 |
+
target_patches = target_patches[selected_indices.bool()]
|
320 |
+
patch_features = outputs[selected_indices.bool()]
|
321 |
+
|
322 |
+
# Pass patch_features and target_patches through the decoder
|
323 |
+
return self.decoder(patch_features, target_patches)
|
324 |
+
|
325 |
+
class CLaMP2Model(PreTrainedModel):
|
326 |
+
def __init__(self,
|
327 |
+
music_config,
|
328 |
+
global_rank=None,
|
329 |
+
world_size=None,
|
330 |
+
text_model_name=TEXT_MODEL_NAME,
|
331 |
+
hidden_size=CLAMP2_HIDDEN_SIZE,
|
332 |
+
load_m3=CLAMP2_LOAD_M3):
|
333 |
+
super(CLaMP2Model, self).__init__(music_config)
|
334 |
+
|
335 |
+
self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model
|
336 |
+
self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections
|
337 |
+
torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution
|
338 |
+
|
339 |
+
self.music_model = M3PatchEncoder(music_config) # Initialize the music model
|
340 |
+
self.music_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for music projections
|
341 |
+
torch.nn.init.normal_(self.music_proj.weight, std=0.02) # Initialize weights with normal distribution
|
342 |
+
|
343 |
+
if global_rank==None or world_size==None:
|
344 |
+
global_rank = 0
|
345 |
+
world_size = 1
|
346 |
+
|
347 |
+
self.loss_fn = ClipLoss(local_loss=False,
|
348 |
+
gather_with_grad=True,
|
349 |
+
cache_labels=False,
|
350 |
+
rank=global_rank,
|
351 |
+
world_size=world_size,
|
352 |
+
use_horovod=False)
|
353 |
+
|
354 |
+
if load_m3 and os.path.exists(M3_WEIGHTS_PATH):
|
355 |
+
checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
|
356 |
+
decoder_config = GPT2Config(vocab_size=128,
|
357 |
+
n_positions=PATCH_SIZE,
|
358 |
+
n_embd=M3_HIDDEN_SIZE,
|
359 |
+
n_layer=TOKEN_NUM_LAYERS,
|
360 |
+
n_head=M3_HIDDEN_SIZE//64,
|
361 |
+
n_inner=M3_HIDDEN_SIZE*4)
|
362 |
+
model = M3Model(music_config, decoder_config)
|
363 |
+
model.load_state_dict(checkpoint['model'])
|
364 |
+
self.music_model = model.encoder
|
365 |
+
model = None
|
366 |
+
print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
|
367 |
+
|
368 |
+
def avg_pooling(self, input_features, input_masks):
|
369 |
+
input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension
|
370 |
+
input_features = input_features * input_masks # apply mask to input_features
|
371 |
+
avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling
|
372 |
+
|
373 |
+
return avg_pool
|
374 |
+
|
375 |
+
def get_text_features(self,
|
376 |
+
text_inputs,
|
377 |
+
text_masks,
|
378 |
+
get_normalized=False):
|
379 |
+
text_features = self.text_model(text_inputs.to(self.device),
|
380 |
+
attention_mask=text_masks.to(self.device))['last_hidden_state']
|
381 |
+
|
382 |
+
if get_normalized:
|
383 |
+
text_features = self.avg_pooling(text_features, text_masks)
|
384 |
+
text_features = self.text_proj(text_features)
|
385 |
+
|
386 |
+
return text_features
|
387 |
+
|
388 |
+
def get_music_features(self,
|
389 |
+
music_inputs,
|
390 |
+
music_masks,
|
391 |
+
get_normalized=False):
|
392 |
+
music_features = self.music_model(music_inputs.to(self.device),
|
393 |
+
music_masks.to(self.device))['last_hidden_state']
|
394 |
+
|
395 |
+
if get_normalized:
|
396 |
+
music_features = self.avg_pooling(music_features, music_masks)
|
397 |
+
music_features = self.music_proj(music_features)
|
398 |
+
|
399 |
+
return music_features
|
400 |
+
|
401 |
+
def forward(self,
|
402 |
+
text_inputs, # [batch_size, seq_length]
|
403 |
+
text_masks, # [batch_size, seq_length]
|
404 |
+
music_inputs, # [batch_size, seq_length, hidden_size]
|
405 |
+
music_masks): # [batch_size, seq_length]
|
406 |
+
# Compute the text features
|
407 |
+
text_features = self.get_text_features(text_inputs, text_masks, get_normalized=True)
|
408 |
+
|
409 |
+
# Compute the music features
|
410 |
+
music_features = self.get_music_features(music_inputs, music_masks, get_normalized=True)
|
411 |
+
|
412 |
+
return self.loss_fn(text_features,
|
413 |
+
music_features,
|
414 |
+
LOGIT_SCALE,
|
415 |
+
output_dict=False)
|
416 |
+
|
417 |
+
def split_data(data, eval_ratio=EVAL_SPLIT):
|
418 |
+
random.shuffle(data)
|
419 |
+
split_idx = int(len(data)*eval_ratio)
|
420 |
+
eval_set = data[:split_idx]
|
421 |
+
train_set = data[split_idx:]
|
422 |
+
return train_set, eval_set
|
423 |
+
|
424 |
+
def mask_patches(target_patches, patchilizer, mode):
|
425 |
+
indices = list(range(len(target_patches)))
|
426 |
+
random.shuffle(indices)
|
427 |
+
selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))]
|
428 |
+
sorted_indices = sorted(selected_indices)
|
429 |
+
input_patches = torch.tensor(target_patches)
|
430 |
+
|
431 |
+
if mode=="eval":
|
432 |
+
choice = "original"
|
433 |
+
else:
|
434 |
+
choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0]
|
435 |
+
|
436 |
+
if choice=="mask":
|
437 |
+
input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE)
|
438 |
+
elif choice=="shuffle":
|
439 |
+
for idx in sorted_indices:
|
440 |
+
patch = input_patches[idx]
|
441 |
+
try:
|
442 |
+
index_eos = (patch == patchilizer.eos_token_id).nonzero().item()
|
443 |
+
except:
|
444 |
+
index_eos = len(patch)
|
445 |
+
|
446 |
+
indices = list(range(1, index_eos))
|
447 |
+
random.shuffle(indices)
|
448 |
+
indices = [0] + indices + list(range(index_eos, len(patch)))
|
449 |
+
input_patches[idx] = patch[indices]
|
450 |
+
|
451 |
+
selected_indices = torch.zeros(len(target_patches))
|
452 |
+
selected_indices[sorted_indices] = 1.
|
453 |
+
|
454 |
+
return input_patches, selected_indices
|
455 |
+
|
456 |
+
def remove_instrument_info(item):
|
457 |
+
# remove instrument information from symbolic music
|
458 |
+
lines = re.findall(r'.*?\n|.*$', item)
|
459 |
+
lines = list(filter(None, lines))
|
460 |
+
if lines[0].split(" ")[0] == "ticks_per_beat":
|
461 |
+
type = "mtf"
|
462 |
+
else:
|
463 |
+
type = "abc"
|
464 |
+
|
465 |
+
cleaned_lines = []
|
466 |
+
for line in lines:
|
467 |
+
if type=="abc" and line.startswith("V:"):
|
468 |
+
# find the position of " nm=" or " snm="
|
469 |
+
nm_pos = line.find(" nm=")
|
470 |
+
snm_pos = line.find(" snm=")
|
471 |
+
# keep the part before " nm=" or " snm="
|
472 |
+
if nm_pos != -1:
|
473 |
+
line = line[:nm_pos]
|
474 |
+
elif snm_pos != -1:
|
475 |
+
line = line[:snm_pos]
|
476 |
+
if nm_pos != -1 or snm_pos != -1:
|
477 |
+
line += "\n"
|
478 |
+
elif type=="mtf" and line.startswith("program_change"):
|
479 |
+
line = " ".join(line.split(" ")[:-1]) + " 0\n"
|
480 |
+
|
481 |
+
cleaned_lines.append(line)
|
482 |
+
|
483 |
+
return ''.join(cleaned_lines)
|
environment.yml
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: clamp2
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- blas=1.0=mkl
|
8 |
+
- brotli-python=1.0.9=py310hd77b12b_8
|
9 |
+
- bzip2=1.0.8=h2bbff1b_6
|
10 |
+
- ca-certificates=2024.7.2=haa95532_0
|
11 |
+
- cuda-cccl=12.6.37=0
|
12 |
+
- cuda-cccl_win-64=12.6.37=0
|
13 |
+
- cuda-cudart=11.8.89=0
|
14 |
+
- cuda-cudart-dev=11.8.89=0
|
15 |
+
- cuda-cupti=11.8.87=0
|
16 |
+
- cuda-libraries=11.8.0=0
|
17 |
+
- cuda-libraries-dev=11.8.0=0
|
18 |
+
- cuda-nvrtc=11.8.89=0
|
19 |
+
- cuda-nvrtc-dev=11.8.89=0
|
20 |
+
- cuda-nvtx=11.8.86=0
|
21 |
+
- cuda-profiler-api=12.6.68=0
|
22 |
+
- cuda-runtime=11.8.0=0
|
23 |
+
- cuda-version=12.6=3
|
24 |
+
- freetype=2.12.1=ha860e81_0
|
25 |
+
- gmpy2=2.1.2=py310h7f96b67_0
|
26 |
+
- intel-openmp=2023.1.0=h59b6b97_46320
|
27 |
+
- jinja2=3.1.4=py310haa95532_0
|
28 |
+
- jpeg=9e=h827c3e9_3
|
29 |
+
- lcms2=2.12=h83e58a3_0
|
30 |
+
- lerc=3.0=hd77b12b_0
|
31 |
+
- libcublas=11.11.3.6=0
|
32 |
+
- libcublas-dev=11.11.3.6=0
|
33 |
+
- libcufft=10.9.0.58=0
|
34 |
+
- libcufft-dev=10.9.0.58=0
|
35 |
+
- libcurand=10.3.7.68=0
|
36 |
+
- libcurand-dev=10.3.7.68=0
|
37 |
+
- libcusolver=11.4.1.48=0
|
38 |
+
- libcusolver-dev=11.4.1.48=0
|
39 |
+
- libcusparse=11.7.5.86=0
|
40 |
+
- libcusparse-dev=11.7.5.86=0
|
41 |
+
- libdeflate=1.17=h2bbff1b_1
|
42 |
+
- libffi=3.4.4=hd77b12b_1
|
43 |
+
- libjpeg-turbo=2.0.0=h196d8e1_0
|
44 |
+
- libnpp=11.8.0.86=0
|
45 |
+
- libnpp-dev=11.8.0.86=0
|
46 |
+
- libnvjpeg=11.9.0.86=0
|
47 |
+
- libnvjpeg-dev=11.9.0.86=0
|
48 |
+
- libpng=1.6.39=h8cc25b3_0
|
49 |
+
- libtiff=4.5.1=hd77b12b_0
|
50 |
+
- libuv=1.48.0=h827c3e9_0
|
51 |
+
- libwebp-base=1.3.2=h2bbff1b_0
|
52 |
+
- lz4-c=1.9.4=h2bbff1b_1
|
53 |
+
- mkl=2023.1.0=h6b88ed4_46358
|
54 |
+
- mkl-service=2.4.0=py310h2bbff1b_1
|
55 |
+
- mkl_fft=1.3.8=py310h2bbff1b_0
|
56 |
+
- mkl_random=1.2.4=py310h59b6b97_0
|
57 |
+
- mpc=1.1.0=h7edee0f_1
|
58 |
+
- mpfr=4.0.2=h62dcd97_1
|
59 |
+
- mpir=3.0.0=hec2e145_1
|
60 |
+
- mpmath=1.3.0=py310haa95532_0
|
61 |
+
- networkx=3.3=py310haa95532_0
|
62 |
+
- numpy=1.26.4=py310h055cbcc_0
|
63 |
+
- numpy-base=1.26.4=py310h65a83cf_0
|
64 |
+
- openjpeg=2.5.2=hae555c5_0
|
65 |
+
- openssl=3.0.14=h827c3e9_0
|
66 |
+
- pip=24.2=py310haa95532_0
|
67 |
+
- pysocks=1.7.1=py310haa95532_0
|
68 |
+
- python=3.10.14=he1021f5_1
|
69 |
+
- pytorch=2.4.0=py3.10_cuda11.8_cudnn9_0
|
70 |
+
- pytorch-cuda=11.8=h24eeafa_5
|
71 |
+
- pytorch-mutex=1.0=cuda
|
72 |
+
- pyyaml=6.0.1=py310h2bbff1b_0
|
73 |
+
- requests=2.32.3=py310haa95532_0
|
74 |
+
- setuptools=72.1.0=py310haa95532_0
|
75 |
+
- sqlite=3.45.3=h2bbff1b_0
|
76 |
+
- sympy=1.13.2=py310haa95532_0
|
77 |
+
- tbb=2021.8.0=h59b6b97_0
|
78 |
+
- tk=8.6.14=h0416ee5_0
|
79 |
+
- typing_extensions=4.11.0=py310haa95532_0
|
80 |
+
- tzdata=2024a=h04d1e81_0
|
81 |
+
- vc=14.40=h2eaa2aa_0
|
82 |
+
- vs2015_runtime=14.40.33807=h98bb1dd_0
|
83 |
+
- wheel=0.43.0=py310haa95532_0
|
84 |
+
- win_inet_pton=1.1.0=py310haa95532_0
|
85 |
+
- xz=5.4.6=h8cc25b3_1
|
86 |
+
- yaml=0.2.5=he774522_0
|
87 |
+
- zlib=1.2.13=h8cc25b3_1
|
88 |
+
- zstd=1.5.5=hd43e919_2
|
89 |
+
- pip:
|
90 |
+
- abctoolkit==0.0.4
|
91 |
+
- accelerate==0.34.0
|
92 |
+
- aiohappyeyeballs==2.4.0
|
93 |
+
- aiohttp==3.10.5
|
94 |
+
- aiosignal==1.3.1
|
95 |
+
- annotated-types==0.7.0
|
96 |
+
- anyio==4.6.2.post1
|
97 |
+
- async-timeout==4.0.3
|
98 |
+
- attrs==24.2.0
|
99 |
+
- audioread==3.0.1
|
100 |
+
- certifi==2023.7.22
|
101 |
+
- cffi==1.17.0
|
102 |
+
- chardet==5.2.0
|
103 |
+
- charset-normalizer==3.2.0
|
104 |
+
- click==8.1.7
|
105 |
+
- colorama==0.4.6
|
106 |
+
- coloredlogs==15.0.1
|
107 |
+
- cycler==0.11.0
|
108 |
+
- datasets==2.21.0
|
109 |
+
- decorator==5.1.1
|
110 |
+
- dill==0.3.8
|
111 |
+
- distro==1.9.0
|
112 |
+
- docker-pycreds==0.4.0
|
113 |
+
- exceptiongroup==1.2.2
|
114 |
+
- filelock==3.12.2
|
115 |
+
- fonttools==4.38.0
|
116 |
+
- frozenlist==1.4.1
|
117 |
+
- fsspec==2024.6.1
|
118 |
+
- gitdb==4.0.11
|
119 |
+
- gitpython==3.1.43
|
120 |
+
- h11==0.14.0
|
121 |
+
- httpcore==1.0.6
|
122 |
+
- httpx==0.27.2
|
123 |
+
- huggingface-hub==0.24.6
|
124 |
+
- humanfriendly==10.0
|
125 |
+
- idna==3.4
|
126 |
+
- importlib-metadata==6.7.0
|
127 |
+
- jellyfish==1.0.0
|
128 |
+
- jiter==0.6.1
|
129 |
+
- joblib==1.3.2
|
130 |
+
- jsonpickle==3.0.2
|
131 |
+
- kiwisolver==1.4.4
|
132 |
+
- langcodes==3.4.0
|
133 |
+
- langdetect==1.0.9
|
134 |
+
- langid==1.1.6
|
135 |
+
- language-data==1.2.0
|
136 |
+
- lazy-loader==0.4
|
137 |
+
- levenshtein==0.25.1
|
138 |
+
- librosa==0.10.1
|
139 |
+
- llvmlite==0.43.0
|
140 |
+
- lxml==5.3.0
|
141 |
+
- marisa-trie==1.2.0
|
142 |
+
- markupsafe==2.1.5
|
143 |
+
- matplotlib==3.5.3
|
144 |
+
- mido==1.3.0
|
145 |
+
- more-itertools==9.1.0
|
146 |
+
- msgpack==1.0.8
|
147 |
+
- multidict==6.0.5
|
148 |
+
- multiprocess==0.70.16
|
149 |
+
- music21==7.3.3
|
150 |
+
- nltk==3.8.1
|
151 |
+
- numba==0.60.0
|
152 |
+
- openai==1.51.2
|
153 |
+
- optimum==1.21.4
|
154 |
+
- packaging==23.1
|
155 |
+
- pandas==1.3.5
|
156 |
+
- pillow==9.5.0
|
157 |
+
- platformdirs==4.2.2
|
158 |
+
- pooch==1.8.2
|
159 |
+
- portalocker==2.10.1
|
160 |
+
- protobuf==5.28.0
|
161 |
+
- psutil==6.0.0
|
162 |
+
- pyarrow==17.0.0
|
163 |
+
- pycparser==2.22
|
164 |
+
- pydantic==2.9.2
|
165 |
+
- pydantic-core==2.23.4
|
166 |
+
- pydub==0.25.1
|
167 |
+
- pyparsing==3.1.1
|
168 |
+
- pyreadline3==3.4.1
|
169 |
+
- python-dateutil==2.8.2
|
170 |
+
- pytz==2023.3
|
171 |
+
- pywin32==306
|
172 |
+
- rapidfuzz==3.9.7
|
173 |
+
- rarfile==4.1
|
174 |
+
- regex==2023.8.8
|
175 |
+
- sacrebleu==2.4.3
|
176 |
+
- sacremoses==0.0.53
|
177 |
+
- safetensors==0.4.4
|
178 |
+
- samplings==0.1.7
|
179 |
+
- scikit-learn==1.5.1
|
180 |
+
- scipy==1.14.1
|
181 |
+
- sentencepiece==0.2.0
|
182 |
+
- sentry-sdk==2.13.0
|
183 |
+
- setproctitle==1.3.3
|
184 |
+
- six==1.16.0
|
185 |
+
- smmap==5.0.1
|
186 |
+
- sniffio==1.3.1
|
187 |
+
- soundfile==0.12.1
|
188 |
+
- soxr==0.5.0.post1
|
189 |
+
- tabulate==0.9.0
|
190 |
+
- threadpoolctl==3.5.0
|
191 |
+
- tokenizers==0.19.1
|
192 |
+
- torch==2.4.0
|
193 |
+
- torchaudio==2.4.0
|
194 |
+
- torchvision==0.19.0
|
195 |
+
- tqdm==4.66.5
|
196 |
+
- transformers==4.40.0
|
197 |
+
- typing-extensions==4.12.2
|
198 |
+
- unidecode==1.3.6
|
199 |
+
- urllib3==2.0.4
|
200 |
+
- wandb==0.17.8
|
201 |
+
- webcolors==1.13
|
202 |
+
- xxhash==3.5.0
|
203 |
+
- yarl==1.9.7
|
204 |
+
- zipp==3.15.0
|
music_classification/README.md
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Music Classification Codebase
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
Linear Probe is a powerful classification tool that leverages feature representations for supervised learning tasks. This codebase includes scripts for training a linear classification model, performing classification on new feature data. The features utilized can be extracted from the M3 or CLaMP 2 models, ensuring that the time dimension information is preserved and **not normalized**. Below is a description of the scripts contained in the `music_classification/` folder.
|
5 |
+
|
6 |
+
## Repository Structure
|
7 |
+
The `music_classification/` folder contains the following scripts:
|
8 |
+
|
9 |
+
### 1. `config.py`
|
10 |
+
This script defines configurations for the linear probe training and inference, specifying training data paths and parameters like learning rate, number of epochs, and hidden size.
|
11 |
+
|
12 |
+
### 2. `inference_cls.py`
|
13 |
+
This script enables the classification of feature vectors using a pre-trained linear probe model.
|
14 |
+
|
15 |
+
#### JSON Output Format
|
16 |
+
The resulting JSON file contains a dictionary with the following structure:
|
17 |
+
```json
|
18 |
+
{
|
19 |
+
"path/to/feature1.npy": "class_A",
|
20 |
+
"path/to/feature2.npy": "class_B",
|
21 |
+
"path/to/feature3.npy": "class_A"
|
22 |
+
}
|
23 |
+
```
|
24 |
+
- **Key**: The path to the input feature file (e.g., `feature1.npy`).
|
25 |
+
- **Value**: The predicted class label assigned by the linear probe model (e.g., `class_A`).
|
26 |
+
|
27 |
+
#### Usage
|
28 |
+
```bash
|
29 |
+
python inference_cls.py <feature_folder> <output_file>
|
30 |
+
```
|
31 |
+
- `feature_folder`: Directory containing input feature files (in `.npy` format).
|
32 |
+
- `output_file`: File path to save the classification results (in JSON format).
|
33 |
+
|
34 |
+
### 3. `train_cls.py`
|
35 |
+
This script is designed for training the linear classification model.
|
36 |
+
|
37 |
+
#### Usage
|
38 |
+
```bash
|
39 |
+
python train_cls.py
|
40 |
+
```
|
41 |
+
|
42 |
+
### 4. `utils.py`
|
43 |
+
The utility script defines the architecture of the linear classification model.
|
44 |
+
|
45 |
+
## Naming Convention
|
46 |
+
All `.npy` files used in this codebase must follow the naming convention of `label_filename.npy`, where the filename should not contain any underscores (`_`).
|
music_classification/config.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Configuration for generative modelling and classification
|
2 |
+
TRAIN_FOLDERS = [
|
3 |
+
"<path_to_training_data>" # Directory containing training data
|
4 |
+
]
|
5 |
+
|
6 |
+
EVAL_FOLDERS = [
|
7 |
+
"" # (Optional) Directory containing evaluation data
|
8 |
+
]
|
9 |
+
|
10 |
+
EVAL_SPLIT = 0.2 # Fraction of training data to use for evaluation
|
11 |
+
|
12 |
+
# Weights and Biases configuration
|
13 |
+
WANDB_KEY = "<your_wandb_key>" # Set M3/CLaMP2_WANDB_LOG=False if no API key for Weights and Biases logging
|
14 |
+
|
15 |
+
# Model Configuration
|
16 |
+
INPUT_HIDDEN_SIZE = 768 # Input hidden size
|
17 |
+
HIDDEN_SIZE = 768 # Model hidden size
|
18 |
+
NUM_EPOCHS = 1000 # Max number of epochs to train (early stopping can terminate earlier)
|
19 |
+
LEARNING_RATE = 1e-5 # Optimizer learning rate
|
20 |
+
BALANCED_TRAINING = False # Set to True to balance labels in training data
|
21 |
+
WANDB_LOG = False # Set to True to log training metrics to WANDB
|
22 |
+
|
23 |
+
# Paths Configuration
|
24 |
+
last_folder_name = TRAIN_FOLDERS[-1].split('/')[-1]
|
25 |
+
WEIGHTS_PATH = f"weights-{last_folder_name}.pth" # Weights file path
|
26 |
+
LOGS_PATH = f"logs-{last_folder_name}.txt" # Log file path
|
music_classification/inference_cls.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from utils import *
|
7 |
+
from tqdm import tqdm
|
8 |
+
from samplings import *
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
def list_files_in_directory(directories, extensions=["npy"]):
|
12 |
+
file_list = []
|
13 |
+
|
14 |
+
for directory in directories:
|
15 |
+
for root, dirs, files in os.walk(directory):
|
16 |
+
for file in files:
|
17 |
+
if any(file.endswith(ext) for ext in extensions):
|
18 |
+
file_path = os.path.join(root, file)
|
19 |
+
file_list.append(file_path)
|
20 |
+
|
21 |
+
return file_list
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
# Setup argument parser
|
25 |
+
parser = argparse.ArgumentParser(description="Feature extraction and classification with CLaMP2.")
|
26 |
+
parser.add_argument("feature_folder", type=str, help="Directory containing input feature files.")
|
27 |
+
parser.add_argument("output_file", type=str, help="File to save the classification results. (format: json)")
|
28 |
+
|
29 |
+
# Parse arguments
|
30 |
+
args = parser.parse_args()
|
31 |
+
feature_folder = args.feature_folder
|
32 |
+
output_file = args.output_file
|
33 |
+
|
34 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
35 |
+
seed = 42
|
36 |
+
random.seed(seed)
|
37 |
+
np.random.seed(seed)
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
torch.cuda.manual_seed_all(seed)
|
40 |
+
torch.backends.cudnn.deterministic = True
|
41 |
+
torch.backends.cudnn.benchmark = False
|
42 |
+
|
43 |
+
checkpoint = torch.load(WEIGHTS_PATH, map_location='cpu')
|
44 |
+
print(f"Successfully Loaded Checkpoint from Epoch {checkpoint['epoch']} with acc {checkpoint['max_eval_acc']}")
|
45 |
+
label2idx = checkpoint['labels']
|
46 |
+
idx2label = {idx: label for label, idx in label2idx.items()} # Create reverse mapping
|
47 |
+
model = LinearClassification(num_classes=len(label2idx))
|
48 |
+
model = model.to(device)
|
49 |
+
|
50 |
+
# print parameter number
|
51 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
52 |
+
|
53 |
+
model.eval()
|
54 |
+
model.load_state_dict(checkpoint['model'])
|
55 |
+
|
56 |
+
# load filenames under train and eval folder
|
57 |
+
feature_files = list_files_in_directory([feature_folder])
|
58 |
+
cls_results = {}
|
59 |
+
|
60 |
+
for filepath in tqdm(feature_files):
|
61 |
+
outputs = np.load(filepath)[0]
|
62 |
+
outputs = torch.from_numpy(outputs).to(device)
|
63 |
+
outputs = outputs.unsqueeze(0)
|
64 |
+
cls_list = model(outputs)[0].tolist()
|
65 |
+
max_prob = max(cls_list)
|
66 |
+
cls_idx = cls_list.index(max_prob)
|
67 |
+
cls_label = idx2label[cls_idx]
|
68 |
+
cls_results[filepath] = cls_label
|
69 |
+
|
70 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
71 |
+
json.dump(cls_results, f)
|
music_classification/train_cls.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import math
|
4 |
+
import wandb
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
from utils import *
|
9 |
+
from config import *
|
10 |
+
from tqdm import tqdm
|
11 |
+
from sklearn.metrics import f1_score
|
12 |
+
from torch.amp import autocast, GradScaler
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
from transformers import get_constant_schedule_with_warmup
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
17 |
+
from torch.utils.data.distributed import DistributedSampler
|
18 |
+
|
19 |
+
# Set up distributed training
|
20 |
+
world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
21 |
+
global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else 0
|
22 |
+
local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
|
23 |
+
|
24 |
+
if world_size > 1:
|
25 |
+
torch.cuda.set_device(local_rank)
|
26 |
+
device = torch.device("cuda", local_rank)
|
27 |
+
dist.init_process_group(backend='nccl') if world_size > 1 else None
|
28 |
+
else:
|
29 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
30 |
+
|
31 |
+
# Set random seed
|
32 |
+
seed = 42 + global_rank
|
33 |
+
random.seed(seed)
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
torch.backends.cudnn.deterministic = True
|
38 |
+
torch.backends.cudnn.benchmark = False
|
39 |
+
|
40 |
+
batch_size = 1
|
41 |
+
|
42 |
+
def collate_batch(input_tensors):
|
43 |
+
|
44 |
+
input_tensors, labels = zip(*input_tensors)
|
45 |
+
input_tensors = torch.stack(input_tensors, dim=0)
|
46 |
+
labels = torch.stack(labels, dim=0)
|
47 |
+
|
48 |
+
return input_tensors.to(device), labels.to(device)
|
49 |
+
|
50 |
+
def list_files_in_directory(directories):
|
51 |
+
file_list = []
|
52 |
+
|
53 |
+
for directory in directories:
|
54 |
+
for root, dirs, files in os.walk(directory):
|
55 |
+
for file in files:
|
56 |
+
if file.endswith(".npy"):
|
57 |
+
file_path = os.path.join(root, file)
|
58 |
+
file_list.append(file_path)
|
59 |
+
return file_list
|
60 |
+
|
61 |
+
class TensorDataset(Dataset):
|
62 |
+
def __init__(self, filenames):
|
63 |
+
print(f"Loading {len(filenames)} files for classification")
|
64 |
+
self.filenames = []
|
65 |
+
self.label2idx = {}
|
66 |
+
|
67 |
+
for filename in tqdm(filenames):
|
68 |
+
label = os.path.basename(filename).split('_')[0]
|
69 |
+
|
70 |
+
self.filenames.append(filename)
|
71 |
+
if label not in self.label2idx:
|
72 |
+
self.label2idx[label] = len(self.label2idx)
|
73 |
+
print(f"Found {len(self.label2idx)} classes")
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.filenames)
|
77 |
+
|
78 |
+
def __getitem__(self, idx):
|
79 |
+
|
80 |
+
filename = self.filenames[idx]
|
81 |
+
label = os.path.basename(filename).split('_')[0]
|
82 |
+
label = self.label2idx[label]
|
83 |
+
|
84 |
+
# load numpy file
|
85 |
+
data = np.load(filename)
|
86 |
+
data = torch.from_numpy(data)[0]
|
87 |
+
label = torch.tensor(label)
|
88 |
+
|
89 |
+
return data, label
|
90 |
+
|
91 |
+
class BalancedTensorDataset(Dataset):
|
92 |
+
def __init__(self, filenames):
|
93 |
+
print(f"Loading {len(filenames)} files for classification")
|
94 |
+
self.filenames = filenames
|
95 |
+
self.label2idx = {}
|
96 |
+
self.label2files = {}
|
97 |
+
|
98 |
+
for filename in tqdm(filenames):
|
99 |
+
label = os.path.basename(filename).split('_')[0]
|
100 |
+
if label not in self.label2idx:
|
101 |
+
self.label2idx[label] = len(self.label2idx)
|
102 |
+
if label not in self.label2files:
|
103 |
+
self.label2files[label] = []
|
104 |
+
self.label2files[label].append(filename)
|
105 |
+
print(f"Found {len(self.label2idx)} classes")
|
106 |
+
|
107 |
+
self.min_samples = min(len(files) for files in self.label2files.values())
|
108 |
+
|
109 |
+
self._update_epoch_filenames()
|
110 |
+
|
111 |
+
def _update_epoch_filenames(self):
|
112 |
+
self.epoch_filenames = []
|
113 |
+
for label, files in self.label2files.items():
|
114 |
+
sampled_files = random.sample(files, self.min_samples)
|
115 |
+
self.epoch_filenames.extend(sampled_files)
|
116 |
+
|
117 |
+
random.shuffle(self.epoch_filenames)
|
118 |
+
|
119 |
+
def __len__(self):
|
120 |
+
return len(self.epoch_filenames)
|
121 |
+
|
122 |
+
def __getitem__(self, idx):
|
123 |
+
filename = self.epoch_filenames[idx]
|
124 |
+
label = os.path.basename(filename).split('_')[0]
|
125 |
+
label = self.label2idx[label]
|
126 |
+
|
127 |
+
data = np.load(filename)
|
128 |
+
data = torch.from_numpy(data)[0]
|
129 |
+
label = torch.tensor(label)
|
130 |
+
|
131 |
+
return data, label
|
132 |
+
|
133 |
+
def on_epoch_end(self):
|
134 |
+
self._update_epoch_filenames()
|
135 |
+
|
136 |
+
# load filenames under train and eval folder
|
137 |
+
train_files = list_files_in_directory(TRAIN_FOLDERS)
|
138 |
+
eval_files = list_files_in_directory(EVAL_FOLDERS)
|
139 |
+
|
140 |
+
if len(eval_files)==0:
|
141 |
+
random.shuffle(train_files)
|
142 |
+
eval_files = train_files[:math.ceil(len(train_files)*EVAL_SPLIT)]
|
143 |
+
train_files = train_files[math.ceil(len(train_files)*EVAL_SPLIT):]
|
144 |
+
if BALANCED_TRAINING:
|
145 |
+
train_set = BalancedTensorDataset(train_files)
|
146 |
+
else:
|
147 |
+
train_set = TensorDataset(train_files)
|
148 |
+
eval_set = TensorDataset(eval_files)
|
149 |
+
eval_set.label2idx = train_set.label2idx
|
150 |
+
|
151 |
+
model = LinearClassification(num_classes=len(train_set.label2idx))
|
152 |
+
model = model.to(device)
|
153 |
+
|
154 |
+
# print parameter number
|
155 |
+
print("Parameter Number: "+str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
|
156 |
+
|
157 |
+
if world_size > 1:
|
158 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
|
159 |
+
|
160 |
+
scaler = GradScaler()
|
161 |
+
is_autocast = True
|
162 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
163 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
164 |
+
|
165 |
+
# call model with a batch of input
|
166 |
+
def process_one_batch(batch):
|
167 |
+
input_tensors, labels = batch
|
168 |
+
logits = model(input_tensors)
|
169 |
+
loss = loss_fn(logits, labels)
|
170 |
+
prediction = torch.argmax(logits, dim=1)
|
171 |
+
acc_num = torch.sum(prediction==labels)
|
172 |
+
|
173 |
+
return loss, acc_num, prediction, labels
|
174 |
+
|
175 |
+
# do one epoch for training
|
176 |
+
def train_epoch():
|
177 |
+
tqdm_train_set = tqdm(train_set)
|
178 |
+
total_train_loss = 0
|
179 |
+
total_acc_num = 0
|
180 |
+
iter_idx = 1
|
181 |
+
model.train()
|
182 |
+
|
183 |
+
for batch in tqdm_train_set:
|
184 |
+
if is_autocast:
|
185 |
+
with autocast(device_type='cuda'):
|
186 |
+
loss, acc_num, prediction, labels = process_one_batch(batch)
|
187 |
+
scaler.scale(loss).backward()
|
188 |
+
scaler.step(optimizer)
|
189 |
+
scaler.update()
|
190 |
+
else:
|
191 |
+
loss, acc_num, prediction, labels = process_one_batch(batch)
|
192 |
+
loss.backward()
|
193 |
+
optimizer.step()
|
194 |
+
|
195 |
+
lr_scheduler.step()
|
196 |
+
model.zero_grad(set_to_none=True)
|
197 |
+
total_train_loss += loss.item()
|
198 |
+
total_acc_num += acc_num.item()
|
199 |
+
tqdm_train_set.set_postfix({str(global_rank)+'_train_acc': total_acc_num / (iter_idx*batch_size)})
|
200 |
+
# Log the training loss to wandb
|
201 |
+
if global_rank==0 and WANDB_LOG:
|
202 |
+
wandb.log({"acc": total_acc_num / (iter_idx*batch_size)})
|
203 |
+
|
204 |
+
iter_idx += 1
|
205 |
+
|
206 |
+
if BALANCED_TRAINING:
|
207 |
+
train_set.dataset.on_epoch_end()
|
208 |
+
|
209 |
+
return total_acc_num / ((iter_idx-1)*batch_size)
|
210 |
+
|
211 |
+
# do one epoch for eval
|
212 |
+
def eval_epoch():
|
213 |
+
tqdm_eval_set = tqdm(eval_set)
|
214 |
+
total_eval_loss = 0
|
215 |
+
total_acc_num = 0
|
216 |
+
iter_idx = 1
|
217 |
+
model.eval()
|
218 |
+
|
219 |
+
all_predictions = []
|
220 |
+
all_labels = []
|
221 |
+
|
222 |
+
# Evaluate data for one epoch
|
223 |
+
for batch in tqdm_eval_set:
|
224 |
+
with torch.no_grad():
|
225 |
+
loss, acc_num, prediction, labels = process_one_batch(batch)
|
226 |
+
total_eval_loss += loss.item()
|
227 |
+
total_acc_num += acc_num.item()
|
228 |
+
|
229 |
+
# Accumulate predictions and labels
|
230 |
+
all_predictions.extend(prediction.cpu().numpy())
|
231 |
+
all_labels.extend(labels.cpu().numpy())
|
232 |
+
|
233 |
+
tqdm_eval_set.set_postfix({str(global_rank)+'_eval_acc': total_acc_num / (iter_idx*batch_size)})
|
234 |
+
iter_idx += 1
|
235 |
+
|
236 |
+
# Compute F1 Macro
|
237 |
+
f1_macro = f1_score(all_labels, all_predictions, average='macro')
|
238 |
+
return total_acc_num / ((iter_idx - 1) * batch_size), f1_macro
|
239 |
+
|
240 |
+
# train and eval
|
241 |
+
if __name__ == "__main__":
|
242 |
+
|
243 |
+
label2idx = train_set.label2idx
|
244 |
+
max_eval_acc = 0
|
245 |
+
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=global_rank)
|
246 |
+
eval_sampler = DistributedSampler(eval_set, num_replicas=world_size, rank=global_rank)
|
247 |
+
|
248 |
+
train_set = DataLoader(train_set, batch_size=batch_size, collate_fn=collate_batch, sampler=train_sampler, shuffle = (train_sampler is None))
|
249 |
+
eval_set = DataLoader(eval_set, batch_size=batch_size, collate_fn=collate_batch, sampler=eval_sampler, shuffle = (train_sampler is None))
|
250 |
+
|
251 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer = optimizer, num_warmup_steps = len(train_set))
|
252 |
+
|
253 |
+
model = model.to(device)
|
254 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
255 |
+
|
256 |
+
if WANDB_LOG and global_rank==0:
|
257 |
+
# Initialize wandb
|
258 |
+
if WANDB_KEY:
|
259 |
+
wandb.login(key=WANDB_KEY)
|
260 |
+
wandb.init(project="linear",
|
261 |
+
name=WEIGHTS_PATH.replace("weights_", "").replace(".pth", ""))
|
262 |
+
|
263 |
+
for epoch in range(1, NUM_EPOCHS+1):
|
264 |
+
train_sampler.set_epoch(epoch)
|
265 |
+
eval_sampler.set_epoch(epoch)
|
266 |
+
print('-' * 21 + "Epoch " + str(epoch) + '-' * 21)
|
267 |
+
train_acc = train_epoch()
|
268 |
+
eval_acc, eval_f1_macro = eval_epoch()
|
269 |
+
if global_rank==0:
|
270 |
+
with open(LOGS_PATH,'a') as f:
|
271 |
+
f.write("Epoch " + str(epoch) + "\ntrain_acc: " + str(train_acc) + "\neval_acc: " +str(eval_acc) + "\neval_f1_macro: " +str(eval_f1_macro) + "\ntime: " + time.asctime(time.localtime(time.time())) + "\n\n")
|
272 |
+
if eval_acc > max_eval_acc:
|
273 |
+
best_epoch = epoch
|
274 |
+
max_eval_acc = eval_acc
|
275 |
+
checkpoint = {
|
276 |
+
'model': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
|
277 |
+
'optimizer': optimizer.state_dict(),
|
278 |
+
'lr_sched': lr_scheduler.state_dict(),
|
279 |
+
'epoch': epoch,
|
280 |
+
'best_epoch': best_epoch,
|
281 |
+
'max_eval_acc': max_eval_acc,
|
282 |
+
"labels": label2idx
|
283 |
+
}
|
284 |
+
torch.save(checkpoint, WEIGHTS_PATH)
|
285 |
+
with open(LOGS_PATH,'a') as f:
|
286 |
+
f.write("Best Epoch so far!\n\n\n")
|
287 |
+
|
288 |
+
if world_size > 1:
|
289 |
+
dist.barrier()
|
290 |
+
|
291 |
+
if global_rank==0:
|
292 |
+
print("Best Eval Epoch : "+str(best_epoch))
|
293 |
+
print("Max Eval Accuracy : "+str(max_eval_acc))
|
music_classification/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from config import *
|
3 |
+
|
4 |
+
class LinearClassification(torch.nn.Module):
|
5 |
+
def __init__(self, num_classes):
|
6 |
+
super(LinearClassification, self).__init__()
|
7 |
+
self.fc1 = torch.nn.Linear(INPUT_HIDDEN_SIZE, HIDDEN_SIZE)
|
8 |
+
self.relu = torch.nn.ReLU()
|
9 |
+
self.fc2 = torch.nn.Linear(HIDDEN_SIZE, num_classes)
|
10 |
+
self.softmax = torch.nn.Softmax(dim=1)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
# Apply the linear layer and ReLU to each time step
|
14 |
+
x = self.fc1(x) # x shape (B, L, H) -> (B, L, hidden_size)
|
15 |
+
x = self.relu(x)
|
16 |
+
|
17 |
+
# Average over the time steps (L dimension)
|
18 |
+
x = x.mean(dim=1) # Now x has shape (B, hidden_size)
|
19 |
+
|
20 |
+
x = self.fc2(x) # Now applying the final layer (B, hidden_size) -> (B, num_classes)
|
21 |
+
x = self.softmax(x)
|
22 |
+
return x
|
process_data/README.md
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Processing Database
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
This codebase contains scripts and utilities for converting between various musical data formats, including ABC notation, MusicXML, MIDI, and MTF (MIDI Text Format). Additionally, it includes a script for summarizing music metadata, which is represented in JSON format containing textual information, using the OpenAI GPT-4 API. The GPT-4 model processes this metadata to generate concise summaries in multiple languages to boost multilingual MIR. These tools are designed to facilitate the transformation and manipulation of musical files, as well as to provide concise multilingual summaries of music metadata for use with CLaMP 2.
|
5 |
+
|
6 |
+
|
7 |
+
## About ABC notation
|
8 |
+
### Standard ABC Notation
|
9 |
+
ABC notation (sheet music), a text-based sheet music representation like stave notation, is theory-oriented and ideal for presenting complex musical concepts to musicians for study and analysis. Standard ABC notation encodes each voice separately, which often results in corresponding bars being spaced far apart. This separation makes it difficult for models to accurately understand the interactions between voices in sheet music that are meant to align musically.
|
10 |
+
|
11 |
+
Example Standard ABC notation representation:
|
12 |
+
```
|
13 |
+
%%score { 1 | 2 }
|
14 |
+
L:1/8
|
15 |
+
Q:1/4=120
|
16 |
+
M:3/4
|
17 |
+
K:G
|
18 |
+
V:1 treble nm="Piano" snm="Pno."
|
19 |
+
V:2 bass
|
20 |
+
V:1
|
21 |
+
!mf!"^Allegro" d2 (GA Bc | d2) .G2 .G2 |]
|
22 |
+
V:2
|
23 |
+
[G,B,D]4 A,2 | B,6 |]
|
24 |
+
```
|
25 |
+
|
26 |
+
### Interleaved ABC Notation
|
27 |
+
In contrast, interleaved ABC notation effectively aligns multi-track music by integrating multiple voices of the same bar into a single line, ensuring that all parts remain synchronized. This format combines voices in-line and tags each bar with its corresponding voice (e.g., `[V:1]` for treble and `[V:2]` for bass). By directly aligning related bars, interleaved ABC notation enhances the model’s understanding of how different voices interact within the same bar.
|
28 |
+
|
29 |
+
Below is the same data optimized with M3 encoding, where each bar or header corresponds to a patch:
|
30 |
+
```
|
31 |
+
%%score { 1 | 2 }
|
32 |
+
L:1/8
|
33 |
+
Q:1/4=120
|
34 |
+
M:3/4
|
35 |
+
K:G
|
36 |
+
V:1 treble nm="Piano" snm="Pno."
|
37 |
+
V:2 bass
|
38 |
+
[V:1]!mf!"^Allegro" d2 (GA Bc|[V:2][G,B,D]4 A,2|
|
39 |
+
[V:1]d2) .G2 .G2|][V:2]B,6|]
|
40 |
+
```
|
41 |
+
|
42 |
+
## About MTF
|
43 |
+
### Raw MIDI Messages
|
44 |
+
MIDI (performance data) precisely encodes performance information related to timing and dynamics, thus suitable for music production and live performance. Raw MIDI messages contain essential musical instructions and metadata, extracted directly from a MIDI file. These include events like note on/off, tempo changes, key signatures, and control changes, which define how the music is performed. The [mido library](https://mido.readthedocs.io/) allows for reading these messages in their native format, as seen below. Each message can include multiple parameters, making the output comprehensive but sometimes redundant.
|
45 |
+
|
46 |
+
```
|
47 |
+
MetaMessage ('time_signature', numerator=3, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0)
|
48 |
+
MetaMessage('key_signature', key='G', time=0)
|
49 |
+
MetaMessage('set_tempo', tempo=500000, time=0)
|
50 |
+
control_change channel=0 control=121 value=0 time=0
|
51 |
+
program_change channel=0 program=0 time=0
|
52 |
+
control_change channel=0 control=7 value=100 time=0
|
53 |
+
control_change channel=0 control=10 value=64 time=0
|
54 |
+
control_change channel=0 control=91 value=0 time=0
|
55 |
+
control_change channel=0 control=93 value=0 time=0
|
56 |
+
MetaMessage('midi_port', port=0, time=0)
|
57 |
+
note_on channel=0 note=74 velocity=80 time=0
|
58 |
+
MetaMessage('key_signature', key='G', time=0)
|
59 |
+
MetaMessage('midi_port', port=0, time=0)
|
60 |
+
note_on channel=0 note=55 velocity=80 time=0
|
61 |
+
note_on channel=0 note=59 velocity=80 time=0
|
62 |
+
note_on channel=0 note=62 velocity=80 time=0
|
63 |
+
note_on channel=0 note=74 velocity=0 time=455
|
64 |
+
note_on channel=0 note=67 velocity=80 time=25
|
65 |
+
note_on channel=0 note=67 velocity=0 time=239
|
66 |
+
note_on channel=0 note=69 velocity=80 time=1
|
67 |
+
note_on channel=0 note=55 velocity=0 time=191
|
68 |
+
note_on channel=0 note=59 velocity=0 time=0
|
69 |
+
note_on channel=0 note=62 velocity=0 time=0
|
70 |
+
note_on channel=0 note=69 velocity=0 time=48
|
71 |
+
note_on channel=0 note=71 velocity=80 time=1
|
72 |
+
note_on channel=0 note=57 velocity=80 time=0
|
73 |
+
note_on channel=0 note=71 velocity=0 time=239
|
74 |
+
note_on channel=0 note=72 velocity=80 time=1
|
75 |
+
note_on channel=0 note=57 velocity=0 time=215
|
76 |
+
note_on channel=0 note=72 velocity=0 time=24
|
77 |
+
note_on channel=0 note=74 velocity=80 time=1
|
78 |
+
note_on channel=0 note=59 velocity=80 time=0
|
79 |
+
note_on channel=0 note=74 velocity=0 time=455
|
80 |
+
note_on channel=0 note=67 velocity=80 time=25
|
81 |
+
note_on channel=0 note=67 velocity=0 time=239
|
82 |
+
note_on channel=0 note=67 velocity=80 time=241
|
83 |
+
note_on channel=0 note=67 velocity=0 time=239
|
84 |
+
note_on channel=0 note=59 velocity=0 time=168
|
85 |
+
MetaMessage('end_of_track', time=1)
|
86 |
+
```
|
87 |
+
### MIDI Text Format (MTF)
|
88 |
+
The MIDI Text Format (MTF) provides a structured, textual representation of MIDI data that preserves all original information without loss. Each MIDI message is accurately represented, allowing full reconstruction, ensuring no musical nuances are overlooked during conversion.
|
89 |
+
|
90 |
+
To generate MTF, the mido library reads raw MIDI messages from MIDI files. The output retains all essential information but can be lengthy and redundant. To simplify the representation, parameter values are read in a fixed order and separated by spaces. For example, the raw time signature message, which contains several parameters—numerator, denominator, clocks per click, notated 32nd notes per beat, and time—is represented in MTF as:
|
91 |
+
|
92 |
+
```
|
93 |
+
time_signature 3 4 24 8 0
|
94 |
+
```
|
95 |
+
|
96 |
+
Other messages, such as control changes and note events, follow a similar compact format while preserving all relevant musical details. This structured simplification improves computational performance and maintains precise control over musical elements, including timing and dynamics.
|
97 |
+
|
98 |
+
Example MTF representation:
|
99 |
+
```
|
100 |
+
ticks_per_beat 480
|
101 |
+
time_signature 3 4 24 8 0
|
102 |
+
key_signature G 0
|
103 |
+
set_tempo 500000 0
|
104 |
+
control_change 0 0 121 0
|
105 |
+
program_change 0 0 0
|
106 |
+
control_change 0 0 7 100
|
107 |
+
control_change 0 0 10 64
|
108 |
+
control_change 0 0 91 0
|
109 |
+
control_change 0 0 93 0
|
110 |
+
midi_port 0 0
|
111 |
+
note_on 0 0 74 80
|
112 |
+
key_signature G 0
|
113 |
+
midi_port 0 0
|
114 |
+
note_on 0 0 55 80
|
115 |
+
note_on 0 0 59 80
|
116 |
+
note_on 0 0 62 80
|
117 |
+
note_on 455 0 74 0
|
118 |
+
note_on 25 0 67 80
|
119 |
+
note_on 239 0 67 0
|
120 |
+
note_on 1 0 69 80
|
121 |
+
note_on 191 0 55 0
|
122 |
+
note_on 0 0 59 0
|
123 |
+
note_on 0 0 62 0
|
124 |
+
note_on 48 0 69 0
|
125 |
+
note_on 1 0 71 80
|
126 |
+
note_on 0 0 57 80
|
127 |
+
note_on 239 0 71 0
|
128 |
+
note_on 1 0 72 80
|
129 |
+
note_on 215 0 57 0
|
130 |
+
note_on 24 0 72 0
|
131 |
+
note_on 1 0 74 80
|
132 |
+
note_on 0 0 59 80
|
133 |
+
note_on 455 0 74 0
|
134 |
+
note_on 25 0 67 80
|
135 |
+
note_on 239 0 67 0
|
136 |
+
note_on 241 0 67 80
|
137 |
+
note_on 239 0 67 0
|
138 |
+
note_on 168 0 59 0
|
139 |
+
end_of_track 1
|
140 |
+
```
|
141 |
+
For simplicity, `ticks_per_beat`, though originally an attribute of MIDI objects in mido, is included as the first message at the beginning of the MTF representation.
|
142 |
+
|
143 |
+
### M3-Encoded MTF
|
144 |
+
When processed using M3 encoding, consecutive messages of the same type that fit within a 64-character limit (the patch size of M3) are combined into a single line. Only the first message in each group specifies the type, with subsequent messages listing only the parameter values separated by tabs. This further simplifies the representation and improves processing efficiency.
|
145 |
+
|
146 |
+
Below is the same data optimized with M3 encoding, where each line corresponds to a patch:
|
147 |
+
```
|
148 |
+
ticks_per_beat 480
|
149 |
+
time_signature 3 4 24 8 0
|
150 |
+
key_signature G 0
|
151 |
+
set_tempo 500000 0
|
152 |
+
control_change 0 0 121 0
|
153 |
+
program_change 0 0 0
|
154 |
+
control_change 0 0 7 100\t0 0 10 64\t0 0 91 0\t0 0 93 0
|
155 |
+
midi_port 0 0
|
156 |
+
note_on 0 0 74 80
|
157 |
+
key_signature G 0
|
158 |
+
midi_port 0 0
|
159 |
+
note_on 0 0 55 80\t0 0 59 80\t0 0 62 80\t455 0 74 0\t25 0 67 80
|
160 |
+
note_on 239 0 67 0\t1 0 69 80\t191 0 55 0\t0 0 59 0\t0 0 62 0
|
161 |
+
note_on 48 0 69 0\t1 0 71 80\t0 0 57 80\t239 0 71 0\t1 0 72 80
|
162 |
+
note_on 215 0 57 0\t24 0 72 0\t1 0 74 80\t0 0 59 80\t455 0 74 0
|
163 |
+
note_on 25 0 67 80\t239 0 67 0\t0 67 80\t239 0 67 0\t168 0 59 0
|
164 |
+
end_of_track 1
|
165 |
+
```
|
166 |
+
|
167 |
+
By reducing redundancy, M3 encoding ensures improved computational performance while maintaining precise timing and musical control, making it an ideal choice for efficient MIDI processing.
|
168 |
+
|
169 |
+
## Repository Structure
|
170 |
+
The `process_data/` folder includes the following scripts and utility files:
|
171 |
+
|
172 |
+
### 1. **Conversion Scripts**
|
173 |
+
|
174 |
+
#### `batch_abc2xml.py`
|
175 |
+
- **Purpose**: Converts ABC notation files into MusicXML format.
|
176 |
+
- **Input**: Directory of interleaved ABC files (modify the `input_dir` variable in the code).
|
177 |
+
- **Output**: MusicXML files saved in a newly created `_xml` directory.
|
178 |
+
- **Logging**: Errors are logged to `logs/abc2xml_error_log.txt`.
|
179 |
+
|
180 |
+
#### `batch_xml2abc.py`
|
181 |
+
- **Purpose**: Converts MusicXML files into standard ABC notation format.
|
182 |
+
- **Input**: Directory of MusicXML files (e.g., `.xml`, `.mxl`, `.musicxml`) (modify the `input_dir` variable in the code).
|
183 |
+
- **Output**: Standard ABC files saved in a newly created `_abc` directory.
|
184 |
+
- **Logging**: Errors are logged to `logs/xml2abc_error_log.txt`.
|
185 |
+
|
186 |
+
#### `batch_interleaved_abc.py`
|
187 |
+
- **Purpose**: Processes standard ABC notation files into interleaved ABC notation.
|
188 |
+
- **Input**: Directory of ABC files (modify the `input_dir` variable in the code).
|
189 |
+
- **Output**: Interleaved ABC files saved in a newly created `_interleaved` directory.
|
190 |
+
- **Logging**: Any processing errors are printed to the console.
|
191 |
+
|
192 |
+
#### `batch_midi2mtf.py`
|
193 |
+
- **Purpose**: Converts MIDI files into MIDI Text Format (MTF).
|
194 |
+
- **Input**: Directory of MIDI files (e.g., `.mid`, `.midi`) (modify the `input_dir` variable in the code).
|
195 |
+
- **Output**: MTF files saved in a newly created `_mtf` directory.
|
196 |
+
- **Logging**: Errors are logged to `logs/midi2mtf_error_log.txt`.
|
197 |
+
- **Note**: The script includes an `m3_compatible` variable, which is set to `True` by default. When `True`, the conversion omits messages whose parameters are strings or lists to eliminate potential natural language information. This ensures that the converted MTF files align with the data format used for training the M3 and CLaMP 2 pretrained weights.
|
198 |
+
|
199 |
+
#### `batch_mtf2midi.py`
|
200 |
+
- **Purpose**: Converts MTF files into MIDI format.
|
201 |
+
- **Input**: Directory of MTF files (modify the `input_dir` variable in the code).
|
202 |
+
- **Output**: MIDI files saved in a newly created `_midi` directory.
|
203 |
+
- **Logging**: Errors are logged to `logs/mtf2midi_error_log.txt`.
|
204 |
+
|
205 |
+
### 2. **Summarization Script**
|
206 |
+
|
207 |
+
#### `gpt4_summarize.py`
|
208 |
+
- **Purpose**: Utilizes the OpenAI GPT-4 API to generate concise summaries of music metadata in multiple languages. The script filters out any entries that lack sufficient musical information to ensure meaningful summaries are produced.
|
209 |
+
- **Input**: Directory of JSON files containing music metadata (modify the `input_dir` variable in the code). For any missing metadata fields, the corresponding keys can be set to `None`. Each JSON file corresponds to a single musical composition and can be linked to both ABC notation and MTF formats. Here’s an example of the required metadata format:
|
210 |
+
|
211 |
+
```json
|
212 |
+
{
|
213 |
+
"title": "Hard Times Come Again No More",
|
214 |
+
"composer": "Stephen Foster",
|
215 |
+
"genres": ["Children's Music", "Folk"],
|
216 |
+
"description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.",
|
217 |
+
"lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus",
|
218 |
+
"tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"],
|
219 |
+
"ensembles": ["Folk Ensemble"],
|
220 |
+
"instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"],
|
221 |
+
"filepaths": [
|
222 |
+
"abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc",
|
223 |
+
"mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf"
|
224 |
+
]
|
225 |
+
}
|
226 |
+
```
|
227 |
+
|
228 |
+
- **Output**: JSON files containing structured summaries in both English and a randomly selected non-English language, chosen from a selection of 100 different non-English languages (in this case, Simplified Chinese). Here’s an example of the expected output format:
|
229 |
+
|
230 |
+
```json
|
231 |
+
{
|
232 |
+
"title": "Hard Times Come Again No More",
|
233 |
+
"composer": "Stephen Foster",
|
234 |
+
"genres": ["Children's Music", "Folk"],
|
235 |
+
"description": "\"Hard Times Come Again No More\" (sometimes referred to as \"Hard Times\") is an American parlor song written by Stephen Foster, reflecting themes of sorrow and hope.",
|
236 |
+
"lyrics": "Let us pause in life's pleasures and count its many tears,\nWhile we all sup sorrow with the poor;\nThere's a song that will linger forever in our ears;\nOh! Hard times come again no more.\n\nChorus:\n'Tis the song, the sigh of the weary,\nHard Times, hard times, come again no more.\nMany days you have lingered around my cabin door;\nOh! Hard times come again no more.\n\nWhile we seek mirth and beauty and music light and gay,\nThere are frail forms fainting at the door;\nThough their voices are silent, their pleading looks will say\nOh! Hard times come again no more.\nChorus\n\nThere's a pale weeping maiden who toils her life away,\nWith a worn heart whose better days are o'er:\nThough her voice would be merry, 'tis sighing all the day,\nOh! Hard times come again no more.\nChorus\n\n'Tis a sigh that is wafted across the troubled wave,\n'Tis a wail that is heard upon the shore\n'Tis a dirge that is murmured around the lowly grave\nOh! Hard times come again no more.\nChorus",
|
237 |
+
"tags": ["folk", "traditional", "bluegrass", "nostalgic", "heartfelt", "acoustic", "melancholic", "storytelling", "American roots", "resilience"],
|
238 |
+
"ensembles": ["Folk Ensemble"],
|
239 |
+
"instruments": ["Vocal", "Violin", "Tin whistle", "Guitar", "Banjo", "Tambourine"],
|
240 |
+
"summary_en": "\"Hard Times Come Again No More,\" composed by Stephen Foster, is a poignant American parlor song that explores themes of sorrow and hope. The lyrics reflect on the contrast between life's pleasures and its hardships, inviting listeners to acknowledge both joy and suffering. With a heartfelt chorus that repeats the line \"Hard times come again no more,\" the song resonates with nostalgia and resilience. It is often performed by folk ensembles and features a variety of instruments, including vocals, violin, guitar, and banjo, encapsulating the spirit of American roots music.",
|
241 |
+
"summary_nen": {
|
242 |
+
"language": "Chinese (Simplified)",
|
243 |
+
"summary": "《艰难时光再无来临》是斯蒂芬·福斯特创作的一首感人至深的美国小歌厅歌曲,探讨了悲伤与希望的主题。歌词展现了生活的乐趣与艰辛之间的对比,邀请听众去感受快乐与痛苦的交织。歌曲中那句反复吟唱的“艰难时光再无来临”深情地表达了怀旧与坚韧。它常常由民谣乐队演奏,伴随着人声、小提琴、吉他和班卓琴等多种乐器,生动地展现了美国根源音乐的独特魅力。"
|
244 |
+
},
|
245 |
+
"filepaths": [
|
246 |
+
"abc/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.abc",
|
247 |
+
"mtf/American_Music/Folk_Traditions/19th_Century/Stephen_Foster/Hard_Times_Come_Again_No_More.mtf"
|
248 |
+
]
|
249 |
+
}
|
250 |
+
```
|
251 |
+
|
252 |
+
- **Logging**: Errors are logged to `logs/gpt4_summarize_error_log.txt`.
|
253 |
+
|
254 |
+
### 3. **Utilities**
|
255 |
+
- **`utils/`**: Contains utility files required for the conversion processes.
|
256 |
+
|
257 |
+
## Usage
|
258 |
+
To use the scripts, modify the `input_dir` variable in each script to point to the directory containing your input files. Then run the script from the command line. Below are example commands for each script:
|
259 |
+
|
260 |
+
### Example Commands
|
261 |
+
```bash
|
262 |
+
# Modify the input_dir variable in the script before running
|
263 |
+
python batch_abc2xml.py
|
264 |
+
python batch_xml2abc.py
|
265 |
+
python batch_interleaved_abc.py
|
266 |
+
python batch_midi2mtf.py
|
267 |
+
python batch_mtf2midi.py
|
268 |
+
python gpt4_summarize.py
|
269 |
+
```
|
270 |
+
|
271 |
+
### Execution Order
|
272 |
+
To achieve specific conversions, follow the order below:
|
273 |
+
|
274 |
+
1. **To obtain interleaved ABC notation**:
|
275 |
+
- First, run `batch_xml2abc.py` to convert MusicXML files to ABC notation.
|
276 |
+
- Then, run `batch_interleaved_abc.py` to process the ABC files into interleaved ABC notation.
|
277 |
+
|
278 |
+
2. **To obtain MTF**:
|
279 |
+
- Run `batch_midi2mtf.py` to convert MIDI files into MTF.
|
280 |
+
|
281 |
+
3. **To convert interleaved ABC back to XML**:
|
282 |
+
- Run `batch_xml2abc.py` on the interleaved ABC files to convert them back to MusicXML format.
|
283 |
+
|
284 |
+
4. **To convert MTF back to MIDI**:
|
285 |
+
- Run `batch_mtf2midi.py` to convert MTF files back to MIDI format.
|
286 |
+
|
287 |
+
5. **To summarize music metadata**:
|
288 |
+
- Run `gpt4_summarize.py` to generate summaries for the music metadata files in JSON format. This assumes you have a directory of JSON files that includes a `filepaths` key, which connects to the corresponding interleaved ABC and MTF files.
|
289 |
+
|
290 |
+
### Parameters
|
291 |
+
To run the scripts, you need to configure the following parameters:
|
292 |
+
|
293 |
+
- **`input_dir`**: This variable should be set to the directory containing the input files to be processed (such as ABC, MusicXML, MIDI, MTF, or JSON files), which is shared across all scripts.
|
294 |
+
|
295 |
+
In addition to **`input_dir`**, the following parameters are specific to certain scripts:
|
296 |
+
|
297 |
+
- **`m3_compatible`** (specific to `batch_midi2mtf.py`):
|
298 |
+
- Default is `True`, which omits messages with parameters that are strings or lists to avoid including potential natural language information.
|
299 |
+
- Setting this to `False` retains all MIDI messages, which is crucial for those planning to retrain models on custom datasets or needing precise MIDI reproduction.
|
300 |
+
|
301 |
+
For **`gpt4_summarize.py`**, you also need to configure these parameters:
|
302 |
+
|
303 |
+
1. **`base_url`**: The base URL for the OpenAI API, used to initialize the client.
|
304 |
+
2. **`api_key`**: Your API key for authenticating requests, required for client initialization.
|
305 |
+
3. **`model`**: The GPT-4 model to use, specified when generating summaries.
|
306 |
+
|
307 |
+
**Important**: When `m3_compatible` is set to `True`, the conversion back from MTF to MIDI using `batch_mtf2midi.py` may produce MIDI files that do not exactly match the original MIDI files. This discrepancy is unexpected; however, retraining both M3 and CLaMP 2 to address this issue would require approximately 6000 hours of H800 GPU hours. Considering that M3 and CLaMP 2 have already achieved state-of-the-art results on MIDI tasks, we have opted not to retrain. Therefore, if consistency with original MIDI files is critical for your application, it is advisable to set `m3_compatible` to `False`.
|
process_data/batch_abc2xml.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dir = "<path_to_your_interleaved_abc_files>" # Replace with the path to your folder containing interleaved ABC (.abc) files
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
import subprocess
|
7 |
+
from tqdm import tqdm
|
8 |
+
from multiprocessing import Pool
|
9 |
+
|
10 |
+
def convert_abc2xml(file_list):
|
11 |
+
cmd = 'cmd /u /c python utils/abc2xml.py '
|
12 |
+
for file in tqdm(file_list):
|
13 |
+
filename = file.split('/')[-1] # Extract file name
|
14 |
+
output_dir = file.split('/')[:-1] # Extract directory path
|
15 |
+
output_dir[0] = output_dir[0] + '_xml' # Create new output folder
|
16 |
+
output_dir = '/'.join(output_dir)
|
17 |
+
os.makedirs(output_dir, exist_ok=True)
|
18 |
+
|
19 |
+
try:
|
20 |
+
p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
|
21 |
+
result = p.communicate()
|
22 |
+
output = result[0].decode('utf-8')
|
23 |
+
|
24 |
+
if output == '':
|
25 |
+
with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
|
26 |
+
f.write(file + '\n')
|
27 |
+
continue
|
28 |
+
else:
|
29 |
+
output_path = f"{output_dir}/" + ".".join(filename.split(".")[:-1]) + ".xml"
|
30 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
31 |
+
f.write(output)
|
32 |
+
except Exception as e:
|
33 |
+
with open("logs/abc2xml_error_log.txt", "a", encoding="utf-8") as f:
|
34 |
+
f.write(file + ' ' + str(e) + '\n')
|
35 |
+
pass
|
36 |
+
|
37 |
+
if __name__ == '__main__':
|
38 |
+
file_list = []
|
39 |
+
os.makedirs("logs", exist_ok=True)
|
40 |
+
|
41 |
+
# Traverse the specified folder for ABC files
|
42 |
+
for root, dirs, files in os.walk(input_dir):
|
43 |
+
for file in files:
|
44 |
+
if not file.endswith(".abc"):
|
45 |
+
continue
|
46 |
+
filename = os.path.join(root, file).replace("\\", "/")
|
47 |
+
file_list.append(filename)
|
48 |
+
|
49 |
+
# Prepare for multiprocessing
|
50 |
+
file_lists = []
|
51 |
+
random.shuffle(file_list)
|
52 |
+
for i in range(os.cpu_count()):
|
53 |
+
start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
|
54 |
+
end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
|
55 |
+
file_lists.append(file_list[start_idx:end_idx])
|
56 |
+
|
57 |
+
pool = Pool(processes=os.cpu_count())
|
58 |
+
pool.map(convert_abc2xml, file_lists)
|
process_data/batch_interleaved_abc.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dir = "<path_to_your_abc_files>" # Replace with the path to your folder containing standard ABC (.abc) files
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from tqdm import tqdm
|
8 |
+
from abctoolkit.utils import (
|
9 |
+
find_all_abc,
|
10 |
+
remove_information_field,
|
11 |
+
remove_bar_no_annotations,
|
12 |
+
Quote_re,
|
13 |
+
Barlines,
|
14 |
+
strip_empty_bars
|
15 |
+
)
|
16 |
+
from abctoolkit.rotate import rotate_abc
|
17 |
+
from abctoolkit.check import check_alignment_unrotated
|
18 |
+
|
19 |
+
def abc_pipeline(abc_path, input_dir, output_dir):
|
20 |
+
"""
|
21 |
+
Converts standard ABC notation to interleaved ABC notation.
|
22 |
+
"""
|
23 |
+
with open(abc_path, 'r', encoding='utf-8') as f:
|
24 |
+
abc_lines = f.readlines()
|
25 |
+
|
26 |
+
abc_lines = [line for line in abc_lines if line.strip() != '']
|
27 |
+
abc_lines = remove_information_field(
|
28 |
+
abc_lines=abc_lines,
|
29 |
+
info_fields=['X:', 'T:', 'C:', 'W:', 'w:', 'Z:', '%%MIDI']
|
30 |
+
)
|
31 |
+
abc_lines = remove_bar_no_annotations(abc_lines)
|
32 |
+
|
33 |
+
# Remove escaped quotes and clean up barlines inside quotes
|
34 |
+
for i, line in enumerate(abc_lines):
|
35 |
+
if not (re.search(r'^[A-Za-z]:', line) or line.startswith('%')):
|
36 |
+
abc_lines[i] = line.replace(r'\"', '')
|
37 |
+
quote_contents = re.findall(Quote_re, line)
|
38 |
+
for quote_content in quote_contents:
|
39 |
+
for barline in Barlines:
|
40 |
+
if barline in quote_content:
|
41 |
+
line = line.replace(quote_content, '')
|
42 |
+
abc_lines[i] = line
|
43 |
+
|
44 |
+
try:
|
45 |
+
stripped_abc_lines, bar_counts = strip_empty_bars(abc_lines)
|
46 |
+
except Exception as e:
|
47 |
+
print(abc_path, 'Error in stripping empty bars:', e)
|
48 |
+
return
|
49 |
+
|
50 |
+
if stripped_abc_lines is None:
|
51 |
+
print(abc_path, 'Failed to strip')
|
52 |
+
return
|
53 |
+
|
54 |
+
# Check alignment
|
55 |
+
_, bar_no_equal_flag, bar_dur_equal_flag = check_alignment_unrotated(stripped_abc_lines)
|
56 |
+
if not bar_no_equal_flag:
|
57 |
+
print(abc_path, 'Unequal bar number')
|
58 |
+
if not bar_dur_equal_flag:
|
59 |
+
print(abc_path, 'Unequal bar duration (unaligned)')
|
60 |
+
|
61 |
+
# Construct the output path, maintaining input folder structure
|
62 |
+
relative_path = os.path.relpath(abc_path, input_dir) # Get relative path from input dir
|
63 |
+
output_file_path = os.path.join(output_dir, relative_path) # Recreate output path
|
64 |
+
os.makedirs(os.path.dirname(output_file_path), exist_ok=True) # Ensure output folder exists
|
65 |
+
|
66 |
+
try:
|
67 |
+
rotated_abc_lines = rotate_abc(stripped_abc_lines)
|
68 |
+
except Exception as e:
|
69 |
+
print(abc_path, 'Error in rotating:', e)
|
70 |
+
return
|
71 |
+
|
72 |
+
if rotated_abc_lines is None:
|
73 |
+
print(abc_path, 'Failed to rotate')
|
74 |
+
return
|
75 |
+
|
76 |
+
with open(output_file_path, 'w', encoding='utf-8') as w:
|
77 |
+
w.writelines(rotated_abc_lines)
|
78 |
+
|
79 |
+
def abc_pipeline_list(abc_path_list, input_dir, output_dir):
|
80 |
+
for abc_path in tqdm(abc_path_list):
|
81 |
+
try:
|
82 |
+
abc_pipeline(abc_path, input_dir, output_dir)
|
83 |
+
except Exception as e:
|
84 |
+
print(abc_path, e)
|
85 |
+
pass
|
86 |
+
|
87 |
+
def batch_abc_pipeline(input_dir):
|
88 |
+
"""
|
89 |
+
Batch process all ABC files from `input_dir`, converting them to interleaved notation.
|
90 |
+
"""
|
91 |
+
output_dir = input_dir + "_interleaved"
|
92 |
+
if not os.path.exists(output_dir):
|
93 |
+
os.makedirs(output_dir, exist_ok=True)
|
94 |
+
|
95 |
+
abc_path_list = []
|
96 |
+
for abc_path in find_all_abc(input_dir):
|
97 |
+
if os.path.getsize(abc_path) > 0:
|
98 |
+
abc_path_list.append(abc_path)
|
99 |
+
random.shuffle(abc_path_list)
|
100 |
+
print(f"Found {len(abc_path_list)} ABC files.")
|
101 |
+
|
102 |
+
num_cpus = os.cpu_count()
|
103 |
+
split_lists = [[] for _ in range(num_cpus)]
|
104 |
+
index = 0
|
105 |
+
|
106 |
+
for abc_path in abc_path_list:
|
107 |
+
split_lists[index].append(abc_path)
|
108 |
+
index = (index + 1) % num_cpus
|
109 |
+
|
110 |
+
pool = Pool(processes=num_cpus)
|
111 |
+
pool.starmap(
|
112 |
+
abc_pipeline_list,
|
113 |
+
[(split, input_dir, output_dir) for split in split_lists]
|
114 |
+
)
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
batch_abc_pipeline(input_dir)
|
process_data/batch_midi2mtf.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dir = "<path_to_your_midi_files>" # Replace with the path to your folder containing MIDI (.midi, .mid) files
|
2 |
+
m3_compatible = True # Set to True for M3 compatibility; set to False to retain all MIDI information during conversion.
|
3 |
+
|
4 |
+
import os
|
5 |
+
import math
|
6 |
+
import mido
|
7 |
+
import random
|
8 |
+
from tqdm import tqdm
|
9 |
+
from multiprocessing import Pool
|
10 |
+
|
11 |
+
def msg_to_str(msg):
|
12 |
+
str_msg = ""
|
13 |
+
for key, value in msg.dict().items():
|
14 |
+
str_msg += " " + str(value)
|
15 |
+
return str_msg.strip().encode('unicode_escape').decode('utf-8')
|
16 |
+
|
17 |
+
def load_midi(filename):
|
18 |
+
# Load a MIDI file
|
19 |
+
mid = mido.MidiFile(filename)
|
20 |
+
msg_list = ["ticks_per_beat " + str(mid.ticks_per_beat)]
|
21 |
+
|
22 |
+
# Traverse the MIDI file
|
23 |
+
for msg in mid.merged_track:
|
24 |
+
if m3_compatible:
|
25 |
+
if msg.is_meta:
|
26 |
+
if msg.type in ["text", "copyright", "track_name", "instrument_name",
|
27 |
+
"lyrics", "marker", "cue_marker", "device_name", "sequencer_specific"]:
|
28 |
+
continue
|
29 |
+
else:
|
30 |
+
if msg.type in ["sysex"]:
|
31 |
+
continue
|
32 |
+
str_msg = msg_to_str(msg)
|
33 |
+
msg_list.append(str_msg)
|
34 |
+
|
35 |
+
return "\n".join(msg_list)
|
36 |
+
|
37 |
+
def convert_midi2mtf(file_list):
|
38 |
+
for file in tqdm(file_list):
|
39 |
+
filename = file.split('/')[-1]
|
40 |
+
output_dir = file.split('/')[:-1]
|
41 |
+
output_dir[0] = output_dir[0] + '_mtf'
|
42 |
+
output_dir = '/'.join(output_dir)
|
43 |
+
os.makedirs(output_dir, exist_ok=True)
|
44 |
+
try:
|
45 |
+
output = load_midi(file)
|
46 |
+
|
47 |
+
if output == '':
|
48 |
+
with open('logs/midi2mtf_error_log.txt', 'a', encoding='utf-8') as f:
|
49 |
+
f.write(file + '\n')
|
50 |
+
continue
|
51 |
+
else:
|
52 |
+
with open(output_dir + "/" + ".".join(filename.split(".")[:-1]) + '.mtf', 'w', encoding='utf-8') as f:
|
53 |
+
f.write(output)
|
54 |
+
except Exception as e:
|
55 |
+
with open('logs/midi2mtf_error_log.txt', 'a', encoding='utf-8') as f:
|
56 |
+
f.write(file + " " + str(e) + '\n')
|
57 |
+
pass
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
file_list = []
|
61 |
+
os.makedirs("logs", exist_ok=True)
|
62 |
+
|
63 |
+
# Traverse the specified folder for MIDI files
|
64 |
+
for root, dirs, files in os.walk(input_dir):
|
65 |
+
for file in files:
|
66 |
+
if not file.endswith(".mid") and not file.endswith(".midi"):
|
67 |
+
continue
|
68 |
+
filename = os.path.join(root, file).replace("\\", "/")
|
69 |
+
file_list.append(filename)
|
70 |
+
|
71 |
+
# Prepare for multiprocessing
|
72 |
+
file_lists = []
|
73 |
+
random.shuffle(file_list)
|
74 |
+
for i in range(os.cpu_count()):
|
75 |
+
start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
|
76 |
+
end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
|
77 |
+
file_lists.append(file_list[start_idx:end_idx])
|
78 |
+
|
79 |
+
pool = Pool(processes=os.cpu_count())
|
80 |
+
pool.map(convert_midi2mtf, file_lists)
|
process_data/batch_mtf2midi.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dir = "<path_to_your_mtf_files>" # Replace with the path to your folder containing MTF (.mtf) files
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import mido
|
6 |
+
import random
|
7 |
+
from tqdm import tqdm
|
8 |
+
from multiprocessing import Pool
|
9 |
+
|
10 |
+
def str_to_msg(str_msg):
|
11 |
+
type = str_msg.split(" ")[0]
|
12 |
+
try:
|
13 |
+
msg = mido.Message(type)
|
14 |
+
except:
|
15 |
+
msg = mido.MetaMessage(type)
|
16 |
+
|
17 |
+
if type in ["text", "copyright", "track_name", "instrument_name",
|
18 |
+
"lyrics", "marker", "cue_marker", "device_name"]:
|
19 |
+
values = [type, " ".join(str_msg.split(" ")[1:-1]).encode('utf-8').decode('unicode_escape'), str_msg.split(" ")[-1]]
|
20 |
+
elif "[" in str_msg or "(" in str_msg:
|
21 |
+
is_bracket = "[" in str_msg
|
22 |
+
left_idx = str_msg.index("[") if is_bracket else str_msg.index("(")
|
23 |
+
right_idx = str_msg.index("]") if is_bracket else str_msg.index(")")
|
24 |
+
list_str = [int(num) for num in str_msg[left_idx+1:right_idx].split(", ")]
|
25 |
+
if not is_bracket:
|
26 |
+
list_str = tuple(list_str)
|
27 |
+
values = str_msg[:left_idx].split(" ") + [list_str] + str_msg[right_idx+1:].split(" ")
|
28 |
+
values = [value for value in values if value != ""]
|
29 |
+
else:
|
30 |
+
values = str_msg.split(" ")
|
31 |
+
|
32 |
+
if len(values) != 1:
|
33 |
+
for idx, (key, content) in enumerate(msg.__dict__.items()):
|
34 |
+
if key == "type":
|
35 |
+
continue
|
36 |
+
value = values[idx]
|
37 |
+
if isinstance(content, int) or isinstance(content, float):
|
38 |
+
float_value = float(value)
|
39 |
+
value = float_value
|
40 |
+
if value % 1 == 0:
|
41 |
+
value = int(value)
|
42 |
+
setattr(msg, key, value)
|
43 |
+
|
44 |
+
return msg
|
45 |
+
|
46 |
+
def convert_mtf2midi(file_list):
|
47 |
+
for file in tqdm(file_list):
|
48 |
+
filename = file.split('/')[-1]
|
49 |
+
output_dir = file.split('/')[:-1]
|
50 |
+
output_dir[0] = output_dir[0] + '_midi'
|
51 |
+
output_dir = '/'.join(output_dir)
|
52 |
+
os.makedirs(output_dir, exist_ok=True)
|
53 |
+
try:
|
54 |
+
with open(file, 'r', encoding='utf-8') as f:
|
55 |
+
msg_list = f.read().splitlines()
|
56 |
+
|
57 |
+
# Build a new MIDI file based on the MIDI messages
|
58 |
+
new_mid = mido.MidiFile()
|
59 |
+
new_mid.ticks_per_beat = int(msg_list[0].split(" ")[1])
|
60 |
+
|
61 |
+
track = mido.MidiTrack()
|
62 |
+
new_mid.tracks.append(track)
|
63 |
+
|
64 |
+
for msg in msg_list[1:]:
|
65 |
+
if "unknown_meta" in msg:
|
66 |
+
continue
|
67 |
+
new_msg = str_to_msg(msg)
|
68 |
+
track.append(new_msg)
|
69 |
+
|
70 |
+
output_file_path = os.path.join(output_dir, os.path.basename(file).replace('.mtf', '.mid'))
|
71 |
+
new_mid.save(output_file_path)
|
72 |
+
except Exception as e:
|
73 |
+
with open('logs/mtf2midi_error_log.txt', 'a', encoding='utf-8') as f:
|
74 |
+
f.write(f"Error processing {file}: {str(e)}\n")
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
file_list = []
|
78 |
+
os.makedirs("logs", exist_ok=True)
|
79 |
+
|
80 |
+
# Traverse the specified folder for MTF files
|
81 |
+
for root, dirs, files in os.walk(input_dir):
|
82 |
+
for file in files:
|
83 |
+
if not file.endswith(".mtf"):
|
84 |
+
continue
|
85 |
+
filename = os.path.join(root, file).replace("\\", "/")
|
86 |
+
file_list.append(filename)
|
87 |
+
|
88 |
+
# Prepare for multiprocessing
|
89 |
+
file_lists = []
|
90 |
+
random.shuffle(file_list)
|
91 |
+
for i in range(os.cpu_count()):
|
92 |
+
start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
|
93 |
+
end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
|
94 |
+
file_lists.append(file_list[start_idx:end_idx])
|
95 |
+
|
96 |
+
pool = Pool(processes=os.cpu_count())
|
97 |
+
pool.map(convert_mtf2midi, file_lists)
|
process_data/batch_xml2abc.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dir = "<path_to_your_xml_files>" # Replace with the path to your folder containing XML (.xml, .mxl, .musicxml) files
|
2 |
+
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
import subprocess
|
7 |
+
from tqdm import tqdm
|
8 |
+
from multiprocessing import Pool
|
9 |
+
|
10 |
+
def convert_xml2abc(file_list):
|
11 |
+
cmd = 'cmd /u /c python utils/xml2abc.py -d 8 -x '
|
12 |
+
for file in tqdm(file_list):
|
13 |
+
filename = file.split('/')[-1]
|
14 |
+
output_dir = file.split('/')[:-1]
|
15 |
+
output_dir[0] = output_dir[0] + '_abc'
|
16 |
+
output_dir = '/'.join(output_dir)
|
17 |
+
os.makedirs(output_dir, exist_ok=True)
|
18 |
+
|
19 |
+
try:
|
20 |
+
p = subprocess.Popen(cmd + '"' + file + '"', stdout=subprocess.PIPE, shell=True)
|
21 |
+
result = p.communicate()
|
22 |
+
output = result[0].decode('utf-8')
|
23 |
+
|
24 |
+
if output == '':
|
25 |
+
with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
|
26 |
+
f.write(file + '\n')
|
27 |
+
continue
|
28 |
+
else:
|
29 |
+
with open(output_dir + '/' + ".".join(filename.split(".")[:-1]) + '.abc', 'w', encoding='utf-8') as f:
|
30 |
+
f.write(output)
|
31 |
+
except Exception as e:
|
32 |
+
with open("logs/xml2abc_error_log.txt", "a", encoding="utf-8") as f:
|
33 |
+
f.write(file + ' ' + str(e) + '\n')
|
34 |
+
pass
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
file_list = []
|
38 |
+
os.makedirs("logs", exist_ok=True)
|
39 |
+
|
40 |
+
# Traverse the specified folder for XML/MXL files
|
41 |
+
for root, dirs, files in os.walk(input_dir):
|
42 |
+
for file in files:
|
43 |
+
if not file.endswith((".mxl", ".xml", ".musicxml")):
|
44 |
+
continue
|
45 |
+
filename = os.path.join(root, file).replace("\\", "/")
|
46 |
+
file_list.append(filename)
|
47 |
+
|
48 |
+
# Prepare for multiprocessing
|
49 |
+
file_lists = []
|
50 |
+
random.shuffle(file_list)
|
51 |
+
for i in range(os.cpu_count()):
|
52 |
+
start_idx = int(math.floor(i * len(file_list) / os.cpu_count()))
|
53 |
+
end_idx = int(math.floor((i + 1) * len(file_list) / os.cpu_count()))
|
54 |
+
file_lists.append(file_list[start_idx:end_idx])
|
55 |
+
|
56 |
+
pool = Pool(processes=os.cpu_count())
|
57 |
+
pool.map(convert_xml2abc, file_lists)
|
process_data/gpt4_summarize.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_dir = "<path_to_your_metadata_json_files>" # Replace with the path to your folder containing metadata (.json) files
|
2 |
+
base_url = "<your_base_url>" # Replace with the base URL for the API
|
3 |
+
api_key = "<your_api_key>" # Replace with your API key
|
4 |
+
model = "<your_model>" # Replace with your model name
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
from openai import OpenAI
|
10 |
+
|
11 |
+
# Initialize the OpenAI client
|
12 |
+
client = OpenAI(base_url=base_url, api_key=api_key)
|
13 |
+
|
14 |
+
def log_error(file_path, error_message):
|
15 |
+
"""Logs error messages to a specified log file."""
|
16 |
+
os.makedirs("logs", exist_ok=True)
|
17 |
+
with open("logs/gpt4_summarize_error_log.txt", 'a', encoding='utf-8') as log_file:
|
18 |
+
log_file.write(f"Error processing {file_path}: {error_message}\n")
|
19 |
+
|
20 |
+
def process_json(metadata, language):
|
21 |
+
"""
|
22 |
+
Processes the given metadata of a music piece using GPT-4 API.
|
23 |
+
|
24 |
+
This function sends the metadata and target language to the GPT-4 model to generate
|
25 |
+
a structured summary. The summary is provided in both English and the specified
|
26 |
+
non-English language from the 'nen_language' field.
|
27 |
+
|
28 |
+
If the provided metadata lacks sufficient music-related details, the function returns `None`.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
- metadata (dict): A dictionary containing the metadata of the music piece.
|
32 |
+
- language (str): The target non-English language for the summary.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
- str: A JSON-formatted string containing the English and non-English summaries,
|
36 |
+
or `None` if there is insufficient information.
|
37 |
+
"""
|
38 |
+
system = """Your task is to provide a concise, comprehensive, and coherent summary of the music piece using the provided metadata. Please write the summary in English first, and then write an equivalent summary in the specified non-English language from the "nen_language" field. Use this JSON format:
|
39 |
+
{
|
40 |
+
"summary_en": "Your English summary here.",
|
41 |
+
"summary_nen": {
|
42 |
+
"language": "Specified non-English language.",
|
43 |
+
"summary": "Your non-English summary here."
|
44 |
+
}
|
45 |
+
If there is not enough music-related information, return `None` instead.
|
46 |
+
}
|
47 |
+
"""
|
48 |
+
user1 = """{
|
49 |
+
"title": "Brejeiro",
|
50 |
+
"composer": "Ernesto Nazareth",
|
51 |
+
"genres": ["Choro", "Classical", "Instrumental"],
|
52 |
+
"description": "\"Brejeiro\" is in A major and 2/4 time. A joyful melody begins at bar six, and a lively tango rhythm starts at bar fourteen. It has a D.C. al Fine at bar fifty-three and ends on two quarter notes in bar thirty-seven. The piece, with its vibrant melodies and rhythms, reflects celebration and carefreeness, embodying the spirit of Brazilian music.",
|
53 |
+
"tags": ["Brazilian", "Choro", "Piano"],
|
54 |
+
"ensembles": ["Solo Piano", "Small Ensemble"],
|
55 |
+
"instruments": ["Piano"],
|
56 |
+
"nen_language": "Japanese"
|
57 |
+
}
|
58 |
+
"""
|
59 |
+
assistant1 = """{
|
60 |
+
"summary_en": "Brejeiro, composed by Ernesto Nazareth, is a lively choro piece in A major and 2/4 time. It features a joyful melody that begins at bar six and a vibrant tango rhythm introduced at bar fourteen. The piece includes a D.C. al Fine at bar fifty-three, concluding on two quarter notes in bar thirty-seven. With its themes of celebration and carefreeness, Brejeiro beautifully captures the essence of Brazilian music and is well-suited for solo piano and small ensembles.",
|
61 |
+
"summary_nen": {
|
62 |
+
"language": "Japanese",
|
63 |
+
"summary": "「ブレジェイロ」は、エルネスト・ナザレが作曲した活気あふれるショーロの作品で、イ長調の2/4拍子で書かれています。第6小節から始まる喜びに満ちたメロディーと、第14小節で導入される活気あるタンゴのリズムが特徴です。この曲には、第53小節でのD.C. al Fineが含まれ、また第37小節で二つの四分音符で締めくくられています。「ブレジェイロ」は、お祝いと無邪気さのテーマを持ち、ブラジル音楽の本質を美しく捉えており、ソロピアノや小編成のアンサンブルにぴったりの作品です。"
|
64 |
+
}
|
65 |
+
}
|
66 |
+
"""
|
67 |
+
user2 = """{
|
68 |
+
"title": "Untitled",
|
69 |
+
"composer": "Unknown",
|
70 |
+
"description": "This is a good song.",
|
71 |
+
"nen_language": "Russian"
|
72 |
+
}
|
73 |
+
"""
|
74 |
+
assistant2 = "None"
|
75 |
+
filepaths = metadata.pop('filepaths')
|
76 |
+
metadata = {k: v for k, v in metadata.items() if v is not None}
|
77 |
+
|
78 |
+
metadata["nen_language"] = language
|
79 |
+
metadata = json.dumps(metadata, ensure_ascii=False, indent=4)
|
80 |
+
summaries = client.chat.completions.create(
|
81 |
+
model=model,
|
82 |
+
messages=[
|
83 |
+
{"role": "system", "content": system},
|
84 |
+
{"role": "user", "content": user1},
|
85 |
+
{"role": "assistant", "content": assistant1},
|
86 |
+
{"role": "user", "content": user2},
|
87 |
+
{"role": "assistant", "content": assistant2},
|
88 |
+
{"role": "user", "content": metadata},
|
89 |
+
]
|
90 |
+
).choices[0].message.content
|
91 |
+
|
92 |
+
if summaries == "None":
|
93 |
+
raise ValueError("Received 'None' as summaries response")
|
94 |
+
|
95 |
+
metadata = json.loads(metadata)
|
96 |
+
summaries = json.loads(summaries)
|
97 |
+
|
98 |
+
if metadata["nen_language"] == summaries["summary_nen"]["language"]:
|
99 |
+
metadata.pop("nen_language")
|
100 |
+
metadata["summary_en"] = summaries["summary_en"]
|
101 |
+
metadata["summary_nen"] = summaries["summary_nen"]
|
102 |
+
metadata["filepaths"] = filepaths
|
103 |
+
return metadata
|
104 |
+
else:
|
105 |
+
raise ValueError("Language mismatch: nen_language does not match summary_nen language")
|
106 |
+
|
107 |
+
def process_files(input_dir):
|
108 |
+
# Create output directory with _summarized suffix
|
109 |
+
output_dir = input_dir + "_summarized"
|
110 |
+
|
111 |
+
# Define available languages
|
112 |
+
languages = """Afrikaans
|
113 |
+
Amharic
|
114 |
+
Arabic
|
115 |
+
Assamese
|
116 |
+
Azerbaijani
|
117 |
+
Belarusian
|
118 |
+
Bulgarian
|
119 |
+
Bengali
|
120 |
+
Bengali (Romanized)
|
121 |
+
Breton
|
122 |
+
Bosnian
|
123 |
+
Catalan
|
124 |
+
Czech
|
125 |
+
Welsh
|
126 |
+
Danish
|
127 |
+
German
|
128 |
+
Greek
|
129 |
+
Esperanto
|
130 |
+
Spanish
|
131 |
+
Estonian
|
132 |
+
Basque
|
133 |
+
Persian
|
134 |
+
Finnish
|
135 |
+
French
|
136 |
+
Western Frisian
|
137 |
+
Irish
|
138 |
+
Scottish Gaelic
|
139 |
+
Galician
|
140 |
+
Gujarati
|
141 |
+
Hausa
|
142 |
+
Hebrew
|
143 |
+
Hindi
|
144 |
+
Hindi (Romanized)
|
145 |
+
Croatian
|
146 |
+
Hungarian
|
147 |
+
Armenian
|
148 |
+
Indonesian
|
149 |
+
Icelandic
|
150 |
+
Italian
|
151 |
+
Japanese
|
152 |
+
Javanese
|
153 |
+
Georgian
|
154 |
+
Kazakh
|
155 |
+
Khmer
|
156 |
+
Kannada
|
157 |
+
Korean
|
158 |
+
Kurdish (Kurmanji)
|
159 |
+
Kyrgyz
|
160 |
+
Latin
|
161 |
+
Lao
|
162 |
+
Lithuanian
|
163 |
+
Latvian
|
164 |
+
Malagasy
|
165 |
+
Macedonian
|
166 |
+
Malayalam
|
167 |
+
Mongolian
|
168 |
+
Marathi
|
169 |
+
Malay
|
170 |
+
Burmese
|
171 |
+
Burmese (Romanized)
|
172 |
+
Nepali
|
173 |
+
Dutch
|
174 |
+
Norwegian
|
175 |
+
Oromo
|
176 |
+
Oriya
|
177 |
+
Punjabi
|
178 |
+
Polish
|
179 |
+
Pashto
|
180 |
+
Portuguese
|
181 |
+
Romanian
|
182 |
+
Russian
|
183 |
+
Sanskrit
|
184 |
+
Sindhi
|
185 |
+
Sinhala
|
186 |
+
Slovak
|
187 |
+
Slovenian
|
188 |
+
Somali
|
189 |
+
Albanian
|
190 |
+
Serbian
|
191 |
+
Sundanese
|
192 |
+
Swedish
|
193 |
+
Swahili
|
194 |
+
Tamil
|
195 |
+
Tamil (Romanized)
|
196 |
+
Telugu
|
197 |
+
Telugu (Romanized)
|
198 |
+
Thai
|
199 |
+
Filipino
|
200 |
+
Turkish
|
201 |
+
Uyghur
|
202 |
+
Ukrainian
|
203 |
+
Urdu
|
204 |
+
Urdu (Romanized)
|
205 |
+
Uzbek
|
206 |
+
Vietnamese
|
207 |
+
Xhosa
|
208 |
+
Yiddish
|
209 |
+
Chinese (Simplified)
|
210 |
+
Chinese (Traditional)
|
211 |
+
Cantonese"""
|
212 |
+
languages = [language.strip() for language in languages.split("\n")]
|
213 |
+
|
214 |
+
# Walk through the input directory
|
215 |
+
for root, _, files in os.walk(input_dir):
|
216 |
+
# Construct the corresponding path in the output folder
|
217 |
+
relative_path = os.path.relpath(root, input_dir)
|
218 |
+
output_path = os.path.join(output_dir, relative_path)
|
219 |
+
|
220 |
+
# Create the output directory if it doesn't exist
|
221 |
+
os.makedirs(output_path, exist_ok=True)
|
222 |
+
|
223 |
+
for file in files:
|
224 |
+
if file.endswith('.json'):
|
225 |
+
input_file = os.path.join(root, file)
|
226 |
+
output_file = os.path.join(output_path, file)
|
227 |
+
|
228 |
+
try:
|
229 |
+
# Read the JSON file
|
230 |
+
with open(input_file, 'r', encoding='utf-8') as f:
|
231 |
+
metadata = json.load(f)
|
232 |
+
|
233 |
+
# Randomly select a language from the list of languages
|
234 |
+
language = random.choice(languages)
|
235 |
+
|
236 |
+
# Process the JSON data
|
237 |
+
processed_metadata = process_json(metadata, language)
|
238 |
+
|
239 |
+
# Write the processed JSON to the output file
|
240 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
241 |
+
json.dump(processed_metadata, f, indent=4, ensure_ascii=False)
|
242 |
+
|
243 |
+
print(f"Processed: {input_file} -> {output_file}")
|
244 |
+
|
245 |
+
except Exception as e:
|
246 |
+
print(f"Failed to process {input_file}: {e}")
|
247 |
+
log_error(input_file, str(e))
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
process_files(input_dir)
|
process_data/utils/abc2xml.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
process_data/utils/pyparsing.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
process_data/utils/xml2abc.py
ADDED
@@ -0,0 +1,1582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=latin-1
|
3 |
+
'''
|
4 |
+
Copyright (C) 2012-2018: W.G. Vree
|
5 |
+
Contributions: M. Tarenskeen, N. Liberg, Paul Villiger, Janus Meuris, Larry Myerscough,
|
6 |
+
Dick Jackson, Jan Wybren de Jong, Mark Zealey.
|
7 |
+
|
8 |
+
This program is free software; you can redistribute it and/or modify it under the terms of the
|
9 |
+
Lesser GNU General Public License as published by the Free Software Foundation;
|
10 |
+
|
11 |
+
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
|
12 |
+
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
13 |
+
See the Lesser GNU General Public License for more details. <http://www.gnu.org/licenses/lgpl.html>.
|
14 |
+
'''
|
15 |
+
|
16 |
+
try: import xml.etree.cElementTree as E
|
17 |
+
except: import xml.etree.ElementTree as E
|
18 |
+
import os, sys, types, re, math
|
19 |
+
|
20 |
+
VERSION = 143
|
21 |
+
|
22 |
+
python3 = sys.version_info.major > 2
|
23 |
+
if python3:
|
24 |
+
tupletype = tuple
|
25 |
+
listtype = list
|
26 |
+
max_int = sys.maxsize
|
27 |
+
else:
|
28 |
+
tupletype = types.TupleType
|
29 |
+
listtype = types.ListType
|
30 |
+
max_int = sys.maxint
|
31 |
+
|
32 |
+
note_ornamentation_map = { # for notations/, modified from EasyABC
|
33 |
+
'ornaments/trill-mark': 'T',
|
34 |
+
'ornaments/mordent': 'M',
|
35 |
+
'ornaments/inverted-mordent': 'P',
|
36 |
+
'ornaments/turn': '!turn!',
|
37 |
+
'ornaments/inverted-turn': '!invertedturn!',
|
38 |
+
'technical/up-bow': 'u',
|
39 |
+
'technical/down-bow': 'v',
|
40 |
+
'technical/harmonic': '!open!',
|
41 |
+
'technical/open-string': '!open!',
|
42 |
+
'technical/stopped': '!plus!',
|
43 |
+
'technical/snap-pizzicato': '!snap!',
|
44 |
+
'technical/thumb-position': '!thumb!',
|
45 |
+
'articulations/accent': '!>!',
|
46 |
+
'articulations/strong-accent':'!^!',
|
47 |
+
'articulations/staccato': '.',
|
48 |
+
'articulations/staccatissimo':'!wedge!',
|
49 |
+
'articulations/scoop': '!slide!',
|
50 |
+
'fermata': '!fermata!',
|
51 |
+
'arpeggiate': '!arpeggio!',
|
52 |
+
'articulations/tenuto': '!tenuto!',
|
53 |
+
'articulations/staccatissimo':'!wedge!', # not sure whether this is the right translation
|
54 |
+
'articulations/spiccato': '!wedge!', # not sure whether this is the right translation
|
55 |
+
'articulations/breath-mark': '!breath!', # this may need to be tested to make sure it appears on the right side of the note
|
56 |
+
'articulations/detached-legato': '!tenuto!.',
|
57 |
+
}
|
58 |
+
|
59 |
+
dynamics_map = { # for direction/direction-type/dynamics/
|
60 |
+
'p': '!p!',
|
61 |
+
'pp': '!pp!',
|
62 |
+
'ppp': '!ppp!',
|
63 |
+
'pppp': '!pppp!',
|
64 |
+
'f': '!f!',
|
65 |
+
'ff': '!ff!',
|
66 |
+
'fff': '!fff!',
|
67 |
+
'ffff': '!ffff!',
|
68 |
+
'mp': '!mp!',
|
69 |
+
'mf': '!mf!',
|
70 |
+
'sfz': '!sfz!',
|
71 |
+
}
|
72 |
+
|
73 |
+
percSvg = '''%%beginsvg
|
74 |
+
<defs>
|
75 |
+
<text id="x" x="-3" y="0"></text>
|
76 |
+
<text id="x-" x="-3" y="0"></text>
|
77 |
+
<text id="x+" x="-3" y="0"></text>
|
78 |
+
<text id="normal" x="-3.7" y="0"></text>
|
79 |
+
<text id="normal-" x="-3.7" y="0"></text>
|
80 |
+
<text id="normal+" x="-3.7" y="0"></text>
|
81 |
+
<g id="circle-x"><text x="-3" y="0"></text><circle r="4" class="stroke"></circle></g>
|
82 |
+
<g id="circle-x-"><text x="-3" y="0"></text><circle r="4" class="stroke"></circle></g>
|
83 |
+
<path id="triangle" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
|
84 |
+
<path id="triangle-" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="stroke-width:1.4"></path>
|
85 |
+
<path id="triangle+" d="m-4 -3.2l4 6.4 4 -6.4z" class="stroke" style="fill:#000"></path>
|
86 |
+
<path id="square" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
|
87 |
+
<path id="square-" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="stroke-width:1.4"></path>
|
88 |
+
<path id="square+" d="m-3.5 3l0 -6.2 7.2 0 0 6.2z" class="stroke" style="fill:#000"></path>
|
89 |
+
<path id="diamond" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
|
90 |
+
<path id="diamond-" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="stroke-width:1.4"></path>
|
91 |
+
<path id="diamond+" d="m0 -3l4.2 3.2 -4.2 3.2 -4.2 -3.2z" class="stroke" style="fill:#000"></path>
|
92 |
+
</defs>
|
93 |
+
%%endsvg'''
|
94 |
+
|
95 |
+
tabSvg = '''%%beginsvg
|
96 |
+
<style type="text/css">
|
97 |
+
.bf {font-family:sans-serif; font-size:7px}
|
98 |
+
</style>
|
99 |
+
<defs>
|
100 |
+
<rect id="clr" x="-3" y="-1" width="6" height="5" fill="white"></rect>
|
101 |
+
<rect id="clr2" x="-3" y="-1" width="11" height="5" fill="white"></rect>'''
|
102 |
+
|
103 |
+
kopSvg = '<g id="kop%s" class="bf"><use xlink:href="#clr"></use><text x="-2" y="3">%s</text></g>\n'
|
104 |
+
kopSvg2 = '<g id="kop%s" class="bf"><use xlink:href="#clr2"></use><text x="-2" y="3">%s</text></g>\n'
|
105 |
+
|
106 |
+
def info (s, warn=1): sys.stderr.write ((warn and '-- ' or '') + s + '\n')
|
107 |
+
|
108 |
+
#-------------------
|
109 |
+
# data abstractions
|
110 |
+
#-------------------
|
111 |
+
class Measure:
|
112 |
+
def __init__ (s, p):
|
113 |
+
s.reset ()
|
114 |
+
s.ixp = p # part number
|
115 |
+
s.ixm = 0 # measure number
|
116 |
+
s.mdur = 0 # measure duration (nominal metre value in divisions)
|
117 |
+
s.divs = 0 # number of divisions per 1/4
|
118 |
+
s.mtr = 4,4 # meter
|
119 |
+
|
120 |
+
def reset (s): # reset each measure
|
121 |
+
s.attr = '' # measure signatures, tempo
|
122 |
+
s.lline = '' # left barline, but only holds ':' at start of repeat, otherwise empty
|
123 |
+
s.rline = '|' # right barline
|
124 |
+
s.lnum = '' # (left) volta number
|
125 |
+
|
126 |
+
class Note:
|
127 |
+
def __init__ (s, dur=0, n=None):
|
128 |
+
s.tijd = 0 # the time in XML division units
|
129 |
+
s.dur = dur # duration of a note in XML divisions
|
130 |
+
s.fact = None # time modification for tuplet notes (num, div)
|
131 |
+
s.tup = [''] # start(s) and/or stop(s) of tuplet
|
132 |
+
s.tupabc = '' # abc tuplet string to issue before note
|
133 |
+
s.beam = 0 # 1 = beamed
|
134 |
+
s.grace = 0 # 1 = grace note
|
135 |
+
s.before = [] # abc string that goes before the note/chord
|
136 |
+
s.after = '' # the same after the note/chord
|
137 |
+
s.ns = n and [n] or [] # notes in the chord
|
138 |
+
s.lyrs = {} # {number -> syllabe}
|
139 |
+
s.tab = None # (string number, fret number)
|
140 |
+
s.ntdec = '' # !string!, !courtesy!
|
141 |
+
|
142 |
+
class Elem:
|
143 |
+
def __init__ (s, string):
|
144 |
+
s.tijd = 0 # the time in XML division units
|
145 |
+
s.str = string # any abc string that is not a note
|
146 |
+
|
147 |
+
class Counter:
|
148 |
+
def inc (s, key, voice): s.counters [key][voice] = s.counters [key].get (voice, 0) + 1
|
149 |
+
def clear (s, vnums): # reset all counters
|
150 |
+
tups = list( zip (vnums.keys (), len (vnums) * [0]))
|
151 |
+
s.counters = {'note': dict (tups), 'nopr': dict (tups), 'nopt': dict (tups)}
|
152 |
+
def getv (s, key, voice): return s.counters[key][voice]
|
153 |
+
def prcnt (s, ip): # print summary of all non zero counters
|
154 |
+
for iv in s.counters ['note']:
|
155 |
+
if s.getv ('nopr', iv) != 0:
|
156 |
+
info ( 'part %d, voice %d has %d skipped non printable notes' % (ip, iv, s.getv ('nopr', iv)))
|
157 |
+
if s.getv ('nopt', iv) != 0:
|
158 |
+
info ( 'part %d, voice %d has %d notes without pitch' % (ip, iv, s.getv ('nopt', iv)))
|
159 |
+
if s.getv ('note', iv) == 0: # no real notes counted in this voice
|
160 |
+
info ( 'part %d, skipped empty voice %d' % (ip, iv))
|
161 |
+
|
162 |
+
class Music:
|
163 |
+
def __init__(s, options):
|
164 |
+
s.tijd = 0 # the current time
|
165 |
+
s.maxtime = 0 # maximum time in a measure
|
166 |
+
s.gMaten = [] # [voices,.. for all measures in a part]
|
167 |
+
s.gLyrics = [] # [{num: (abc_lyric_string, melis)},.. for all measures in a part]
|
168 |
+
s.vnums = {} # all used voice id's in a part (xml voice id's == numbers)
|
169 |
+
s.cnt = Counter () # global counter object
|
170 |
+
s.vceCnt = 1 # the global voice count over all parts
|
171 |
+
s.lastnote = None # the last real note record inserted in s.voices
|
172 |
+
s.bpl = options.b # the max number of bars per line when writing abc
|
173 |
+
s.cpl = options.n # the number of chars per line when writing abc
|
174 |
+
s.repbra = 0 # true if volta is used somewhere
|
175 |
+
s.nvlt = options.v # no volta on higher voice numbers
|
176 |
+
s.jscript = options.j # compatibility with javascript version
|
177 |
+
|
178 |
+
def initVoices (s, newPart=0):
|
179 |
+
s.vtimes, s.voices, s.lyrics = {}, {}, {}
|
180 |
+
for v in s.vnums:
|
181 |
+
s.vtimes [v] = 0 # {voice: the end time of the last item in each voice}
|
182 |
+
s.voices [v] = [] # {voice: [Note|Elem, ..]}
|
183 |
+
s.lyrics [v] = [] # {voice: [{num: syl}, ..]}
|
184 |
+
if newPart: s.cnt.clear (s.vnums) # clear counters once per part
|
185 |
+
|
186 |
+
def incTime (s, dt):
|
187 |
+
s.tijd += dt
|
188 |
+
if s.tijd < 0: s.tijd = 0 # erroneous <backup> element
|
189 |
+
if s.tijd > s.maxtime: s.maxtime = s.tijd
|
190 |
+
|
191 |
+
def appendElemCv (s, voices, elem):
|
192 |
+
for v in voices:
|
193 |
+
s.appendElem (v, elem) # insert element in all voices
|
194 |
+
|
195 |
+
def insertElem (s, v, elem): # insert at the start of voice v in the current measure
|
196 |
+
obj = Elem (elem)
|
197 |
+
obj.tijd = 0 # because voice is sorted later
|
198 |
+
s.voices [v].insert (0, obj)
|
199 |
+
|
200 |
+
def appendObj (s, v, obj, dur):
|
201 |
+
obj.tijd = s.tijd
|
202 |
+
s.voices [v].append (obj)
|
203 |
+
s.incTime (dur)
|
204 |
+
if s.tijd > s.vtimes[v]: s.vtimes[v] = s.tijd # don't update for inserted earlier items
|
205 |
+
|
206 |
+
def appendElem (s, v, elem, tel=0):
|
207 |
+
s.appendObj (v, Elem (elem), 0)
|
208 |
+
if tel: s.cnt.inc ('note', v) # count number of certain elements in each voice (in addition to notes)
|
209 |
+
|
210 |
+
def appendElemT (s, v, elem, tijd): # insert element at specified time
|
211 |
+
obj = Elem (elem)
|
212 |
+
obj.tijd = tijd
|
213 |
+
s.voices [v].append (obj)
|
214 |
+
|
215 |
+
def appendNote (s, v, note, noot):
|
216 |
+
note.ns.append (note.ntdec + noot)
|
217 |
+
s.appendObj (v, note, int (note.dur))
|
218 |
+
s.lastnote = note # remember last note/rest for later modifications (chord, grace)
|
219 |
+
if noot != 'z' and noot != 'x': # real notes and grace notes
|
220 |
+
s.cnt.inc ('note', v) # count number of real notes in each voice
|
221 |
+
if not note.grace: # for every real note
|
222 |
+
s.lyrics[v].append (note.lyrs) # even when it has no lyrics
|
223 |
+
|
224 |
+
def getLastRec (s, voice):
|
225 |
+
if s.gMaten: return s.gMaten[-1][voice][-1] # the last record in the last measure
|
226 |
+
return None # no previous records in the first measure
|
227 |
+
|
228 |
+
def getLastMelis (s, voice, num): # get melisma of last measure
|
229 |
+
if s.gLyrics:
|
230 |
+
lyrdict = s.gLyrics[-1][voice] # the previous lyrics dict in this voice
|
231 |
+
if num in lyrdict: return lyrdict[num][1] # lyrdict = num -> (lyric string, melisma)
|
232 |
+
return 0 # no previous lyrics in voice or line number
|
233 |
+
|
234 |
+
def addChord (s, note, noot): # careful: we assume that chord notes follow immediately
|
235 |
+
for d in note.before: # put all decorations before chord
|
236 |
+
if d not in s.lastnote.before:
|
237 |
+
s.lastnote.before += [d]
|
238 |
+
s.lastnote.ns.append (note.ntdec + noot)
|
239 |
+
|
240 |
+
def addBar (s, lbrk, m): # linebreak, measure data
|
241 |
+
if m.mdur and s.maxtime > m.mdur: info ('measure %d in part %d longer than metre' % (m.ixm+1, m.ixp+1))
|
242 |
+
s.tijd = s.maxtime # the time of the bar lines inserted here
|
243 |
+
for v in s.vnums:
|
244 |
+
if m.lline or m.lnum: # if left barline or left volta number
|
245 |
+
p = s.getLastRec (v) # get the previous barline record
|
246 |
+
if p: # in measure 1 no previous measure is available
|
247 |
+
x = p.str # p.str is the ABC barline string
|
248 |
+
if m.lline: # append begin of repeat, m.lline == ':'
|
249 |
+
x = (x + m.lline).replace (':|:','::').replace ('||','|')
|
250 |
+
if s.nvlt == 3: # add volta number only to lowest voice in part 0
|
251 |
+
if m.ixp + v == min (s.vnums): x += m.lnum
|
252 |
+
elif m.lnum: # new behaviour with I:repbra 0
|
253 |
+
x += m.lnum # add volta number(s) or text to all voices
|
254 |
+
s.repbra = 1 # signal occurrence of a volta
|
255 |
+
p.str = x # modify previous right barline
|
256 |
+
elif m.lline: # begin of new part and left repeat bar is required
|
257 |
+
s.insertElem (v, '|:')
|
258 |
+
if lbrk:
|
259 |
+
p = s.getLastRec (v) # get the previous barline record
|
260 |
+
if p: p.str += lbrk # insert linebreak char after the barlines+volta
|
261 |
+
if m.attr: # insert signatures at front of buffer
|
262 |
+
s.insertElem (v, '%s' % m.attr)
|
263 |
+
s.appendElem (v, ' %s' % m.rline) # insert current barline record at time maxtime
|
264 |
+
s.voices[v] = sortMeasure (s.voices[v], m) # make all times consistent
|
265 |
+
lyrs = s.lyrics[v] # [{number: sylabe}, .. for all notes]
|
266 |
+
lyrdict = {} # {number: (abc_lyric_string, melis)} for this voice
|
267 |
+
nums = [num for d in lyrs for num in d.keys ()] # the lyrics numbers in this measure
|
268 |
+
maxNums = max (nums + [0]) # the highest lyrics number in this measure
|
269 |
+
for i in range (maxNums, 0, -1):
|
270 |
+
xs = [syldict.get (i, '') for syldict in lyrs] # collect the syllabi with number i
|
271 |
+
melis = s.getLastMelis (v, i) # get melisma from last measure
|
272 |
+
lyrdict [i] = abcLyr (xs, melis)
|
273 |
+
s.lyrics[v] = lyrdict # {number: (abc_lyric_string, melis)} for this measure
|
274 |
+
mkBroken (s.voices[v])
|
275 |
+
s.gMaten.append (s.voices)
|
276 |
+
s.gLyrics.append (s.lyrics)
|
277 |
+
s.tijd = s.maxtime = 0
|
278 |
+
s.initVoices ()
|
279 |
+
|
280 |
+
def outVoices (s, divs, ip, isSib): # output all voices of part ip
|
281 |
+
vvmap = {} # xml voice number -> abc voice number (one part)
|
282 |
+
vnum_keys = list (s.vnums.keys ())
|
283 |
+
if s.jscript or isSib: vnum_keys.sort ()
|
284 |
+
lvc = min (vnum_keys or [1]) # lowest xml voice number of this part
|
285 |
+
for iv in vnum_keys:
|
286 |
+
if s.cnt.getv ('note', iv) == 0: # no real notes counted in this voice
|
287 |
+
continue # skip empty voices
|
288 |
+
if abcOut.denL: unitL = abcOut.denL # take the unit length from the -d option
|
289 |
+
else: unitL = compUnitLength (iv, s.gMaten, divs) # compute the best unit length for this voice
|
290 |
+
abcOut.cmpL.append (unitL) # remember for header output
|
291 |
+
vn, vl = [], {} # for voice iv: collect all notes to vn and all lyric lines to vl
|
292 |
+
for im in range (len (s.gMaten)):
|
293 |
+
measure = s.gMaten [im][iv]
|
294 |
+
vn.append (outVoice (measure, divs [im], im, ip, unitL))
|
295 |
+
checkMelismas (s.gLyrics, s.gMaten, im, iv)
|
296 |
+
for n, (lyrstr, melis) in s.gLyrics [im][iv].items ():
|
297 |
+
if n in vl:
|
298 |
+
while len (vl[n]) < im: vl[n].append ('') # fill in skipped measures
|
299 |
+
vl[n].append (lyrstr)
|
300 |
+
else:
|
301 |
+
vl[n] = im * [''] + [lyrstr] # must skip im measures
|
302 |
+
for n, lyrs in vl.items (): # fill up possibly empty lyric measures at the end
|
303 |
+
mis = len (vn) - len (lyrs)
|
304 |
+
lyrs += mis * ['']
|
305 |
+
abcOut.add ('V:%d' % s.vceCnt)
|
306 |
+
if s.repbra:
|
307 |
+
if s.nvlt == 1 and s.vceCnt > 1: abcOut.add ('I:repbra 0') # only volta on first voice
|
308 |
+
if s.nvlt == 2 and iv > lvc: abcOut.add ('I:repbra 0') # only volta on first voice of each part
|
309 |
+
if s.cpl > 0: s.bpl = 0 # option -n (max chars per line) overrules -b (max bars per line)
|
310 |
+
elif s.bpl == 0: s.cpl = 100 # the default: 100 chars per line
|
311 |
+
bn = 0 # count bars
|
312 |
+
while vn: # while still measures available
|
313 |
+
ib = 1
|
314 |
+
chunk = vn [0]
|
315 |
+
while ib < len (vn):
|
316 |
+
if s.cpl > 0 and len (chunk) + len (vn [ib]) >= s.cpl: break # line full (number of chars)
|
317 |
+
if s.bpl > 0 and ib >= s.bpl: break # line full (number of bars)
|
318 |
+
chunk += vn [ib]
|
319 |
+
ib += 1
|
320 |
+
bn += ib
|
321 |
+
abcOut.add (chunk + ' %%%d' % bn) # line with barnumer
|
322 |
+
del vn[:ib] # chop ib bars
|
323 |
+
lyrlines = sorted (vl.items ()) # order the numbered lyric lines for output
|
324 |
+
for n, lyrs in lyrlines:
|
325 |
+
abcOut.add ('w: ' + '|'.join (lyrs[:ib]) + '|')
|
326 |
+
del lyrs[:ib]
|
327 |
+
vvmap [iv] = s.vceCnt # xml voice number -> abc voice number
|
328 |
+
s.vceCnt += 1 # count voices over all parts
|
329 |
+
s.gMaten = [] # reset the follwing instance vars for each part
|
330 |
+
s.gLyrics = []
|
331 |
+
s.cnt.prcnt (ip+1) # print summary of skipped items in this part
|
332 |
+
return vvmap
|
333 |
+
|
334 |
+
class ABCoutput:
|
335 |
+
pagekeys = 'scale,pageheight,pagewidth,leftmargin,rightmargin,topmargin,botmargin'.split (',')
|
336 |
+
def __init__ (s, fnmext, pad, X, options):
|
337 |
+
s.fnmext = fnmext
|
338 |
+
s.outlist = [] # list of ABC strings
|
339 |
+
s.title = 'T:Title'
|
340 |
+
s.key = 'none'
|
341 |
+
s.clefs = {} # clefs for all abc-voices
|
342 |
+
s.mtr = 'none'
|
343 |
+
s.tempo = 0 # 0 -> no tempo field
|
344 |
+
s.tempo_units = (1,4) # note type of tempo direction
|
345 |
+
s.pad = pad # the output path or none
|
346 |
+
s.X = X + 1 # the abc tune number
|
347 |
+
s.denL = options.d # denominator of the unit length (L:) from -d option
|
348 |
+
s.volpan = int (options.m) # 0 -> no %%MIDI, 1 -> only program, 2 -> all %%MIDI
|
349 |
+
s.cmpL = [] # computed optimal unit length for all voices
|
350 |
+
s.jscript = options.j # compatibility with javascript version
|
351 |
+
s.tstep = options.t # translate percmap to voicemap
|
352 |
+
s.stemless = 0 # use U:s=!stemless!
|
353 |
+
s.shiftStem = options.s # shift note heads 3 units left
|
354 |
+
if pad:
|
355 |
+
_, base_name = os.path.split (fnmext)
|
356 |
+
s.outfile = open (os.path.join (pad, base_name), 'w')
|
357 |
+
else: s.outfile = sys.stdout
|
358 |
+
if s.jscript: s.X = 1 # always X:1 in javascript version
|
359 |
+
s.pageFmt = {}
|
360 |
+
for k in s.pagekeys: s.pageFmt [k] = None
|
361 |
+
if len (options.p) == 7:
|
362 |
+
for k, v in zip (s.pagekeys, options.p):
|
363 |
+
try: s.pageFmt [k] = float (v)
|
364 |
+
except: info ('illegal float %s for %s', (k, v)); continue
|
365 |
+
|
366 |
+
def add (s, str):
|
367 |
+
s.outlist.append (str + '\n') # collect all ABC output
|
368 |
+
|
369 |
+
def mkHeader (s, stfmap, partlist, midimap, vmpdct, koppen): # stfmap = [parts], part = [staves], stave = [voices]
|
370 |
+
accVce, accStf, staffs = [], [], stfmap[:] # staffs is consumed
|
371 |
+
for x in partlist: # collect partnames into accVce and staff groups into accStf
|
372 |
+
try: prgroupelem (x, ('', ''), '', stfmap, accVce, accStf)
|
373 |
+
except: info ('lousy musicxml: error in part-list')
|
374 |
+
staves = ' '.join (accStf)
|
375 |
+
clfnms = {}
|
376 |
+
for part, (partname, partabbrv) in zip (staffs, accVce):
|
377 |
+
if not part: continue # skip empty part
|
378 |
+
firstVoice = part[0][0] # the first voice number in this part
|
379 |
+
nm = partname.replace ('\n','\\n').replace ('.:','.').strip (':')
|
380 |
+
snm = partabbrv.replace ('\n','\\n').replace ('.:','.').strip (':')
|
381 |
+
clfnms [firstVoice] = (nm and 'nm="%s"' % nm or '') + (snm and ' snm="%s"' % snm or '')
|
382 |
+
hd = ['X:%d\n%s\n' % (s.X, s.title)]
|
383 |
+
for i, k in enumerate (s.pagekeys):
|
384 |
+
if s.jscript and k in ['pageheight','topmargin', 'botmargin']: continue
|
385 |
+
if s.pageFmt [k] != None: hd.append ('%%%%%s %.2f%s\n' % (k, s.pageFmt [k], i > 0 and 'cm' or ''))
|
386 |
+
if staves and len (accStf) > 1: hd.append ('%%score ' + staves + '\n')
|
387 |
+
tempo = s.tempo and 'Q:%d/%d=%s\n' % (s.tempo_units [0], s.tempo_units [1], s.tempo) or '' # default no tempo field
|
388 |
+
d = {} # determine the most frequently occurring unit length over all voices
|
389 |
+
for x in s.cmpL: d[x] = d.get (x, 0) + 1
|
390 |
+
if s.jscript: defLs = sorted (d.items (), key=lambda x: (-x[1], x[0])) # when tie (1) sort on key (0)
|
391 |
+
else: defLs = sorted (d.items (), key=lambda x: -x[1])
|
392 |
+
defL = s.denL and s.denL or defLs [0][0] # override default unit length with -d option
|
393 |
+
hd.append ('L:1/%d\n%sM:%s\n' % (defL, tempo, s.mtr))
|
394 |
+
hd.append ('K:%s\n' % s.key)
|
395 |
+
if s.stemless: hd.append ('U:s=!stemless!\n')
|
396 |
+
vxs = sorted (vmpdct.keys ())
|
397 |
+
for vx in vxs: hd.extend (vmpdct [vx])
|
398 |
+
s.dojef = 0 # translate percmap to voicemap
|
399 |
+
for vnum, clef in s.clefs.items ():
|
400 |
+
ch, prg, vol, pan = midimap [vnum-1][:4]
|
401 |
+
dmap = midimap [vnum - 1][4:] # map of abc percussion notes to midi notes
|
402 |
+
if dmap and 'perc' not in clef: clef = (clef + ' map=perc').strip ();
|
403 |
+
hd.append ('V:%d %s %s\n' % (vnum, clef, clfnms.get (vnum, '')))
|
404 |
+
if vnum in vmpdct:
|
405 |
+
hd.append ('%%%%voicemap tab%d\n' % vnum)
|
406 |
+
hd.append ('K:none\nM:none\n%%clef none\n%%staffscale 1.6\n%%flatbeams true\n%%stemdir down\n')
|
407 |
+
if 'perc' in clef: hd.append ('K:none\n'); # no key for a perc voice
|
408 |
+
if s.volpan > 1: # option -m 2 -> output all recognized midi commands when needed and present in xml
|
409 |
+
if ch > 0 and ch != vnum: hd.append ('%%%%MIDI channel %d\n' % ch)
|
410 |
+
if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
|
411 |
+
if vol >= 0: hd.append ('%%%%MIDI control 7 %.0f\n' % vol) # volume == 0 is possible ...
|
412 |
+
if pan >= 0: hd.append ('%%%%MIDI control 10 %.0f\n' % pan)
|
413 |
+
elif s.volpan > 0: # default -> only output midi program command when present in xml
|
414 |
+
if dmap and ch > 0: hd.append ('%%%%MIDI channel %d\n' % ch) # also channel if percussion part
|
415 |
+
if prg > 0: hd.append ('%%%%MIDI program %d\n' % (prg - 1))
|
416 |
+
for abcNote, step, midiNote, notehead in dmap:
|
417 |
+
if not notehead: notehead = 'normal'
|
418 |
+
if abcMid (abcNote) != midiNote or abcNote != step:
|
419 |
+
if s.volpan > 0: hd.append ('%%%%MIDI drummap %s %s\n' % (abcNote, midiNote))
|
420 |
+
hd.append ('I:percmap %s %s %s %s\n' % (abcNote, step, midiNote, notehead))
|
421 |
+
s.dojef = s.tstep
|
422 |
+
if defL != s.cmpL [vnum-1]: # only if computed unit length different from header
|
423 |
+
hd.append ('L:1/%d\n' % s.cmpL [vnum-1])
|
424 |
+
s.outlist = hd + s.outlist
|
425 |
+
if koppen: # output SVG stuff needed for tablature
|
426 |
+
k1 = kopSvg.replace ('-2','-5') if s.shiftStem else kopSvg # shift note heads 3 units left
|
427 |
+
k2 = kopSvg2.replace ('-2','-5') if s.shiftStem else kopSvg2
|
428 |
+
tb = tabSvg.replace ('-3','-6') if s.shiftStem else tabSvg
|
429 |
+
ks = sorted (koppen.keys ()) # javascript compatibility
|
430 |
+
ks = [k2 % (k, k) if len (k) == 2 else k1 % (k, k) for k in ks]
|
431 |
+
tbs = map (lambda x: x.strip () + '\n', tb.splitlines ()) # javascript compatibility
|
432 |
+
s.outlist = tbs + ks + ['</defs>\n%%endsvg\n'] + s.outlist
|
433 |
+
|
434 |
+
def writeall (s): # determine the required encoding of the entire ABC output
|
435 |
+
str = ''.join (s.outlist)
|
436 |
+
if s.dojef: str = perc2map (str)
|
437 |
+
if python3: s.outfile.write (str)
|
438 |
+
else: s.outfile.write (str.encode ('utf-8'))
|
439 |
+
if s.pad: s.outfile.close () # close each file with -o option
|
440 |
+
else: s.outfile.write ('\n') # add empty line between tunes on stdout
|
441 |
+
info ('%s written with %d voices' % (s.fnmext, len (s.clefs)), warn=0)
|
442 |
+
|
443 |
+
#----------------
|
444 |
+
# functions
|
445 |
+
#----------------
|
446 |
+
def abcLyr (xs, melis): # Convert list xs to abc lyrics.
|
447 |
+
if not ''.join (xs): return '', 0 # there is no lyrics in this measure
|
448 |
+
res = []
|
449 |
+
for x in xs: # xs has for every note a lyrics syllabe or an empty string
|
450 |
+
if x == '': # note without lyrics
|
451 |
+
if melis: x = '_' # set melisma
|
452 |
+
else: x = '*' # skip note
|
453 |
+
elif x.endswith ('_') and not x.endswith ('\_'): # start of new melisma
|
454 |
+
x = x.replace ('_', '') # remove and set melis boolean
|
455 |
+
melis = 1 # so next skips will become melisma
|
456 |
+
else: melis = 0 # melisma stops on first syllable
|
457 |
+
res.append (x)
|
458 |
+
return (' '.join (res), melis)
|
459 |
+
|
460 |
+
def simplify (a, b): # divide a and b by their greatest common divisor
|
461 |
+
x, y = a, b
|
462 |
+
while b: a, b = b, a % b
|
463 |
+
return x // a, y // a
|
464 |
+
|
465 |
+
def abcdur (nx, divs, uL): # convert an musicXML duration d to abc units with L:1/uL
|
466 |
+
if nx.dur == 0: return '' # when called for elements without duration
|
467 |
+
num, den = simplify (uL * nx.dur, divs * 4) # L=1/8 -> uL = 8 units
|
468 |
+
if nx.fact: # apply tuplet time modification
|
469 |
+
numfac, denfac = nx.fact
|
470 |
+
num, den = simplify (num * numfac, den * denfac)
|
471 |
+
if den > 64: # limit the denominator to a maximum of 64
|
472 |
+
x = float (num) / den; n = math.floor (x); # when just above an integer n
|
473 |
+
if x - n < 0.1 * x: num, den = n, 1; # round to n
|
474 |
+
num64 = 64. * num / den + 1.0e-15 # to get Python2 behaviour of round
|
475 |
+
num, den = simplify (int (round (num64)), 64)
|
476 |
+
if num == 1:
|
477 |
+
if den == 1: dabc = ''
|
478 |
+
elif den == 2: dabc = '/'
|
479 |
+
else: dabc = '/%d' % den
|
480 |
+
elif den == 1: dabc = '%d' % num
|
481 |
+
else: dabc = '%d/%d' % (num, den)
|
482 |
+
return dabc
|
483 |
+
|
484 |
+
def abcMid (note): # abc note -> midi pitch
|
485 |
+
r = re.search (r"([_^]*)([A-Ga-g])([',]*)", note)
|
486 |
+
if not r: return -1
|
487 |
+
acc, n, oct = r.groups ()
|
488 |
+
nUp = n.upper ()
|
489 |
+
p = 60 + [0,2,4,5,7,9,11]['CDEFGAB'.index (nUp)] + (12 if nUp != n else 0);
|
490 |
+
if acc: p += (1 if acc[0] == '^' else -1) * len (acc)
|
491 |
+
if oct: p += (12 if oct[0] == "'" else -12) * len (oct)
|
492 |
+
return p
|
493 |
+
|
494 |
+
def staffStep (ptc, o, clef, tstep):
|
495 |
+
ndif = 0
|
496 |
+
if 'stafflines=1' in clef: ndif += 4 # meaning of one line: E (xml) -> B (abc)
|
497 |
+
if not tstep and clef.startswith ('bass'): ndif += 12 # transpose bass -> treble (C3 -> A4)
|
498 |
+
if ndif: # diatonic transposition == addition modulo 7
|
499 |
+
nm7 = 'C,D,E,F,G,A,B'.split (',')
|
500 |
+
n = nm7.index (ptc) + ndif
|
501 |
+
ptc, o = nm7 [n % 7], o + n // 7
|
502 |
+
if o > 4: ptc = ptc.lower ()
|
503 |
+
if o > 5: ptc = ptc + (o-5) * "'"
|
504 |
+
if o < 4: ptc = ptc + (4-o) * ","
|
505 |
+
return ptc
|
506 |
+
|
507 |
+
def setKey (fifths, mode):
|
508 |
+
sharpness = ['Fb', 'Cb','Gb','Db','Ab','Eb','Bb','F','C','G','D','A', 'E', 'B', 'F#','C#','G#','D#','A#','E#','B#']
|
509 |
+
offTab = {'maj':8, 'ion':8, 'm':11, 'min':11, 'aeo':11, 'mix':9, 'dor':10, 'phr':12, 'lyd':7, 'loc':13, 'non':8}
|
510 |
+
mode = mode.lower ()[:3] # only first three chars, no case
|
511 |
+
key = sharpness [offTab [mode] + fifths] + (mode if offTab [mode] != 8 else '')
|
512 |
+
accs = ['F','C','G','D','A','E','B']
|
513 |
+
if fifths >= 0: msralts = dict (zip (accs[:fifths], fifths * [1]))
|
514 |
+
else: msralts = dict (zip (accs[fifths:], -fifths * [-1]))
|
515 |
+
return key, msralts
|
516 |
+
|
517 |
+
def insTup (ix, notes, fact): # read one nested tuplet
|
518 |
+
tupcnt = 0
|
519 |
+
nx = notes [ix]
|
520 |
+
if 'start' in nx.tup:
|
521 |
+
nx.tup.remove ('start') # do recursive calls when starts remain
|
522 |
+
tix = ix # index of first tuplet note
|
523 |
+
fn, fd = fact # xml time-mod of the higher level
|
524 |
+
fnum, fden = nx.fact # xml time-mod of the current level
|
525 |
+
tupfact = fnum//fn, fden//fd # abc time mod of this level
|
526 |
+
while ix < len (notes):
|
527 |
+
nx = notes [ix]
|
528 |
+
if isinstance (nx, Elem) or nx.grace:
|
529 |
+
ix += 1 # skip all non tuplet elements
|
530 |
+
continue
|
531 |
+
if 'start' in nx.tup: # more nested tuplets to start
|
532 |
+
ix, tupcntR = insTup (ix, notes, tupfact) # ix is on the stop note!
|
533 |
+
tupcnt += tupcntR
|
534 |
+
elif nx.fact:
|
535 |
+
tupcnt += 1 # count tuplet elements
|
536 |
+
if 'stop' in nx.tup:
|
537 |
+
nx.tup.remove ('stop')
|
538 |
+
break
|
539 |
+
if not nx.fact: # stop on first non tuplet note
|
540 |
+
ix = lastix # back to last tuplet note
|
541 |
+
break
|
542 |
+
lastix = ix
|
543 |
+
ix += 1
|
544 |
+
# put abc tuplet notation before the recursive ones
|
545 |
+
tup = (tupfact[0], tupfact[1], tupcnt)
|
546 |
+
if tup == (3, 2, 3): tupPrefix = '(3'
|
547 |
+
else: tupPrefix = '(%d:%d:%d' % tup
|
548 |
+
notes [tix].tupabc = tupPrefix + notes [tix].tupabc
|
549 |
+
return ix, tupcnt # ix is on the last tuplet note
|
550 |
+
|
551 |
+
def mkBroken (vs): # introduce broken rhythms (vs: one voice, one measure)
|
552 |
+
vs = [n for n in vs if isinstance (n, Note)]
|
553 |
+
i = 0
|
554 |
+
while i < len (vs) - 1:
|
555 |
+
n1, n2 = vs[i], vs[i+1] # scan all adjacent pairs
|
556 |
+
# skip if note in tuplet or has no duration or outside beam
|
557 |
+
if not n1.fact and not n2.fact and n1.dur > 0 and n2.beam:
|
558 |
+
if n1.dur * 3 == n2.dur:
|
559 |
+
n2.dur = (2 * n2.dur) // 3
|
560 |
+
n1.dur = n1.dur * 2
|
561 |
+
n1.after = '<' + n1.after
|
562 |
+
i += 1 # do not chain broken rhythms
|
563 |
+
elif n2.dur * 3 == n1.dur:
|
564 |
+
n1.dur = (2 * n1.dur) // 3
|
565 |
+
n2.dur = n2.dur * 2
|
566 |
+
n1.after = '>' + n1.after
|
567 |
+
i += 1 # do not chain broken rhythms
|
568 |
+
i += 1
|
569 |
+
|
570 |
+
def outVoice (measure, divs, im, ip, unitL): # note/elem objects of one measure in one voice
|
571 |
+
ix = 0
|
572 |
+
while ix < len (measure): # set all (nested) tuplet annotations
|
573 |
+
nx = measure [ix]
|
574 |
+
if isinstance (nx, Note) and nx.fact and not nx.grace:
|
575 |
+
ix, tupcnt = insTup (ix, measure, (1, 1)) # read one tuplet, insert annotation(s)
|
576 |
+
ix += 1
|
577 |
+
vs = []
|
578 |
+
for nx in measure:
|
579 |
+
if isinstance (nx, Note):
|
580 |
+
durstr = abcdur (nx, divs, unitL) # xml -> abc duration string
|
581 |
+
chord = len (nx.ns) > 1
|
582 |
+
cns = [nt[:-1] for nt in nx.ns if nt.endswith ('-')]
|
583 |
+
tie = ''
|
584 |
+
if chord and len (cns) == len (nx.ns): # all chord notes tied
|
585 |
+
nx.ns = cns # chord notes without tie
|
586 |
+
tie = '-' # one tie for whole chord
|
587 |
+
s = nx.tupabc + ''.join (nx.before)
|
588 |
+
if chord: s += '['
|
589 |
+
for nt in nx.ns: s += nt
|
590 |
+
if chord: s += ']' + tie
|
591 |
+
if s.endswith ('-'): s, tie = s[:-1], '-' # split off tie
|
592 |
+
s += durstr + tie # and put it back again
|
593 |
+
s += nx.after
|
594 |
+
nospace = nx.beam
|
595 |
+
else:
|
596 |
+
if isinstance (nx.str, listtype): nx.str = nx.str [0]
|
597 |
+
s = nx.str
|
598 |
+
nospace = 1
|
599 |
+
if nospace: vs.append (s)
|
600 |
+
else: vs.append (' ' + s)
|
601 |
+
vs = ''.join (vs) # ad hoc: remove multiple pedal directions
|
602 |
+
while vs.find ('!ped!!ped!') >= 0: vs = vs.replace ('!ped!!ped!','!ped!')
|
603 |
+
while vs.find ('!ped-up!!ped-up!') >= 0: vs = vs.replace ('!ped-up!!ped-up!','!ped-up!')
|
604 |
+
while vs.find ('!8va(!!8va)!') >= 0: vs = vs.replace ('!8va(!!8va)!','') # remove empty ottava's
|
605 |
+
return vs
|
606 |
+
|
607 |
+
def sortMeasure (voice, m):
|
608 |
+
voice.sort (key=lambda o: o.tijd) # sort on time
|
609 |
+
time = 0
|
610 |
+
v = []
|
611 |
+
rs = [] # holds rests in between notes
|
612 |
+
for i, nx in enumerate (voice): # establish sequentiality
|
613 |
+
if nx.tijd > time and chkbug (nx.tijd - time, m):
|
614 |
+
v.append (Note (nx.tijd - time, 'x')) # fill hole with invisble rest
|
615 |
+
rs.append (len (v) - 1)
|
616 |
+
if isinstance (nx, Elem):
|
617 |
+
if nx.tijd < time: nx.tijd = time # shift elems without duration to where they fit
|
618 |
+
v.append (nx)
|
619 |
+
time = nx.tijd
|
620 |
+
continue
|
621 |
+
if nx.tijd < time: # overlapping element
|
622 |
+
if nx.ns[0] == 'z': continue # discard overlapping rest
|
623 |
+
if v[-1].tijd <= nx.tijd: # we can do something
|
624 |
+
if v[-1].ns[0] == 'z': # shorten rest
|
625 |
+
v[-1].dur = nx.tijd - v[-1].tijd
|
626 |
+
if v[-1].dur == 0: del v[-1] # nothing left
|
627 |
+
info ('overlap in part %d, measure %d: rest shortened' % (m.ixp+1, m.ixm+1))
|
628 |
+
else: # make a chord of overlap
|
629 |
+
v[-1].ns += nx.ns
|
630 |
+
info ('overlap in part %d, measure %d: added chord' % (m.ixp+1, m.ixm+1))
|
631 |
+
nx.dur = (nx.tijd + nx.dur) - time # the remains
|
632 |
+
if nx.dur <= 0: continue # nothing left
|
633 |
+
nx.tijd = time # append remains
|
634 |
+
else: # give up
|
635 |
+
info ('overlapping notes in one voice! part %d, measure %d, note %s discarded' % (m.ixp+1, m.ixm+1, isinstance (nx, Note) and nx.ns or nx.str))
|
636 |
+
continue
|
637 |
+
v.append (nx)
|
638 |
+
if isinstance (nx, Note):
|
639 |
+
if nx.ns [0] in 'zx':
|
640 |
+
rs.append (len (v) - 1) # remember rests between notes
|
641 |
+
elif len (rs):
|
642 |
+
if nx.beam and not nx.grace: # copy beam into rests
|
643 |
+
for j in rs: v[j].beam = nx.beam
|
644 |
+
rs = [] # clear rests on each note
|
645 |
+
time = nx.tijd + nx.dur
|
646 |
+
# when a measure contains no elements and no forwards -> no incTime -> s.maxtime = 0 -> right barline
|
647 |
+
# is inserted at time == 0 (in addbar) and is only element in the voice when sortMeasure is called
|
648 |
+
if time == 0: info ('empty measure in part %d, measure %d, it should contain at least a rest to advance the time!' % (m.ixp+1, m.ixm+1))
|
649 |
+
return v
|
650 |
+
|
651 |
+
def getPartlist (ps): # correct part-list (from buggy xml-software)
|
652 |
+
xs = [] # the corrected part-list
|
653 |
+
e = [] # stack of opened part-groups
|
654 |
+
for x in list (ps): # insert missing stops, delete double starts
|
655 |
+
if x.tag == 'part-group':
|
656 |
+
num, type = x.get ('number'), x.get ('type')
|
657 |
+
if type == 'start':
|
658 |
+
if num in e: # missing stop: insert one
|
659 |
+
xs.append (E.Element ('part-group', number = num, type = 'stop'))
|
660 |
+
xs.append (x)
|
661 |
+
else: # normal start
|
662 |
+
xs.append (x)
|
663 |
+
e.append (num)
|
664 |
+
else:
|
665 |
+
if num in e: # normal stop
|
666 |
+
e.remove (num)
|
667 |
+
xs.append (x)
|
668 |
+
else: pass # double stop: skip it
|
669 |
+
else: xs.append (x)
|
670 |
+
for num in reversed (e): # fill missing stops at the end
|
671 |
+
xs.append (E.Element ('part-group', number = num, type = 'stop'))
|
672 |
+
return xs
|
673 |
+
|
674 |
+
def parseParts (xs, d, e): # -> [elems on current level], rest of xs
|
675 |
+
if not xs: return [],[]
|
676 |
+
x = xs.pop (0)
|
677 |
+
if x.tag == 'part-group':
|
678 |
+
num, type = x.get ('number'), x.get ('type')
|
679 |
+
if type == 'start': # go one level deeper
|
680 |
+
s = [x.findtext (n, '') for n in ['group-symbol','group-barline','group-name','group-abbreviation']]
|
681 |
+
d [num] = s # remember groupdata by group number
|
682 |
+
e.append (num) # make stack of open group numbers
|
683 |
+
elemsnext, rest1 = parseParts (xs, d, e) # parse one level deeper to next stop
|
684 |
+
elems, rest2 = parseParts (rest1, d, e) # parse the rest on this level
|
685 |
+
return [elemsnext] + elems, rest2
|
686 |
+
else: # stop: close level and return group-data
|
687 |
+
nums = e.pop () # last open group number in stack order
|
688 |
+
if xs and xs[0].get ('type') == 'stop': # two consequetive stops
|
689 |
+
if num != nums: # in the wrong order (tempory solution)
|
690 |
+
d[nums], d[num] = d[num], d[nums] # exchange values (only works for two stops!!!)
|
691 |
+
sym = d[num] # retrieve an return groupdata as last element of the group
|
692 |
+
return [sym], xs
|
693 |
+
else:
|
694 |
+
elems, rest = parseParts (xs, d, e) # parse remaining elements on current level
|
695 |
+
name = x.findtext ('part-name',''), x.findtext ('part-abbreviation','')
|
696 |
+
return [name] + elems, rest
|
697 |
+
|
698 |
+
def bracePart (part): # put a brace on multistaff part and group voices
|
699 |
+
if not part: return [] # empty part in the score
|
700 |
+
brace = []
|
701 |
+
for ivs in part:
|
702 |
+
if len (ivs) == 1: # stave with one voice
|
703 |
+
brace.append ('%s' % ivs[0])
|
704 |
+
else: # stave with multiple voices
|
705 |
+
brace += ['('] + ['%s' % iv for iv in ivs] + [')']
|
706 |
+
brace.append ('|')
|
707 |
+
del brace[-1] # no barline at the end
|
708 |
+
if len (part) > 1:
|
709 |
+
brace = ['{'] + brace + ['}']
|
710 |
+
return brace
|
711 |
+
|
712 |
+
def prgroupelem (x, gnm, bar, pmap, accVce, accStf): # collect partnames (accVce) and %%score map (accStf)
|
713 |
+
if type (x) == tupletype: # partname-tuple = (part-name, part-abbrev)
|
714 |
+
y = pmap.pop (0)
|
715 |
+
if gnm[0]: x = [n1 + ':' + n2 for n1, n2 in zip (gnm, x)] # put group-name before part-name
|
716 |
+
accVce.append (x)
|
717 |
+
accStf.extend (bracePart (y))
|
718 |
+
elif len (x) == 2 and type (x[0]) == tupletype: # misuse of group just to add extra name to stave
|
719 |
+
y = pmap.pop (0)
|
720 |
+
nms = [n1 + ':' + n2 for n1, n2 in zip (x[0], x[1][2:])] # x[0] = partname-tuple, x[1][2:] = groupname-tuple
|
721 |
+
accVce.append (nms)
|
722 |
+
accStf.extend (bracePart (y))
|
723 |
+
else:
|
724 |
+
prgrouplist (x, bar, pmap, accVce, accStf)
|
725 |
+
|
726 |
+
def prgrouplist (x, pbar, pmap, accVce, accStf): # collect partnames, scoremap for a part-group
|
727 |
+
sym, bar, gnm, gabbr = x[-1] # bracket symbol, continue barline, group-name-tuple
|
728 |
+
bar = bar == 'yes' or pbar # pbar -> the parent has bar
|
729 |
+
accStf.append (sym == 'brace' and '{' or '[')
|
730 |
+
for z in x[:-1]:
|
731 |
+
prgroupelem (z, (gnm, gabbr), bar, pmap, accVce, accStf)
|
732 |
+
if bar: accStf.append ('|')
|
733 |
+
if bar: del accStf [-1] # remove last one before close
|
734 |
+
accStf.append (sym == 'brace' and '}' or ']')
|
735 |
+
|
736 |
+
def compUnitLength (iv, maten, divs): # compute optimal unit length
|
737 |
+
uLmin, minLen = 0, max_int
|
738 |
+
for uL in [4,8,16]: # try 1/4, 1/8 and 1/16
|
739 |
+
vLen = 0 # total length of abc duration strings in this voice
|
740 |
+
for im, m in enumerate (maten): # all measures
|
741 |
+
for e in m[iv]: # all notes in voice iv
|
742 |
+
if isinstance (e, Elem) or e.dur == 0: continue # no real durations
|
743 |
+
vLen += len (abcdur (e, divs [im], uL)) # add len of duration string
|
744 |
+
if vLen < minLen: uLmin, minLen = uL, vLen # remember the smallest
|
745 |
+
return uLmin
|
746 |
+
|
747 |
+
def doSyllable (syl):
|
748 |
+
txt = ''
|
749 |
+
for e in syl:
|
750 |
+
if e.tag == 'elision': txt += '~'
|
751 |
+
elif e.tag == 'text': # escape - and space characters
|
752 |
+
txt += (e.text or '').replace ('_','\_').replace('-', r'\-').replace(' ', '~')
|
753 |
+
if not txt: return txt
|
754 |
+
if syl.findtext('syllabic') in ['begin', 'middle']: txt += '-'
|
755 |
+
if syl.find('extend') is not None: txt += '_'
|
756 |
+
return txt
|
757 |
+
|
758 |
+
def checkMelismas (lyrics, maten, im, iv):
|
759 |
+
if im == 0: return
|
760 |
+
maat = maten [im][iv] # notes of the current measure
|
761 |
+
curlyr = lyrics [im][iv] # lyrics dict of current measure
|
762 |
+
prvlyr = lyrics [im-1][iv] # lyrics dict of previous measure
|
763 |
+
for n, (lyrstr, melis) in prvlyr.items (): # all lyric numbers in the previous measure
|
764 |
+
if n not in curlyr and melis: # melisma required, but no lyrics present -> make one!
|
765 |
+
ms = getMelisma (maat) # get a melisma for the current measure
|
766 |
+
if ms: curlyr [n] = (ms, 0) # set melisma as the n-th lyrics of the current measure
|
767 |
+
|
768 |
+
def getMelisma (maat): # get melisma from notes in maat
|
769 |
+
ms = []
|
770 |
+
for note in maat: # every note should get an underscore
|
771 |
+
if not isinstance (note, Note): continue # skip Elem's
|
772 |
+
if note.grace: continue # skip grace notes
|
773 |
+
if note.ns [0] in 'zx': break # stop on first rest
|
774 |
+
ms.append ('_')
|
775 |
+
return ' '.join (ms)
|
776 |
+
|
777 |
+
def perc2map (abcIn):
|
778 |
+
fillmap = {'diamond':1, 'triangle':1, 'square':1, 'normal':1};
|
779 |
+
abc = map (lambda x: x.strip (), percSvg.splitlines ())
|
780 |
+
id='default'
|
781 |
+
maps = {'default': []};
|
782 |
+
dmaps = {'default': []}
|
783 |
+
r1 = re.compile (r'V:\s*(\S+)')
|
784 |
+
ls = abcIn.splitlines ()
|
785 |
+
for x in ls:
|
786 |
+
if 'I:percmap' in x:
|
787 |
+
noot, step, midi, kop = map (lambda x: x.strip (), x.split ()[1:])
|
788 |
+
if kop in fillmap: kop = kop + '+' + ',' + kop
|
789 |
+
x = '%%%%map perc%s %s print=%s midi=%s heads=%s' % (id, noot, step, midi, kop)
|
790 |
+
maps [id].append (x)
|
791 |
+
if '%%MIDI' in x: dmaps [id].append (x)
|
792 |
+
if 'V:' in x:
|
793 |
+
r = r1.match (x)
|
794 |
+
if r:
|
795 |
+
id = r.group (1);
|
796 |
+
if id not in maps: maps [id] = []; dmaps [id] = []
|
797 |
+
ids = sorted (maps.keys ())
|
798 |
+
for id in ids: abc += maps [id]
|
799 |
+
id='default'
|
800 |
+
for x in ls:
|
801 |
+
if 'I:percmap' in x: continue
|
802 |
+
if '%%MIDI' in x: continue
|
803 |
+
if 'V:' in x or 'K:' in x:
|
804 |
+
r = r1.match (x)
|
805 |
+
if r: id = r.group (1)
|
806 |
+
abc.append (x)
|
807 |
+
if id in dmaps and len (dmaps [id]) > 0: abc.extend (dmaps [id]); del dmaps [id]
|
808 |
+
if 'perc' in x and 'map=' not in x: x += ' map=perc';
|
809 |
+
if 'map=perc' in x and len (maps [id]) > 0: abc.append ('%%voicemap perc' + id);
|
810 |
+
if 'map=off' in x: abc.append ('%%voicemap');
|
811 |
+
else:
|
812 |
+
abc.append (x)
|
813 |
+
return '\n'.join (abc) + '\n'
|
814 |
+
|
815 |
+
def addoct (ptc, o): # xml staff step, xml octave number
|
816 |
+
p = ptc
|
817 |
+
if o > 4: p = ptc.lower ()
|
818 |
+
if o > 5: p = p + (o-5) * "'"
|
819 |
+
if o < 4: p = p + (4-o) * ","
|
820 |
+
return p # abc pitch == abc note without accidental
|
821 |
+
|
822 |
+
def chkbug (dt, m):
|
823 |
+
if dt > m.divs / 16: return 1 # duration should be > 1/64 note
|
824 |
+
info ('MuseScore bug: incorrect duration, smaller then 1/64! in measure %d, part %d' % (m.ixm, m.ixp))
|
825 |
+
return 0
|
826 |
+
|
827 |
+
#----------------
|
828 |
+
# parser
|
829 |
+
#----------------
|
830 |
+
class Parser:
|
831 |
+
note_alts = [ # 3 alternative notations of the same note for tablature mapping
|
832 |
+
[x.strip () for x in '=C, ^C, =D, ^D, =E, =F, ^F, =G, ^G, =A, ^A, =B'.split (',')],
|
833 |
+
[x.strip () for x in '^B, _D,^^C, _E, _F, ^E, _G,^^F, _A,^^G, _B, _C'.split (',')],
|
834 |
+
[x.strip () for x in '__D,^^B,__E,__F,^^D,__G,^^E,__A,_/A,__B,__C,^^A'.split (',')] ]
|
835 |
+
step_map = {'C':0,'D':2,'E':4,'F':5,'G':7,'A':9,'B':11}
|
836 |
+
def __init__ (s, options):
|
837 |
+
# unfold repeats, number of chars per line, credit filter level, volta option
|
838 |
+
s.slurBuf = {} # dict of open slurs keyed by slur number
|
839 |
+
s.dirStk = {} # {direction-type + number -> (type, voice | time)} dict for proper closing
|
840 |
+
s.ingrace = 0 # marks a sequence of grace notes
|
841 |
+
s.msc = Music (options) # global music data abstraction
|
842 |
+
s.unfold = options.u # turn unfolding repeats on
|
843 |
+
s.ctf = options.c # credit text filter level
|
844 |
+
s.gStfMap = [] # [[abc voice numbers] for all parts]
|
845 |
+
s.midiMap = [] # midi-settings for each abc voice, in order
|
846 |
+
s.drumInst = {} # inst_id -> midi pitch for channel 10 notes
|
847 |
+
s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
|
848 |
+
s.instMid = [] # [{inst id -> midi-settings} for all parts]
|
849 |
+
s.midDflt = [-1,-1,-1,-91] # default midi settings for channel, program, volume, panning
|
850 |
+
s.msralts = {} # xml-notenames (without octave) with accidentals from the key
|
851 |
+
s.curalts = {} # abc-notenames (with voice number) with passing accidentals
|
852 |
+
s.stfMap = {} # xml staff number -> [xml voice number]
|
853 |
+
s.vce2stf = {} # xml voice number -> allocated staff number
|
854 |
+
s.clefMap = {} # xml staff number -> abc clef (for header only)
|
855 |
+
s.curClef = {} # xml staff number -> current abc clef
|
856 |
+
s.stemDir = {} # xml voice number -> current stem direction
|
857 |
+
s.clefOct = {} # xml staff number -> current clef-octave-change
|
858 |
+
s.curStf = {} # xml voice number -> current xml staff number
|
859 |
+
s.nolbrk = options.x; # generate no linebreaks ($)
|
860 |
+
s.jscript = options.j # compatibility with javascript version
|
861 |
+
s.ornaments = sorted (note_ornamentation_map.items ())
|
862 |
+
s.doPageFmt = len (options.p) == 1 # translate xml page format
|
863 |
+
s.tstep = options.t # clef determines step on staff (percussion)
|
864 |
+
s.dirtov1 = options.v1 # all directions to first voice of staff
|
865 |
+
s.ped = options.ped # render pedal directions
|
866 |
+
s.wstems = options.stm # translate stem elements
|
867 |
+
s.pedVce = None # voice for pedal directions
|
868 |
+
s.repeat_str = {} # staff number -> [measure number, repeat-text]
|
869 |
+
s.tabVceMap = {} # abc voice num -> [%%map ...] for tab voices
|
870 |
+
s.koppen = {} # noteheads needed for %%map
|
871 |
+
|
872 |
+
def matchSlur (s, type2, n, v2, note2, grace, stopgrace): # match slur number n in voice v2, add abc code to before/after
|
873 |
+
if type2 not in ['start', 'stop']: return # slur type continue has no abc equivalent
|
874 |
+
if n == None: n = '1'
|
875 |
+
if n in s.slurBuf:
|
876 |
+
type1, v1, note1, grace1 = s.slurBuf [n]
|
877 |
+
if type2 != type1: # slur complete, now check the voice
|
878 |
+
if v2 == v1: # begins and ends in the same voice: keep it
|
879 |
+
if type1 == 'start' and (not grace1 or not stopgrace): # normal slur: start before stop and no grace slur
|
880 |
+
note1.before = ['('] + note1.before # keep left-right order!
|
881 |
+
note2.after += ')'
|
882 |
+
# no else: don't bother with reversed stave spanning slurs
|
883 |
+
del s.slurBuf [n] # slur finished, remove from stack
|
884 |
+
else: # double definition, keep the last
|
885 |
+
info ('double slur numbers %s-%s in part %d, measure %d, voice %d note %s, first discarded' % (type2, n, s.msr.ixp+1, s.msr.ixm+1, v2, note2.ns))
|
886 |
+
s.slurBuf [n] = (type2, v2, note2, grace)
|
887 |
+
else: # unmatched slur, put in dict
|
888 |
+
s.slurBuf [n] = (type2, v2, note2, grace)
|
889 |
+
|
890 |
+
def doNotations (s, note, nttn, isTab):
|
891 |
+
for key, val in s.ornaments:
|
892 |
+
if nttn.find (key) != None: note.before += [val] # just concat all ornaments
|
893 |
+
trem = nttn.find ('ornaments/tremolo')
|
894 |
+
if trem != None:
|
895 |
+
type = trem.get ('type')
|
896 |
+
if type == 'single':
|
897 |
+
note.before.insert (0, '!%s!' % (int (trem.text) * '/'))
|
898 |
+
else:
|
899 |
+
note.fact = None # no time modification in ABC
|
900 |
+
if s.tstep: # abc2svg version
|
901 |
+
if type == 'stop': note.before.insert (0, '!trem%s!' % trem.text);
|
902 |
+
else: # abc2xml version
|
903 |
+
if type == 'start': note.before.insert (0, '!%s-!' % (int (trem.text) * '/'));
|
904 |
+
fingering = nttn.findall ('technical/fingering')
|
905 |
+
for finger in fingering: # handle multiple finger annotations
|
906 |
+
if not isTab: note.before += ['!%s!' % finger.text] # fingering goes before chord (addChord)
|
907 |
+
snaar = nttn.find ('technical/string')
|
908 |
+
if snaar != None and isTab:
|
909 |
+
if s.tstep:
|
910 |
+
fret = nttn.find ('technical/fret')
|
911 |
+
if fret != None: note.tab = (snaar.text, fret.text)
|
912 |
+
else:
|
913 |
+
deco = '!%s!' % snaar.text # no double string decos (bug in musescore)
|
914 |
+
if deco not in note.ntdec: note.ntdec += deco
|
915 |
+
wvln = nttn.find ('ornaments/wavy-line')
|
916 |
+
if wvln != None:
|
917 |
+
if wvln.get ('type') == 'start': note.before = ['!trill(!'] + note.before # keep left-right order!
|
918 |
+
elif wvln.get ('type') == 'stop': note.before = ['!trill)!'] + note.before
|
919 |
+
glis = nttn.find ('glissando')
|
920 |
+
if glis == None: glis = nttn.find ('slide') # treat slide as glissando
|
921 |
+
if glis != None:
|
922 |
+
lt = '~' if glis.get ('line-type') =='wavy' else '-'
|
923 |
+
if glis.get ('type') == 'start': note.before = ['!%s(!' % lt] + note.before # keep left-right order!
|
924 |
+
elif glis.get ('type') == 'stop': note.before = ['!%s)!' % lt] + note.before
|
925 |
+
|
926 |
+
def tabnote (s, alt, ptc, oct, v, ntrec):
|
927 |
+
p = s.step_map [ptc] + int (alt or '0') # p in -2 .. 13
|
928 |
+
if p > 11: oct += 1 # octave correction
|
929 |
+
if p < 0: oct -= 1
|
930 |
+
p = p % 12 # remap p into 0..11
|
931 |
+
snaar_nw, fret_nw = ntrec.tab # the computed/annotated allocation of nt
|
932 |
+
for i in range (4): # support same note on 4 strings
|
933 |
+
na = s.note_alts [i % 3] [p] # get alternative representation of same note
|
934 |
+
o = oct
|
935 |
+
if na in ['^B', '^^B']: o -= 1 # because in adjacent octave
|
936 |
+
if na in ['_C', '__C']: o += 1
|
937 |
+
if '/' in na or i == 3: o = 9 # emergency notation for 4th string case
|
938 |
+
nt = addoct (na, o)
|
939 |
+
snaar, fret = s.tabmap.get ((v, nt), ('', '')) # the current allocation of nt
|
940 |
+
if not snaar: break # note not yet allocated
|
941 |
+
if snaar_nw == snaar: return nt # use present allocation
|
942 |
+
if i == 3: # new allocaion needed but none is free
|
943 |
+
fmt = 'rejected: voice %d note %3s string %s fret %2s remains: string %s fret %s'
|
944 |
+
info (fmt % (v, nt, snaar_nw, fret_nw, snaar, fret), 1)
|
945 |
+
ntrec.tab = (snaar, fret)
|
946 |
+
s.tabmap [v, nt] = ntrec.tab # for tablature map (voice, note) -> (string, fret)
|
947 |
+
return nt # ABC code always in key C (with midi pitch alterations)
|
948 |
+
|
949 |
+
def ntAbc (s, ptc, oct, note, v, ntrec, isTab): # pitch, octave -> abc notation
|
950 |
+
acc2alt = {'double-flat':-2,'flat-flat':-2,'flat':-1,'natural':0,'sharp':1,'sharp-sharp':2,'double-sharp':2}
|
951 |
+
oct += s.clefOct.get (s.curStf [v], 0) # minus clef-octave-change value
|
952 |
+
acc = note.findtext ('accidental') # should be the notated accidental
|
953 |
+
alt = note.findtext ('pitch/alter') # pitch alteration (midi)
|
954 |
+
if ntrec.tab: return s.tabnote (alt, ptc, oct, v, ntrec) # implies s.tstep is true (options.t was given)
|
955 |
+
elif isTab and s.tstep:
|
956 |
+
nt = ['__','_','','^','^^'][int (alt or '0') + 2] + addoct (ptc, oct)
|
957 |
+
info ('no string notation found for note %s in voice %d' % (nt, v), 1)
|
958 |
+
p = addoct (ptc, oct)
|
959 |
+
if alt == None and s.msralts.get (ptc, 0): alt = 0 # no alt but key implies alt -> natural!!
|
960 |
+
if alt == None and (p, v) in s.curalts: alt = 0 # no alt but previous note had one -> natural!!
|
961 |
+
if acc == None and alt == None: return p # no acc, no alt
|
962 |
+
elif acc != None:
|
963 |
+
alt = acc2alt [acc] # acc takes precedence over the pitch here!
|
964 |
+
else: # now see if we really must add an accidental
|
965 |
+
alt = int (float (alt))
|
966 |
+
if (p, v) in s.curalts: # the note in this voice has been altered before
|
967 |
+
if alt == s.curalts [(p, v)]: return p # alteration still the same
|
968 |
+
elif alt == s.msralts.get (ptc, 0): return p # alteration implied by the key
|
969 |
+
tieElms = note.findall ('tie') + note.findall ('notations/tied') # in xml we have separate notated ties and playback ties
|
970 |
+
if 'stop' in [e.get ('type') for e in tieElms]: return p # don't alter tied notes
|
971 |
+
info ('accidental %d added in part %d, measure %d, voice %d note %s' % (alt, s.msr.ixp+1, s.msr.ixm+1, v+1, p))
|
972 |
+
s.curalts [(p, v)] = alt
|
973 |
+
p = ['__','_','=','^','^^'][alt+2] + p # and finally ... prepend the accidental
|
974 |
+
return p
|
975 |
+
|
976 |
+
def doNote (s, n): # parse a musicXML note tag
|
977 |
+
note = Note ()
|
978 |
+
v = int (n.findtext ('voice', '1'))
|
979 |
+
if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
|
980 |
+
chord = n.find ('chord') != None
|
981 |
+
p = n.findtext ('pitch/step') or n.findtext ('unpitched/display-step')
|
982 |
+
o = n.findtext ('pitch/octave') or n.findtext ('unpitched/display-octave')
|
983 |
+
r = n.find ('rest')
|
984 |
+
numer = n.findtext ('time-modification/actual-notes')
|
985 |
+
if numer:
|
986 |
+
denom = n.findtext ('time-modification/normal-notes')
|
987 |
+
note.fact = (int (numer), int (denom))
|
988 |
+
note.tup = [x.get ('type') for x in n.findall ('notations/tuplet')]
|
989 |
+
dur = n.findtext ('duration')
|
990 |
+
grc = n.find ('grace')
|
991 |
+
note.grace = grc != None
|
992 |
+
note.before, note.after = [], '' # strings with ABC stuff that goes before or after a note/chord
|
993 |
+
if note.grace and not s.ingrace: # open a grace sequence
|
994 |
+
s.ingrace = 1
|
995 |
+
note.before = ['{']
|
996 |
+
if grc.get ('slash') == 'yes': note.before += ['/'] # acciaccatura
|
997 |
+
stopgrace = not note.grace and s.ingrace
|
998 |
+
if stopgrace: # close the grace sequence
|
999 |
+
s.ingrace = 0
|
1000 |
+
s.msc.lastnote.after += '}' # close grace on lastenote.after
|
1001 |
+
if dur == None or note.grace: dur = 0
|
1002 |
+
if r == None and n.get ('print-object') == 'no':
|
1003 |
+
if chord: return
|
1004 |
+
r = 1 # turn invisible notes (that advance the time) into invisible rests
|
1005 |
+
note.dur = int (dur)
|
1006 |
+
if r == None and (not p or not o): # not a rest and no pitch
|
1007 |
+
s.msc.cnt.inc ('nopt', v) # count unpitched notes
|
1008 |
+
o, p = 5,'E' # make it an E5 ??
|
1009 |
+
isTab = s.curClef and s.curClef.get (s.curStf [v], '').startswith ('tab')
|
1010 |
+
nttn = n.find ('notations') # add ornaments
|
1011 |
+
if nttn != None: s.doNotations (note, nttn, isTab)
|
1012 |
+
e = n.find ('stem') if r == None else None # no !stemless! before rest
|
1013 |
+
if e != None and e.text == 'none' and (not isTab or v in s.hasStems or s.tstep):
|
1014 |
+
note.before += ['s']; abcOut.stemless = 1;
|
1015 |
+
e = n.find ('accidental')
|
1016 |
+
if e != None and e.get ('parentheses') == 'yes': note.ntdec += '!courtesy!'
|
1017 |
+
if r != None: noot = 'x' if n.get ('print-object') == 'no' or isTab else 'z'
|
1018 |
+
else: noot = s.ntAbc (p, int (o), n, v, note, isTab)
|
1019 |
+
if n.find ('unpitched') != None:
|
1020 |
+
clef = s.curClef [s.curStf [v]] # the current clef for this voice
|
1021 |
+
step = staffStep (p, int (o), clef, s.tstep) # (clef independent) step value of note on the staff
|
1022 |
+
instr = n.find ('instrument')
|
1023 |
+
instId = instr.get ('id') if instr != None else 'dummyId'
|
1024 |
+
midi = s.drumInst.get (instId, abcMid (noot))
|
1025 |
+
nh = n.findtext ('notehead', '').replace (' ','-') # replace spaces in xml notehead names for percmap
|
1026 |
+
if nh == 'x': noot = '^' + noot.replace ('^','').replace ('_','')
|
1027 |
+
if nh in ['circle-x','diamond','triangle']: noot = '_' + noot.replace ('^','').replace ('_','')
|
1028 |
+
if nh and n.find ('notehead').get ('filled','') == 'yes': nh += '+'
|
1029 |
+
if nh and n.find ('notehead').get ('filled','') == 'no': nh += '-'
|
1030 |
+
s.drumNotes [(v, noot)] = (step, midi, nh) # keep data for percussion map
|
1031 |
+
tieElms = n.findall ('tie') + n.findall ('notations/tied') # in xml we have separate notated ties and playback ties
|
1032 |
+
if 'start' in [e.get ('type') for e in tieElms]: # n can have stop and start tie
|
1033 |
+
noot = noot + '-'
|
1034 |
+
note.beam = sum ([1 for b in n.findall('beam') if b.text in ['continue', 'end']]) + int (note.grace)
|
1035 |
+
lyrlast = 0; rsib = re.compile (r'^.*verse')
|
1036 |
+
for e in n.findall ('lyric'):
|
1037 |
+
lyrnum = int (rsib.sub ('', e.get ('number', '1'))) # also do Sibelius numbers
|
1038 |
+
if lyrnum == 0: lyrnum = lyrlast + 1 # and correct Sibelius bugs
|
1039 |
+
else: lyrlast = lyrnum
|
1040 |
+
note.lyrs [lyrnum] = doSyllable (e)
|
1041 |
+
stemdir = n.findtext ('stem')
|
1042 |
+
if s.wstems and (stemdir == 'up' or stemdir == 'down'):
|
1043 |
+
if stemdir != s.stemDir.get (v, ''):
|
1044 |
+
s.stemDir [v] = stemdir
|
1045 |
+
s.msc.appendElem (v, '[I:stemdir %s]' % stemdir)
|
1046 |
+
if chord: s.msc.addChord (note, noot)
|
1047 |
+
else:
|
1048 |
+
xmlstaff = int (n.findtext ('staff', '1'))
|
1049 |
+
if s.curStf [v] != xmlstaff: # the note should go to another staff
|
1050 |
+
dstaff = xmlstaff - s.curStf [v] # relative new staff number
|
1051 |
+
s.curStf [v] = xmlstaff # remember the new staff for this voice
|
1052 |
+
s.msc.appendElem (v, '[I:staff %+d]' % dstaff) # insert a move before the note
|
1053 |
+
s.msc.appendNote (v, note, noot)
|
1054 |
+
for slur in n.findall ('notations/slur'): # s.msc.lastnote points to the last real note/chord inserted above
|
1055 |
+
s.matchSlur (slur.get ('type'), slur.get ('number'), v, s.msc.lastnote, note.grace, stopgrace) # match slur definitions
|
1056 |
+
|
1057 |
+
def doAttr (s, e): # parse a musicXML attribute tag
|
1058 |
+
teken = {'C1':'alto1','C2':'alto2','C3':'alto','C4':'tenor','F4':'bass','F3':'bass3','G2':'treble','TAB':'tab','percussion':'perc'}
|
1059 |
+
dvstxt = e.findtext ('divisions')
|
1060 |
+
if dvstxt: s.msr.divs = int (dvstxt)
|
1061 |
+
steps = int (e.findtext ('transpose/chromatic', '0')) # for transposing instrument
|
1062 |
+
fifths = e.findtext ('key/fifths')
|
1063 |
+
first = s.msc.tijd == 0 and s.msr.ixm == 0 # first attributes in first measure
|
1064 |
+
if fifths:
|
1065 |
+
key, s.msralts = setKey (int (fifths), e.findtext ('key/mode','major'))
|
1066 |
+
if first and not steps and abcOut.key == 'none':
|
1067 |
+
abcOut.key = key # first measure -> header, if not transposing instrument or percussion part!
|
1068 |
+
elif key != abcOut.key or not first:
|
1069 |
+
s.msr.attr += '[K:%s]' % key # otherwise -> voice
|
1070 |
+
beats = e.findtext ('time/beats')
|
1071 |
+
if beats:
|
1072 |
+
unit = e.findtext ('time/beat-type')
|
1073 |
+
mtr = beats + '/' + unit
|
1074 |
+
if first: abcOut.mtr = mtr # first measure -> header
|
1075 |
+
else: s.msr.attr += '[M:%s]' % mtr # otherwise -> voice
|
1076 |
+
s.msr.mtr = int (beats), int (unit)
|
1077 |
+
s.msr.mdur = (s.msr.divs * s.msr.mtr[0] * 4) // s.msr.mtr[1] # duration of measure in xml-divisions
|
1078 |
+
for ms in e.findall('measure-style'):
|
1079 |
+
n = int (ms.get ('number', '1')) # staff number
|
1080 |
+
voices = s.stfMap [n] # all voices of staff n
|
1081 |
+
for mr in ms.findall('measure-repeat'):
|
1082 |
+
ty = mr.get('type')
|
1083 |
+
if ty == 'start': # remember start measure number and text voor each staff
|
1084 |
+
s.repeat_str [n] = [s.msr.ixm, mr.text]
|
1085 |
+
for v in voices: # insert repeat into all voices, value will be overwritten at stop
|
1086 |
+
s.msc.insertElem (v, s.repeat_str [n])
|
1087 |
+
elif ty == 'stop': # calculate repeat measure count for this staff n
|
1088 |
+
start_ix, text_ = s.repeat_str [n]
|
1089 |
+
repeat_count = s.msr.ixm - start_ix
|
1090 |
+
if text_:
|
1091 |
+
mid_str = "%s " % text_
|
1092 |
+
repeat_count /= int (text_)
|
1093 |
+
else:
|
1094 |
+
mid_str = "" # overwrite repeat with final string
|
1095 |
+
s.repeat_str [n][0] = '[I:repeat %s%d]' % (mid_str, repeat_count)
|
1096 |
+
del s.repeat_str [n] # remove closed repeats
|
1097 |
+
toct = e.findtext ('transpose/octave-change', '')
|
1098 |
+
if toct: steps += 12 * int (toct) # extra transposition of toct octaves
|
1099 |
+
for clef in e.findall ('clef'): # a part can have multiple staves
|
1100 |
+
n = int (clef.get ('number', '1')) # local staff number for this clef
|
1101 |
+
sgn = clef.findtext ('sign')
|
1102 |
+
line = clef.findtext ('line', '') if sgn not in ['percussion','TAB'] else ''
|
1103 |
+
cs = teken.get (sgn + line, '')
|
1104 |
+
oct = clef.findtext ('clef-octave-change', '') or '0'
|
1105 |
+
if oct: cs += {-2:'-15', -1:'-8', 1:'+8', 2:'+15'}.get (int (oct), '')
|
1106 |
+
s.clefOct [n] = -int (oct); # xml playback pitch -> abc notation pitch
|
1107 |
+
if steps: cs += ' transpose=' + str (steps)
|
1108 |
+
stfdtl = e.find ('staff-details')
|
1109 |
+
if stfdtl and int (stfdtl.get ('number', '1')) == n:
|
1110 |
+
lines = stfdtl.findtext ('staff-lines')
|
1111 |
+
if lines:
|
1112 |
+
lns= '|||' if lines == '3' and sgn == 'TAB' else lines
|
1113 |
+
cs += ' stafflines=%s' % lns
|
1114 |
+
s.stafflines = int (lines) # remember for tab staves
|
1115 |
+
strings = stfdtl.findall ('staff-tuning')
|
1116 |
+
if strings:
|
1117 |
+
tuning = [st.findtext ('tuning-step') + st.findtext ('tuning-octave') for st in strings]
|
1118 |
+
cs += ' strings=%s' % ','.join (tuning)
|
1119 |
+
capo = stfdtl.findtext ('capo')
|
1120 |
+
if capo: cs += ' capo=%s' % capo
|
1121 |
+
s.curClef [n] = cs # keep track of current clef (for percmap)
|
1122 |
+
if first: s.clefMap [n] = cs # clef goes to header (where it is mapped to voices)
|
1123 |
+
else:
|
1124 |
+
voices = s.stfMap[n] # clef change to all voices of staff n
|
1125 |
+
for v in voices:
|
1126 |
+
if n != s.curStf [v]: # voice is not at its home staff n
|
1127 |
+
dstaff = n - s.curStf [v]
|
1128 |
+
s.curStf [v] = n # reset current staff at start of measure to home position
|
1129 |
+
s.msc.appendElem (v, '[I:staff %+d]' % dstaff)
|
1130 |
+
s.msc.appendElem (v, '[K:%s]' % cs)
|
1131 |
+
|
1132 |
+
def findVoice (s, i, es):
|
1133 |
+
stfnum = int (es[i].findtext ('staff',1)) # directions belong to a staff
|
1134 |
+
vs = s.stfMap [stfnum] # voices in this staff
|
1135 |
+
v1 = vs [0] if vs else 1 # directions to first voice of staff
|
1136 |
+
if s.dirtov1: return stfnum, v1, v1 # option --v1
|
1137 |
+
for e in es [i+1:]: # or to the voice of the next note
|
1138 |
+
if e.tag == 'note':
|
1139 |
+
v = int (e.findtext ('voice', '1'))
|
1140 |
+
if s.isSib: v += 100 * int (e.findtext ('staff', '1')) # repair bug in Sibelius
|
1141 |
+
stf = s.vce2stf [v] # use our own staff allocation
|
1142 |
+
return stf, v, v1 # voice of next note, first voice of staff
|
1143 |
+
if e.tag == 'backup': break
|
1144 |
+
return stfnum, v1, v1 # no note found, fall back to v1
|
1145 |
+
|
1146 |
+
def doDirection (s, e, i, es): # parse a musicXML direction tag
|
1147 |
+
def addDirection (x, vs, tijd, stfnum):
|
1148 |
+
if not x: return
|
1149 |
+
vs = s.stfMap [stfnum] if '!8v' in x else [vs] # ottava's go to all voices of staff
|
1150 |
+
for v in vs:
|
1151 |
+
if tijd != None: # insert at time of encounter
|
1152 |
+
s.msc.appendElemT (v, x.replace ('(',')').replace ('ped','ped-up'), tijd)
|
1153 |
+
else:
|
1154 |
+
s.msc.appendElem (v, x)
|
1155 |
+
def startStop (dtype, vs, stfnum=1):
|
1156 |
+
typmap = {'down':'!8va(!', 'up':'!8vb(!', 'crescendo':'!<(!', 'diminuendo':'!>(!', 'start':'!ped!'}
|
1157 |
+
type = t.get ('type', '')
|
1158 |
+
k = dtype + t.get ('number', '1') # key to match the closing direction
|
1159 |
+
if type in typmap: # opening the direction
|
1160 |
+
x = typmap [type]
|
1161 |
+
if k in s.dirStk: # closing direction already encountered
|
1162 |
+
stype, tijd = s.dirStk [k]; del s.dirStk [k]
|
1163 |
+
if stype == 'stop':
|
1164 |
+
addDirection (x, vs, tijd, stfnum)
|
1165 |
+
else:
|
1166 |
+
info ('%s direction %s has no stop in part %d, measure %d, voice %d' % (dtype, stype, s.msr.ixp+1, s.msr.ixm+1, vs+1))
|
1167 |
+
s.dirStk [k] = ((type , vs)) # remember voice and type for closing
|
1168 |
+
else:
|
1169 |
+
s.dirStk [k] = ((type , vs)) # remember voice and type for closing
|
1170 |
+
elif type == 'stop':
|
1171 |
+
if k in s.dirStk: # matching open direction found
|
1172 |
+
type, vs = s.dirStk [k]; del s.dirStk [k] # into the same voice
|
1173 |
+
if type == 'stop':
|
1174 |
+
info ('%s direction %s has double stop in part %d, measure %d, voice %d' % (dtype, type, s.msr.ixp+1, s.msr.ixm+1, vs+1))
|
1175 |
+
x = ''
|
1176 |
+
else:
|
1177 |
+
x = typmap [type].replace ('(',')').replace ('ped','ped-up')
|
1178 |
+
else: # closing direction found before opening
|
1179 |
+
s.dirStk [k] = ('stop', s.msc.tijd)
|
1180 |
+
x = '' # delay code generation until opening found
|
1181 |
+
else: raise ValueError ('wrong direction type')
|
1182 |
+
addDirection (x, vs, None, stfnum)
|
1183 |
+
tempo, wrdstxt = None, ''
|
1184 |
+
plcmnt = e.get ('placement')
|
1185 |
+
stf, vs, v1 = s.findVoice (i, es)
|
1186 |
+
jmp = '' # for jump sound elements: dacapo, dalsegno and family
|
1187 |
+
jmps = [('dacapo','D.C.'),('dalsegno','D.S.'),('tocoda','dacoda'),('fine','fine'),('coda','O'),('segno','S')]
|
1188 |
+
t = e.find ('sound') # there are many possible attributes for sound
|
1189 |
+
if t != None:
|
1190 |
+
minst = t.find ('midi-instrument')
|
1191 |
+
if minst:
|
1192 |
+
prg = t.findtext ('midi-instrument/midi-program')
|
1193 |
+
chn = t.findtext ('midi-instrument/midi-channel')
|
1194 |
+
vids = [v for v, id in s.vceInst.items () if id == minst.get ('id')]
|
1195 |
+
if vids: vs = vids [0] # direction for the indentified voice, not the staff
|
1196 |
+
parm, inst = ('program', str (int (prg) - 1)) if prg else ('channel', chn)
|
1197 |
+
if inst and abcOut.volpan > 0: s.msc.appendElem (vs, '[I:MIDI= %s %s]' % (parm, inst))
|
1198 |
+
tempo = t.get ('tempo') # look for tempo attribute
|
1199 |
+
if tempo:
|
1200 |
+
tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
|
1201 |
+
tempo_units = (1,4) # always 1/4 for sound elements!
|
1202 |
+
for r, v in jmps:
|
1203 |
+
if t.get (r, ''): jmp = v; break
|
1204 |
+
dirtypes = e.findall ('direction-type')
|
1205 |
+
for dirtyp in dirtypes:
|
1206 |
+
units = { 'whole': (1,1), 'half': (1,2), 'quarter': (1,4), 'eighth': (1,8) }
|
1207 |
+
metr = dirtyp.find ('metronome')
|
1208 |
+
if metr != None:
|
1209 |
+
t = metr.findtext ('beat-unit', '')
|
1210 |
+
if t in units: tempo_units = units [t]
|
1211 |
+
else: tempo_units = units ['quarter']
|
1212 |
+
if metr.find ('beat-unit-dot') != None:
|
1213 |
+
tempo_units = simplify (tempo_units [0] * 3, tempo_units [1] * 2)
|
1214 |
+
tmpro = re.search ('[.\d]+', metr.findtext ('per-minute')) # look for a number
|
1215 |
+
if tmpro: tempo = tmpro.group () # overwrites the value set by the sound element of this direction
|
1216 |
+
t = dirtyp.find ('wedge')
|
1217 |
+
if t != None: startStop ('wedge', vs)
|
1218 |
+
allwrds = dirtyp.findall ('words') # insert text annotations
|
1219 |
+
if not allwrds: allwrds = dirtyp.findall ('rehearsal') # treat rehearsal mark as text annotation
|
1220 |
+
for wrds in allwrds:
|
1221 |
+
if jmp: # ignore the words when a jump sound element is present in this direction
|
1222 |
+
s.msc.appendElem (vs, '!%s!' % jmp , 1) # to voice
|
1223 |
+
break
|
1224 |
+
plc = plcmnt == 'below' and '_' or '^'
|
1225 |
+
if float (wrds.get ('default-y', '0')) < 0: plc = '_'
|
1226 |
+
wrdstxt += (wrds.text or '').replace ('"','\\"').replace ('\n', '\\n')
|
1227 |
+
wrdstxt = wrdstxt.strip ()
|
1228 |
+
for key, val in dynamics_map.items ():
|
1229 |
+
if dirtyp.find ('dynamics/' + key) != None:
|
1230 |
+
s.msc.appendElem (vs, val, 1) # to voice
|
1231 |
+
if dirtyp.find ('coda') != None: s.msc.appendElem (vs, 'O', 1)
|
1232 |
+
if dirtyp.find ('segno') != None: s.msc.appendElem (vs, 'S', 1)
|
1233 |
+
t = dirtyp.find ('octave-shift')
|
1234 |
+
if t != None: startStop ('octave-shift', vs, stf) # assume size == 8 for the time being
|
1235 |
+
t = dirtyp.find ('pedal')
|
1236 |
+
if t != None and s.ped:
|
1237 |
+
if not s.pedVce: s.pedVce = vs
|
1238 |
+
startStop ('pedal', s.pedVce)
|
1239 |
+
if dirtyp.findtext ('other-direction') == 'diatonic fretting': s.diafret = 1;
|
1240 |
+
if tempo:
|
1241 |
+
tempo = '%.0f' % float (tempo) # hope it is a number and insert in voice 1
|
1242 |
+
if s.msc.tijd == 0 and s.msr.ixm == 0: # first measure -> header
|
1243 |
+
abcOut.tempo = tempo
|
1244 |
+
abcOut.tempo_units = tempo_units
|
1245 |
+
else:
|
1246 |
+
s.msc.appendElem (v1, '[Q:%d/%d=%s]' % (tempo_units [0], tempo_units [1], tempo)) # otherwise -> 1st voice
|
1247 |
+
if wrdstxt: s.msc.appendElem (vs, '"%s%s"' % (plc, wrdstxt), 1) # to voice, but after tempo
|
1248 |
+
|
1249 |
+
def doHarmony (s, e, i, es): # parse a musicXMl harmony tag
|
1250 |
+
_, vt, _ = s.findVoice (i, es)
|
1251 |
+
short = {'major':'', 'minor':'m', 'augmented':'+', 'diminished':'dim', 'dominant':'7', 'half-diminished':'m7b5'}
|
1252 |
+
accmap = {'major':'maj', 'dominant':'', 'minor':'m', 'diminished':'dim', 'augmented':'+', 'suspended':'sus'}
|
1253 |
+
modmap = {'second':'2', 'fourth':'4', 'seventh':'7', 'sixth':'6', 'ninth':'9', '11th':'11', '13th':'13'}
|
1254 |
+
altmap = {'1':'#', '0':'', '-1':'b'}
|
1255 |
+
root = e.findtext ('root/root-step','')
|
1256 |
+
alt = altmap.get (e.findtext ('root/root-alter'), '')
|
1257 |
+
sus = ''
|
1258 |
+
kind = e.findtext ('kind', '')
|
1259 |
+
if kind in short: kind = short [kind]
|
1260 |
+
elif '-' in kind: # xml chord names: <triad name>-<modification>
|
1261 |
+
triad, mod = kind.split ('-')
|
1262 |
+
kind = accmap.get (triad, '') + modmap.get (mod, '')
|
1263 |
+
if kind.startswith ('sus'): kind, sus = '', kind # sus-suffix goes to the end
|
1264 |
+
elif kind == 'none': kind = e.find ('kind').get ('text','')
|
1265 |
+
degrees = e.findall ('degree')
|
1266 |
+
for d in degrees: # chord alterations
|
1267 |
+
kind += altmap.get (d.findtext ('degree-alter'),'') + d.findtext ('degree-value','')
|
1268 |
+
kind = kind.replace ('79','9').replace ('713','13').replace ('maj6','6')
|
1269 |
+
bass = e.findtext ('bass/bass-step','') + altmap.get (e.findtext ('bass/bass-alter'),'')
|
1270 |
+
s.msc.appendElem (vt, '"%s%s%s%s%s"' % (root, alt, kind, sus, bass and '/' + bass), 1)
|
1271 |
+
|
1272 |
+
def doBarline (s, e): # 0 = no repeat, 1 = begin repeat, 2 = end repeat
|
1273 |
+
rep = e.find ('repeat')
|
1274 |
+
if rep != None: rep = rep.get ('direction')
|
1275 |
+
if s.unfold: # unfold repeat, don't translate barlines
|
1276 |
+
return rep and (rep == 'forward' and 1 or 2) or 0
|
1277 |
+
loc = e.get ('location', 'right') # right is the default
|
1278 |
+
if loc == 'right': # only change style for the right side
|
1279 |
+
style = e.findtext ('bar-style')
|
1280 |
+
if style == 'light-light': s.msr.rline = '||'
|
1281 |
+
elif style == 'light-heavy': s.msr.rline = '|]'
|
1282 |
+
if rep != None: # repeat found
|
1283 |
+
if rep == 'forward': s.msr.lline = ':'
|
1284 |
+
else: s.msr.rline = ':|' # override barline style
|
1285 |
+
end = e.find ('ending')
|
1286 |
+
if end != None:
|
1287 |
+
if end.get ('type') == 'start':
|
1288 |
+
n = end.get ('number', '1').replace ('.','').replace (' ','')
|
1289 |
+
try: list (map (int, n.split (','))) # should be a list of integers
|
1290 |
+
except: n = '"%s"' % n.strip () # illegal musicXML
|
1291 |
+
s.msr.lnum = n # assume a start is always at the beginning of a measure
|
1292 |
+
elif s.msr.rline == '|': # stop and discontinue the same in ABC ?
|
1293 |
+
s.msr.rline = '||' # to stop on a normal barline use || in ABC ?
|
1294 |
+
return 0
|
1295 |
+
|
1296 |
+
def doPrint (s, e): # print element, measure number -> insert a line break
|
1297 |
+
if e.get ('new-system') == 'yes' or e.get ('new-page') == 'yes':
|
1298 |
+
if not s.nolbrk: return '$' # a line break
|
1299 |
+
|
1300 |
+
def doPartList (s, e): # translate the start/stop-event-based xml-partlist into proper tree
|
1301 |
+
for sp in e.findall ('part-list/score-part'):
|
1302 |
+
midi = {}
|
1303 |
+
for m in sp.findall ('midi-instrument'):
|
1304 |
+
x = [m.findtext (p, s.midDflt [i]) for i,p in enumerate (['midi-channel','midi-program','volume','pan'])]
|
1305 |
+
pan = float (x[3])
|
1306 |
+
if pan >= -90 and pan <= 90: # would be better to map behind-pannings
|
1307 |
+
pan = (float (x[3]) + 90) / 180 * 127 # xml between -90 and +90
|
1308 |
+
midi [m.get ('id')] = [int (x[0]), int (x[1]), float (x[2]) * 1.27, pan] # volume 100 -> midi 127
|
1309 |
+
up = m.findtext ('midi-unpitched')
|
1310 |
+
if up: s.drumInst [m.get ('id')] = int (up) - 1 # store midi-pitch for channel 10 notes
|
1311 |
+
s.instMid.append (midi)
|
1312 |
+
ps = e.find ('part-list') # partlist = [groupelem]
|
1313 |
+
xs = getPartlist (ps) # groupelem = partname | grouplist
|
1314 |
+
partlist, _ = parseParts (xs, {}, []) # grouplist = [groupelem, ..., groupdata]
|
1315 |
+
return partlist # groupdata = [group-symbol, group-barline, group-name, group-abbrev]
|
1316 |
+
|
1317 |
+
def mkTitle (s, e):
|
1318 |
+
def filterCredits (y): # y == filter level, higher filters less
|
1319 |
+
cs = []
|
1320 |
+
for x in credits: # skip redundant credit lines
|
1321 |
+
if y < 6 and (x in title or x in mvttl): continue # sure skip
|
1322 |
+
if y < 5 and (x in composer or x in lyricist): continue # almost sure skip
|
1323 |
+
if y < 4 and ((title and title in x) or (mvttl and mvttl in x)): continue # may skip too much
|
1324 |
+
if y < 3 and ([1 for c in composer if c in x] or [1 for c in lyricist if c in x]): continue # skips too much
|
1325 |
+
if y < 2 and re.match (r'^[\d\W]*$', x): continue # line only contains numbers and punctuation
|
1326 |
+
cs.append (x)
|
1327 |
+
if y == 0 and (title + mvttl): cs = '' # default: only credit when no title set
|
1328 |
+
return cs
|
1329 |
+
title = e.findtext ('work/work-title', '').strip ()
|
1330 |
+
mvttl = e.findtext ('movement-title', '').strip ()
|
1331 |
+
composer, lyricist, credits = [], [], []
|
1332 |
+
for creator in e.findall ('identification/creator'):
|
1333 |
+
if creator.text:
|
1334 |
+
if creator.get ('type') == 'composer':
|
1335 |
+
composer += [line.strip () for line in creator.text.split ('\n')]
|
1336 |
+
elif creator.get ('type') in ('lyricist', 'transcriber'):
|
1337 |
+
lyricist += [line.strip () for line in creator.text.split ('\n')]
|
1338 |
+
for rights in e.findall ('identification/rights'):
|
1339 |
+
if rights.text:
|
1340 |
+
lyricist += [line.strip () for line in rights.text.split ('\n')]
|
1341 |
+
for credit in e.findall('credit'):
|
1342 |
+
cs = ''.join (e.text or '' for e in credit.findall('credit-words'))
|
1343 |
+
credits += [re.sub (r'\s*[\r\n]\s*', ' ', cs)]
|
1344 |
+
credits = filterCredits (s.ctf)
|
1345 |
+
if title: title = 'T:%s\n' % title.replace ('\n', '\nT:')
|
1346 |
+
if mvttl: title += 'T:%s\n' % mvttl.replace ('\n', '\nT:')
|
1347 |
+
if credits: title += '\n'.join (['T:%s' % c for c in credits]) + '\n'
|
1348 |
+
if composer: title += '\n'.join (['C:%s' % c for c in composer]) + '\n'
|
1349 |
+
if lyricist: title += '\n'.join (['Z:%s' % c for c in lyricist]) + '\n'
|
1350 |
+
if title: abcOut.title = title[:-1]
|
1351 |
+
s.isSib = 'Sibelius' in (e.findtext ('identification/encoding/software') or '')
|
1352 |
+
if s.isSib: info ('Sibelius MusicXMl is unreliable')
|
1353 |
+
|
1354 |
+
def doDefaults (s, e):
|
1355 |
+
if not s.doPageFmt: return # return if -pf option absent
|
1356 |
+
d = e.find ('defaults');
|
1357 |
+
if d == None: return;
|
1358 |
+
mils = d.findtext ('scaling/millimeters') # mills == staff height (mm)
|
1359 |
+
tenths = d.findtext ('scaling/tenths') # staff height in tenths
|
1360 |
+
if not mils or not tenths: return
|
1361 |
+
xmlScale = float (mils) / float (tenths) / 10 # tenths -> mm
|
1362 |
+
space = 10 * xmlScale # space between staff lines == 10 tenths
|
1363 |
+
abcScale = space / 0.2117 # 0.2117 cm = 6pt = space between staff lines for scale = 1.0 in abcm2ps
|
1364 |
+
abcOut.pageFmt ['scale'] = abcScale
|
1365 |
+
eks = 2 * ['page-layout/'] + 4 * ['page-layout/page-margins/']
|
1366 |
+
eks = [a+b for a,b in zip (eks, 'page-height,page-width,left-margin,right-margin,top-margin,bottom-margin'.split (','))]
|
1367 |
+
for i in range (6):
|
1368 |
+
v = d.findtext (eks [i])
|
1369 |
+
k = abcOut.pagekeys [i+1] # pagekeys [0] == scale already done, skip it
|
1370 |
+
if not abcOut.pageFmt [k] and v:
|
1371 |
+
try: abcOut.pageFmt [k] = float (v) * xmlScale # -> cm
|
1372 |
+
except: info ('illegal value %s for xml element %s', (v, eks [i])); continue # just skip illegal values
|
1373 |
+
|
1374 |
+
def locStaffMap (s, part, maten): # map voice to staff with majority voting
|
1375 |
+
vmap = {} # {voice -> {staff -> n}} count occurrences of voice in staff
|
1376 |
+
s.vceInst = {} # {voice -> instrument id} for this part
|
1377 |
+
s.msc.vnums = {} # voice id's
|
1378 |
+
s.hasStems = {} # XML voice nums with at least one note with a stem (for tab key)
|
1379 |
+
s.stfMap, s.clefMap = {}, {} # staff -> [voices], staff -> clef
|
1380 |
+
ns = part.findall ('measure/note')
|
1381 |
+
for n in ns: # count staff allocations for all notes
|
1382 |
+
v = int (n.findtext ('voice', '1'))
|
1383 |
+
if s.isSib: v += 100 * int (n.findtext ('staff', '1')) # repair bug in Sibelius
|
1384 |
+
s.msc.vnums [v] = 1 # collect all used voice id's in this part
|
1385 |
+
sn = int (n.findtext ('staff', '1'))
|
1386 |
+
s.stfMap [sn] = []
|
1387 |
+
if v not in vmap:
|
1388 |
+
vmap [v] = {sn:1}
|
1389 |
+
else:
|
1390 |
+
d = vmap[v] # counter for voice v
|
1391 |
+
d[sn] = d.get (sn, 0) + 1 # ++ number of allocations for staff sn
|
1392 |
+
x = n.find ('instrument')
|
1393 |
+
if x != None: s.vceInst [v] = x.get ('id')
|
1394 |
+
x, noRest = n.findtext ('stem'), n.find ('rest') == None
|
1395 |
+
if noRest and (not x or x != 'none'): s.hasStems [v] = 1 # XML voice v has at least one stem
|
1396 |
+
vks = list (vmap.keys ())
|
1397 |
+
if s.jscript or s.isSib: vks.sort ()
|
1398 |
+
for v in vks: # choose staff with most allocations for each voice
|
1399 |
+
xs = [(n, sn) for sn, n in vmap[v].items ()]
|
1400 |
+
xs.sort ()
|
1401 |
+
stf = xs[-1][1] # the winner: staff with most notes of voice v
|
1402 |
+
s.stfMap [stf].append (v)
|
1403 |
+
s.vce2stf [v] = stf # reverse map
|
1404 |
+
s.curStf [v] = stf # current staff of XML voice v
|
1405 |
+
|
1406 |
+
def addStaffMap (s, vvmap): # vvmap: xml voice number -> global abc voice number
|
1407 |
+
part = [] # default: brace on staffs of one part
|
1408 |
+
for stf, voices in sorted (s.stfMap.items ()): # s.stfMap has xml staff and voice numbers
|
1409 |
+
locmap = [vvmap [iv] for iv in voices if iv in vvmap]
|
1410 |
+
nostem = [(iv not in s.hasStems) for iv in voices if iv in vvmap] # same order as locmap
|
1411 |
+
if locmap: # abc voice number of staff stf
|
1412 |
+
part.append (locmap)
|
1413 |
+
clef = s.clefMap.get (stf, 'treble') # {xml staff number -> clef}
|
1414 |
+
for i, iv in enumerate (locmap):
|
1415 |
+
clef_attr = ''
|
1416 |
+
if clef.startswith ('tab'):
|
1417 |
+
if nostem [i] and 'nostems' not in clef: clef_attr = ' nostems'
|
1418 |
+
if s.diafret and 'diafret' not in clef: clef_attr += ' diafret' # for all voices in the part
|
1419 |
+
abcOut.clefs [iv] = clef + clef_attr # add nostems when all notes of voice had no stem
|
1420 |
+
s.gStfMap.append (part)
|
1421 |
+
|
1422 |
+
def addMidiMap (s, ip, vvmap): # map abc voices to midi settings
|
1423 |
+
instr = s.instMid [ip] # get the midi settings for this part
|
1424 |
+
if instr.values (): defInstr = list(instr.values ())[0] # default settings = first instrument
|
1425 |
+
else: defInstr = s.midDflt # no instruments defined
|
1426 |
+
xs = []
|
1427 |
+
for v, vabc in vvmap.items (): # xml voice num, abc voice num
|
1428 |
+
ks = sorted (s.drumNotes.items ())
|
1429 |
+
ds = [(nt, step, midi, head) for (vd, nt), (step, midi, head) in ks if v == vd] # map perc notes
|
1430 |
+
id = s.vceInst.get (v, '') # get the instrument-id for part with multiple instruments
|
1431 |
+
if id in instr: # id is defined as midi-instrument in part-list
|
1432 |
+
xs.append ((vabc, instr [id] + ds)) # get midi settings for id
|
1433 |
+
else: xs.append ((vabc, defInstr + ds)) # only one instrument for this part
|
1434 |
+
xs.sort () # put abc voices in order
|
1435 |
+
s.midiMap.extend ([midi for v, midi in xs])
|
1436 |
+
snaarmap = ['E','G','B','d', 'f', 'a', "c'", "e'"]
|
1437 |
+
diamap = '0,1-,1,1+,2,3,3,4,4,5,6,6+,7,8-,8,8+,9,10,10,11,11,12,13,13+,14'.split (',')
|
1438 |
+
for k in sorted (s.tabmap.keys ()): # add %%map's for all tab voices
|
1439 |
+
v, noot = k;
|
1440 |
+
snaar, fret = s.tabmap [k];
|
1441 |
+
if s.diafret: fret = diamap [int (fret)]
|
1442 |
+
vabc = vvmap [v]
|
1443 |
+
snaar = s.stafflines - int (snaar)
|
1444 |
+
xs = s.tabVceMap.get (vabc, [])
|
1445 |
+
xs.append ('%%%%map tab%d %s print=%s heads=kop%s\n' % (vabc, noot, snaarmap [snaar], fret))
|
1446 |
+
s.tabVceMap [vabc] = xs
|
1447 |
+
s.koppen [fret] = 1 # collect noteheads for SVG defs
|
1448 |
+
|
1449 |
+
def parse (s, fobj):
|
1450 |
+
vvmapAll = {} # collect xml->abc voice maps (vvmap) of all parts
|
1451 |
+
e = E.parse (fobj)
|
1452 |
+
s.mkTitle (e)
|
1453 |
+
s.doDefaults (e)
|
1454 |
+
partlist = s.doPartList (e)
|
1455 |
+
parts = e.findall ('part')
|
1456 |
+
for ip, p in enumerate (parts):
|
1457 |
+
maten = p.findall ('measure')
|
1458 |
+
s.locStaffMap (p, maten) # {voice -> staff} for this part
|
1459 |
+
s.drumNotes = {} # (xml voice, abc note) -> (midi note, note head)
|
1460 |
+
s.clefOct = {} # xml staff number -> current clef-octave-change
|
1461 |
+
s.curClef = {} # xml staff number -> current abc clef
|
1462 |
+
s.stemDir = {} # xml voice number -> current stem direction
|
1463 |
+
s.tabmap = {} # (xml voice, abc note) -> (string, fret)
|
1464 |
+
s.diafret = 0 # use diatonic fretting
|
1465 |
+
s.stafflines = 5
|
1466 |
+
s.msc.initVoices (newPart = 1) # create all voices
|
1467 |
+
aantalHerhaald = 0 # keep track of number of repititions
|
1468 |
+
herhaalMaat = 0 # target measure of the repitition
|
1469 |
+
divisions = [] # current value of <divisions> for each measure
|
1470 |
+
s.msr = Measure (ip) # various measure data
|
1471 |
+
while s.msr.ixm < len (maten):
|
1472 |
+
maat = maten [s.msr.ixm]
|
1473 |
+
herhaal, lbrk = 0, ''
|
1474 |
+
s.msr.reset ()
|
1475 |
+
s.curalts = {} # passing accidentals are reset each measure
|
1476 |
+
es = list (maat)
|
1477 |
+
for i, e in enumerate (es):
|
1478 |
+
if e.tag == 'note': s.doNote (e)
|
1479 |
+
elif e.tag == 'attributes': s.doAttr (e)
|
1480 |
+
elif e.tag == 'direction': s.doDirection (e, i, es)
|
1481 |
+
elif e.tag == 'sound': s.doDirection (maat, i, es) # sound element directly in measure!
|
1482 |
+
elif e.tag == 'harmony': s.doHarmony (e, i, es)
|
1483 |
+
elif e.tag == 'barline': herhaal = s.doBarline (e)
|
1484 |
+
elif e.tag == 'backup':
|
1485 |
+
dt = int (e.findtext ('duration'))
|
1486 |
+
if chkbug (dt, s.msr): s.msc.incTime (-dt)
|
1487 |
+
elif e.tag == 'forward':
|
1488 |
+
dt = int (e.findtext ('duration'))
|
1489 |
+
if chkbug (dt, s.msr): s.msc.incTime (dt)
|
1490 |
+
elif e.tag == 'print': lbrk = s.doPrint (e)
|
1491 |
+
s.msc.addBar (lbrk, s.msr)
|
1492 |
+
divisions.append (s.msr.divs)
|
1493 |
+
if herhaal == 1:
|
1494 |
+
herhaalMaat = s.msr.ixm
|
1495 |
+
s.msr.ixm += 1
|
1496 |
+
elif herhaal == 2:
|
1497 |
+
if aantalHerhaald < 1: # jump
|
1498 |
+
s.msr.ixm = herhaalMaat
|
1499 |
+
aantalHerhaald += 1
|
1500 |
+
else:
|
1501 |
+
aantalHerhaald = 0 # reset
|
1502 |
+
s.msr.ixm += 1 # just continue
|
1503 |
+
else: s.msr.ixm += 1 # on to the next measure
|
1504 |
+
for rv in s.repeat_str.values (): # close hanging measure-repeats without stop
|
1505 |
+
rv [0] = '[I:repeat %s %d]' % (rv [1], 1)
|
1506 |
+
vvmap = s.msc.outVoices (divisions, ip, s.isSib)
|
1507 |
+
s.addStaffMap (vvmap) # update global staff map
|
1508 |
+
s.addMidiMap (ip, vvmap)
|
1509 |
+
vvmapAll.update (vvmap)
|
1510 |
+
if vvmapAll: # skip output if no part has any notes
|
1511 |
+
abcOut.mkHeader (s.gStfMap, partlist, s.midiMap, s.tabVceMap, s.koppen)
|
1512 |
+
abcOut.writeall ()
|
1513 |
+
else: info ('nothing written, %s has no notes ...' % abcOut.fnmext)
|
1514 |
+
|
1515 |
+
#----------------
|
1516 |
+
# Main Program
|
1517 |
+
#----------------
|
1518 |
+
if __name__ == '__main__':
|
1519 |
+
from optparse import OptionParser
|
1520 |
+
from glob import glob
|
1521 |
+
from zipfile import ZipFile
|
1522 |
+
ustr = '%prog [-h] [-u] [-m] [-c C] [-d D] [-n CPL] [-b BPL] [-o DIR] [-v V]\n'
|
1523 |
+
ustr += '[-x] [-p PFMT] [-t] [-s] [-i] [--v1] [--noped] [--stems] <file1> [<file2> ...]'
|
1524 |
+
parser = OptionParser (usage=ustr, version=str(VERSION))
|
1525 |
+
parser.add_option ("-u", action="store_true", help="unfold simple repeats")
|
1526 |
+
parser.add_option ("-m", action="store", help="0 -> no %%MIDI, 1 -> minimal %%MIDI, 2-> all %%MIDI", default=0)
|
1527 |
+
parser.add_option ("-c", action="store", type="int", help="set credit text filter to C", default=0, metavar='C')
|
1528 |
+
parser.add_option ("-d", action="store", type="int", help="set L:1/D", default=0, metavar='D')
|
1529 |
+
parser.add_option ("-n", action="store", type="int", help="CPL: max number of characters per line (default 100)", default=0, metavar='CPL')
|
1530 |
+
parser.add_option ("-b", action="store", type="int", help="BPL: max number of bars per line", default=0, metavar='BPL')
|
1531 |
+
parser.add_option ("-o", action="store", help="store abc files in DIR", default='', metavar='DIR')
|
1532 |
+
parser.add_option ("-v", action="store", type="int", help="set volta typesetting behaviour to V", default=0, metavar='V')
|
1533 |
+
parser.add_option ("-x", action="store_true", help="output no line breaks")
|
1534 |
+
parser.add_option ("-p", action="store", help="pageformat PFMT (cm) = scale, pageheight, pagewidth, leftmargin, rightmargin, topmargin, botmargin", default='', metavar='PFMT')
|
1535 |
+
parser.add_option ("-j", action="store_true", help="switch for compatibility with javascript version")
|
1536 |
+
parser.add_option ("-t", action="store_true", help="translate perc- and tab-staff to ABC code with %%map, %%voicemap")
|
1537 |
+
parser.add_option ("-s", action="store_true", help="shift node heads 3 units left in a tab staff")
|
1538 |
+
parser.add_option ("--v1", action="store_true", help="start-stop directions allways to first voice of staff")
|
1539 |
+
parser.add_option ("--noped", action="store_false", help="skip all pedal directions", dest='ped', default=True)
|
1540 |
+
parser.add_option ("--stems", action="store_true", help="translate stem directions", dest='stm', default=False)
|
1541 |
+
parser.add_option ("-i", action="store_true", help="read xml file from standard input")
|
1542 |
+
options, args = parser.parse_args ()
|
1543 |
+
if options.n < 0: parser.error ('only values >= 0')
|
1544 |
+
if options.b < 0: parser.error ('only values >= 0')
|
1545 |
+
if options.d and options.d not in [2**n for n in range (10)]:
|
1546 |
+
parser.error ('D should be on of %s' % ','.join ([str(2**n) for n in range (10)]))
|
1547 |
+
options.p = options.p and options.p.split (',') or [] # ==> [] | [string]
|
1548 |
+
if len (args) == 0 and not options.i: parser.error ('no input file given')
|
1549 |
+
pad = options.o
|
1550 |
+
if pad:
|
1551 |
+
if not os.path.exists (pad): os.mkdir (pad)
|
1552 |
+
if not os.path.isdir (pad): parser.error ('%s is not a directory' % pad)
|
1553 |
+
fnmext_list = []
|
1554 |
+
for i in args: fnmext_list += glob (i)
|
1555 |
+
if options.i: fnmext_list = ['stdin.xml']
|
1556 |
+
if not fnmext_list: parser.error ('none of the input files exist')
|
1557 |
+
for X, fnmext in enumerate (fnmext_list):
|
1558 |
+
fnm, ext = os.path.splitext (fnmext)
|
1559 |
+
if ext.lower () not in ('.xml','.mxl','.musicxml'):
|
1560 |
+
info ('skipped input file %s, it should have extension .xml or .mxl' % fnmext)
|
1561 |
+
continue
|
1562 |
+
if os.path.isdir (fnmext):
|
1563 |
+
info ('skipped directory %s. Only files are accepted' % fnmext)
|
1564 |
+
continue
|
1565 |
+
if fnmext == 'stdin.xml':
|
1566 |
+
fobj = sys.stdin
|
1567 |
+
elif ext.lower () == '.mxl': # extract .xml file from .mxl file
|
1568 |
+
z = ZipFile(fnmext)
|
1569 |
+
for n in z.namelist(): # assume there is always an xml file in a mxl archive !!
|
1570 |
+
if (n[:4] != 'META') and (n[-4:].lower() == '.xml'):
|
1571 |
+
fobj = z.open (n)
|
1572 |
+
break # assume only one MusicXML file per archive
|
1573 |
+
else:
|
1574 |
+
fobj = open (fnmext, 'rb') # open regular xml file
|
1575 |
+
|
1576 |
+
abcOut = ABCoutput (fnm + '.abc', pad, X, options) # create global ABC output object
|
1577 |
+
psr = Parser (options) # xml parser
|
1578 |
+
try:
|
1579 |
+
psr.parse (fobj) # parse file fobj and write abc to <fnm>.abc
|
1580 |
+
except:
|
1581 |
+
etype, value, traceback = sys.exc_info () # works in python 2 & 3
|
1582 |
+
info ('** %s occurred: %s in %s' % (etype, value, fnmext), 0)
|
semantic_search/README.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Semantic Search Codebase
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
CLaMP 2 is a state-of-the-art multimodal music information retrieval system designed to work with 101 languages. This codebase includes scripts for evaluating model performance, performing semantic searches, and calculating similarity metrics based on CLaMP2-extracted **nomarlized** feature vectors from music or text data. Below is a description of the scripts contained in the `semantic_search/` folder.
|
5 |
+
|
6 |
+
## Repository Structure
|
7 |
+
The `semantic_search/` folder contains the following scripts:
|
8 |
+
|
9 |
+
### 1. `clamp2_score.py`
|
10 |
+
This script calculates the cosine similarity between the average feature vectors extracted from two sets of `.npy` files, serving as a measure of similarity between the reference and test datasets.
|
11 |
+
|
12 |
+
It can be used to validate the semantic similarity between generated music and ground truth, providing an objective metric. Through empirical observation, we found that this metric aligns well with subjective judgments made by individuals with professional music expertise.
|
13 |
+
|
14 |
+
**Usage:**
|
15 |
+
```bash
|
16 |
+
python clamp2_score.py <reference_folder> <test_folder>
|
17 |
+
```
|
18 |
+
- `reference_folder`: Path to the folder containing reference `.npy` files.
|
19 |
+
- `test_folder`: Path to the folder containing test `.npy` files.
|
20 |
+
|
21 |
+
**Functionality:**
|
22 |
+
- Loads all `.npy` files from the specified folders.
|
23 |
+
- Computes the average feature vector for each folder.
|
24 |
+
- Calculates the cosine similarity between the two averaged vectors.
|
25 |
+
- Outputs the similarity score rounded to four decimal places.
|
26 |
+
|
27 |
+
### 2. `semantic_search.py`
|
28 |
+
This script performs semantic search by calculating the cosine similarity between a query feature and a set of features stored in `.npy` files.
|
29 |
+
|
30 |
+
**Usage:**
|
31 |
+
```bash
|
32 |
+
python semantic_search.py <query_file> <features_folder> [--top_k TOP_K]
|
33 |
+
```
|
34 |
+
- `query_file`: Path to the query feature file (e.g., `ballad.npy`).
|
35 |
+
- `features_folder`: Path to the folder containing feature files for comparison.
|
36 |
+
- `--top_k`: (Optional) Number of top similar items to display. Defaults to 10 if not specified.
|
37 |
+
|
38 |
+
**Functionality:**
|
39 |
+
- Loads a query feature from the specified file.
|
40 |
+
- Loads feature vectors from the given folder.
|
41 |
+
- Computes cosine similarity between the query feature and each loaded feature vector.
|
42 |
+
- Displays the top K most similar features along with their similarity scores.
|
43 |
+
|
44 |
+
### 3. `semantic_search_metrics.py`
|
45 |
+
This script calculates evaluation metrics for semantic search by comparing query features to reference features.
|
46 |
+
|
47 |
+
**Usage:**
|
48 |
+
```bash
|
49 |
+
python semantic_search_metrics.py <query_folder> <reference_folder>
|
50 |
+
```
|
51 |
+
- `query_folder`: Path to the folder containing query features (in `.npy` format).
|
52 |
+
- `reference_folder`: Path to the folder containing reference features (in `.npy` format).
|
53 |
+
|
54 |
+
**Functionality:**
|
55 |
+
- Loads query features from the specified folder.
|
56 |
+
- Loads reference features from the given folder.
|
57 |
+
- Computes the following metrics based on cosine similarity:
|
58 |
+
- **Mean Reciprocal Rank (MRR)**
|
59 |
+
- **Hit@1**
|
60 |
+
- **Hit@10**
|
61 |
+
- **Hit@100**
|
62 |
+
- Outputs the calculated metrics to the console.
|
semantic_search/clamp2_score.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
def load_npy_files(folder_path):
|
6 |
+
"""
|
7 |
+
Load all .npy files from a specified folder and return a list of numpy arrays.
|
8 |
+
"""
|
9 |
+
npy_list = []
|
10 |
+
for file_name in os.listdir(folder_path):
|
11 |
+
if file_name.endswith('.npy'):
|
12 |
+
file_path = os.path.join(folder_path, file_name)
|
13 |
+
np_array = np.load(file_path)[0]
|
14 |
+
npy_list.append(np_array)
|
15 |
+
return npy_list
|
16 |
+
|
17 |
+
def average_npy(npy_list):
|
18 |
+
"""
|
19 |
+
Compute the average of a list of numpy arrays.
|
20 |
+
"""
|
21 |
+
return np.mean(npy_list, axis=0)
|
22 |
+
|
23 |
+
def cosine_similarity(vec1, vec2):
|
24 |
+
"""
|
25 |
+
Compute cosine similarity between two numpy arrays.
|
26 |
+
"""
|
27 |
+
dot_product = np.dot(vec1, vec2)
|
28 |
+
|
29 |
+
norm_vec1 = np.linalg.norm(vec1)
|
30 |
+
norm_vec2 = np.linalg.norm(vec2)
|
31 |
+
|
32 |
+
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
|
33 |
+
|
34 |
+
return cosine_sim
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
# Set up argument parsing for input folders
|
38 |
+
parser = argparse.ArgumentParser(description="Calculate cosine similarity between average feature vectors.")
|
39 |
+
parser.add_argument('reference', type=str, help='Path to the reference folder containing .npy files.')
|
40 |
+
parser.add_argument('test', type=str, help='Path to the test folder containing .npy files.')
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
reference = args.reference
|
44 |
+
test = args.test
|
45 |
+
# Load .npy files
|
46 |
+
ref_npy = load_npy_files(reference)
|
47 |
+
test_npy = load_npy_files(test)
|
48 |
+
|
49 |
+
# Compute the average of each list of numpy arrays
|
50 |
+
avg_ref = average_npy(ref_npy)
|
51 |
+
avg_test = average_npy(test_npy)
|
52 |
+
|
53 |
+
# Compute the cosine similarity between the two averaged numpy arrays
|
54 |
+
similarity = cosine_similarity(avg_ref, avg_test)
|
55 |
+
|
56 |
+
# Output the cosine similarity rounded to four decimal places
|
57 |
+
print(f"Cosine similarity between '{reference}' and '{test}': {similarity:.4f}")
|
semantic_search/semantic_search.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
def get_info(folder_path):
|
7 |
+
"""
|
8 |
+
Load all .npy files from a specified folder and return a dictionary of features.
|
9 |
+
"""
|
10 |
+
files = sorted(os.listdir(folder_path))
|
11 |
+
features = {}
|
12 |
+
|
13 |
+
for file in files:
|
14 |
+
if file.endswith(".npy"):
|
15 |
+
key = file.split(".")[0]
|
16 |
+
features[key] = np.load(os.path.join(folder_path, file))[0]
|
17 |
+
|
18 |
+
return features
|
19 |
+
|
20 |
+
def main(query_file, features_folder, top_k=10):
|
21 |
+
# Load query feature from the specified file
|
22 |
+
query_feature = np.load(query_file)[0] # Load directly from the query file
|
23 |
+
query_tensor = torch.tensor(query_feature).unsqueeze(dim=0)
|
24 |
+
|
25 |
+
# Load key features from the specified folder
|
26 |
+
key_features = get_info(features_folder)
|
27 |
+
|
28 |
+
# Prepare tensor for key features
|
29 |
+
key_feats_tensor = torch.tensor(np.array([key_features[k] for k in key_features.keys()]))
|
30 |
+
|
31 |
+
# Calculate cosine similarity
|
32 |
+
similarities = torch.cosine_similarity(query_tensor, key_feats_tensor)
|
33 |
+
ranked_indices = torch.argsort(similarities, descending=True)
|
34 |
+
|
35 |
+
# Get the keys for the features
|
36 |
+
keys = list(key_features.keys())
|
37 |
+
|
38 |
+
print(f"Top {top_k} similar items:")
|
39 |
+
for i in range(top_k):
|
40 |
+
print(keys[ranked_indices[i]], similarities[ranked_indices[i]].item())
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
# Set up argument parsing for input paths
|
44 |
+
parser = argparse.ArgumentParser(description="Find top similar features based on cosine similarity.")
|
45 |
+
parser.add_argument('query_file', type=str, help='Path to the query feature file (e.g., ballad.npy).')
|
46 |
+
parser.add_argument('features_folder', type=str, help='Path to the folder containing feature files for comparison.')
|
47 |
+
parser.add_argument('--top_k', type=int, default=10, help='Number of top similar items to display (default: 10).')
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
# Execute the main functionality
|
51 |
+
main(args.query_file, args.features_folder, args.top_k)
|
semantic_search/semantic_search_metrics.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
def get_features(path):
|
7 |
+
"""
|
8 |
+
Load and return feature data from .npy files in the given directory.
|
9 |
+
Each feature is stored in a dictionary with the filename (without extension) as the key.
|
10 |
+
"""
|
11 |
+
files = sorted(os.listdir(path))
|
12 |
+
features = {}
|
13 |
+
|
14 |
+
for file in files:
|
15 |
+
if file.endswith(".npy"):
|
16 |
+
key = file.split(".")[0]
|
17 |
+
features[key] = np.load(os.path.join(path, file))[0]
|
18 |
+
|
19 |
+
return features
|
20 |
+
|
21 |
+
def calculate_metrics(query_features, reference_features):
|
22 |
+
"""
|
23 |
+
Calculate MRR, Hit@1, Hit@10, and Hit@100 metrics based on the similarity
|
24 |
+
between query and reference features.
|
25 |
+
"""
|
26 |
+
common_keys = set(query_features.keys()) & set(reference_features.keys())
|
27 |
+
mrr, hit_1, hit_10, hit_100 = 0, 0, 0, 0
|
28 |
+
|
29 |
+
for idx, key in enumerate(common_keys):
|
30 |
+
# Convert query feature to tensor and add batch dimension
|
31 |
+
query_feat = torch.tensor(query_features[key]).unsqueeze(dim=0)
|
32 |
+
|
33 |
+
# Collect all reference features for common keys
|
34 |
+
ref_feats = torch.tensor(np.array([reference_features[k] for k in common_keys]))
|
35 |
+
|
36 |
+
# Compute cosine similarity between the query and all reference features
|
37 |
+
similarities = torch.cosine_similarity(query_feat, ref_feats)
|
38 |
+
|
39 |
+
# Create a list of (similarity, index) pairs
|
40 |
+
indexed_sims = list(enumerate(similarities.tolist()))
|
41 |
+
|
42 |
+
# Sort by similarity in descending order, with idx-based tie-breaking
|
43 |
+
sorted_indices = sorted(indexed_sims, key=lambda x: (x[1], x[0] == idx), reverse=True)
|
44 |
+
|
45 |
+
# Extract the sorted rank list
|
46 |
+
ranks = [x[0] for x in sorted_indices]
|
47 |
+
|
48 |
+
# Calculate MRR
|
49 |
+
mrr += 1 / (ranks.index(idx) + 1)
|
50 |
+
|
51 |
+
# Calculate Hit@1, Hit@10, Hit@100
|
52 |
+
if idx in ranks[:100]:
|
53 |
+
hit_100 += 1
|
54 |
+
if idx in ranks[:10]:
|
55 |
+
hit_10 += 1
|
56 |
+
if idx in ranks[:1]:
|
57 |
+
hit_1 += 1
|
58 |
+
|
59 |
+
# Compute the final metrics
|
60 |
+
total_keys = len(common_keys)
|
61 |
+
print(f"MRR: {round(mrr / total_keys, 4)}")
|
62 |
+
print(f"Hit@1: {round(hit_1 / total_keys, 4)}")
|
63 |
+
print(f"Hit@10: {round(hit_10 / total_keys, 4)}")
|
64 |
+
print(f"Hit@100: {round(hit_100 / total_keys, 4)}")
|
65 |
+
|
66 |
+
if __name__ == '__main__':
|
67 |
+
# Set up argument parsing for input directories
|
68 |
+
parser = argparse.ArgumentParser(description="Calculate similarity metrics between query and reference features.")
|
69 |
+
parser.add_argument('query_folder', type=str, help='Path to the folder containing query features (.npy files).')
|
70 |
+
parser.add_argument('reference_folder', type=str, help='Path to the folder containing reference features (.npy files).')
|
71 |
+
args = parser.parse_args()
|
72 |
+
|
73 |
+
# Load features from the specified folders
|
74 |
+
query_features = get_features(args.query_folder)
|
75 |
+
reference_features = get_features(args.reference_folder)
|
76 |
+
|
77 |
+
# Calculate and print the metrics
|
78 |
+
calculate_metrics(query_features, reference_features)
|