A newer version of the Gradio SDK is available:
5.20.1
tests
Documentation
Overview
The tests.py
script is designed to verify the functionality of a specific configuration, model weights, and dataset before proceeding with training, validation, or inference. Additionally, it allows the specification of other parameters, which can be passed either through the command line or via the base_args
variable in the script.
Usage
To use tests.py
, provide the desired arguments via the command line using the --
prefix. It is mandatory to specify the following arguments:
--model_type
--config_path
--start_check_point
--data_path
--valid_path
For example:
python tests.py --check_train \
--config_path config.yaml \
--model_type scnet \
--data_path /path/to/data \
--valid_path /path/to/valid
Alternatively, you can define default arguments in the base_args
variable directly in the script.
Arguments
The script accepts the following arguments:
--check_train
: Check training functionality.--check_valid
: Check validation functionality.--check_inference
: Check inference functionality.--device_ids
: Specify device IDs for training or inference.--model_type
: Specify the type of model to use.--start_check_point
: Path to the checkpoint to start from.--config_path
: Path to the configuration file.--data_path
: Path to the training data.--valid_path
: Path to the validation data.--results_path
: Path to save training results.--store_dir
: Path to store validation or inference results.--input_folder
: Path to the input folder for inference.--metrics
: List of metrics to evaluate, provided as space-separated values.--max_folders
: Maximum number of folders to process.--dataset_type
: Dataset type. Must be one of: 1, 2, 3, or 4. Default is 1.--num_workers
: Number of workers for the dataloader. Default is 0.--pin_memory
: Use pinned memory in the dataloader.--seed
: Random seed for reproducibility. Default is 0.--use_multistft_loss
: Use MultiSTFT Loss from the auraloss package.--use_mse_loss
: Use Mean Squared Error (MSE) loss.--use_l1_loss
: Use L1 loss.--wandb_key
: API Key for Weights and Biases (wandb). Default is an empty string.--pre_valid
: Run validation before training.--metric_for_scheduler
: Metric to be used for the learning rate scheduler. Choices aresdr
,l1_freq
,si_sdr
,neg_log_wmse
,aura_stft
,aura_mrstft
,bleedless
, orfullness
. Default issdr
.--train_lora
: Enable training with LoRA.--lora_checkpoint
: Path to the initial LoRA weights checkpoint. Default is an empty string.--extension
: File extension for validation. Default iswav
.--use_tta
: Enable test-time augmentation during inference. This triples runtime but improves prediction quality.--extract_instrumental
: Invert vocals to obtain instrumental output if available.--disable_detailed_pbar
: Disable the detailed progress bar.--force_cpu
: Force the use of the CPU, even if CUDA is available.--flac_file
: Output FLAC files instead of WAV.--pcm_type
: PCM type for FLAC files. Choices arePCM_16
orPCM_24
. Default isPCM_24
.--draw_spectro
: Generate spectrograms for the resulting stems. Specify the value in seconds of the track. Requires--store_dir
to be set. Default is 0.
Example
To check train, validate and inference with a configuration file with a specific dataset and checkpoint we can use:
python tests/test.py \
--check_train \
--check_valid \
--check_inference \
--model_type scnet \
--config_path configs/config_musdb18_scnet_large_starrytong.yaml \
--start_check_point weights/model_scnet_ep_30_neg_log_wmse_-11.8688.ckpt \
--data_path datasets/moisesdb/train_tracks \
--valid_path datasets/moisesdb/valid \
--use_tta \
--use_mse_loss
This command validates the setup by:
Specifying
scnet
as the model type.Loading the configuration from
configs/config_musdb18_scnet_large_starrytong.yaml
.Using the dataset located at
datasets/moisesdb/train_tracks
for training.Using
datasets/moisesdb/valid
for validation.Starting from the checkpoint at
weights/model_scnet_ep_30_neg_log_wmse_-11.8688.ckpt
.Enabling test-time augmentation and using MSE loss.
Additional Script: admin_test.py
The admin_test.py
script provides a way to verify the functionality of all configurations and models without specifying model weights or datasets. By default, it performs validation and inference. The configurations and corresponding parameters can be modified using the MODEL_CONFIGS
variable in the script.
This script is useful for bulk testing and ensuring that multiple configurations are correctly set up. It can help identify potential issues with configurations or models before proceeding to detailed testing with tests.py
or full-scale training.