Spaces:
Runtime error
Runtime error
| import argparse | |
| import functools | |
| import os | |
| import random | |
| from tqdm import tqdm | |
| import sys | |
| sys.path.append('../') | |
| import yaml | |
| import time | |
| import numpy as np | |
| import torch | |
| from data.data import get_audiotext_dataloader | |
| def validation_losses(model, data_config, clap_config, tokenizer, batch_size, autocast, cast_dtype, device_id, verbose=True): | |
| model.eval() | |
| def get_val_loss(validloader): | |
| loss_sum = 0.0 | |
| for idx, batch in tqdm(enumerate(validloader)): | |
| audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| labels = input_ids.clone() | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| labels[:, :1] = -100 | |
| labels[labels == tokenizer.encode("<audio>")[-1]] = -100 | |
| sep_locations = labels == tokenizer.sep_token_id | |
| eoc_locations = labels == endofchunk_token_id | |
| for i in range(labels.shape[0]): | |
| shouldmask = True | |
| for j in range(labels.shape[1]): | |
| if shouldmask and (labels[i][j] != tokenizer.eos_token_id): | |
| masked_value = -100 | |
| else: | |
| masked_value = labels[i][j] | |
| if labels[i][j] == tokenizer.sep_token_id: | |
| shouldmask = False | |
| elif labels[i][j] == endofchunk_token_id: | |
| shouldmask = True | |
| labels[i][j] = masked_value | |
| if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]: | |
| for j in range(labels.shape[1]-1, -1, -1): | |
| if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]: | |
| labels[i][j] = -100 | |
| else: | |
| break | |
| labels = labels.to(device_id) | |
| with autocast(): | |
| output = model( | |
| audio_x=audio_clips, | |
| audio_x_mask=audio_embed_mask, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels | |
| ) | |
| valid_loss_no_multiplier = output.loss.item() | |
| loss_sum += valid_loss_no_multiplier | |
| return loss_sum / ((idx+1) * batch_size) | |
| media_token_id = tokenizer("<audio>", add_special_tokens=False)["input_ids"][-1] | |
| assert media_token_id == tokenizer.encode("<audio>")[-1] | |
| endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] | |
| valid_losses = {} | |
| all_valid_AudioTextDataInfo = get_audiotext_dataloader(data_config, clap_config, tokenizer, batch_size, split='val') | |
| for valid_dataset_name in all_valid_AudioTextDataInfo: | |
| if verbose: | |
| print('computing validation loss on {}'.format(valid_dataset_name)) | |
| validloader = all_valid_AudioTextDataInfo[valid_dataset_name].dataloader | |
| valid_losses[valid_dataset_name] = get_val_loss(validloader) | |
| if verbose: | |
| print('validation loss on {} is {:.3f}'.format(valid_dataset_name, valid_losses[valid_dataset_name])) | |
| model.train() | |
| return valid_losses | |
| if __name__ == "__main__": | |
| from src.factory import create_model_and_transforms | |
| from train_utils import Dict2Class, get_autocast, get_cast_dtype | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-c', '--config', type=str, default='../configs/config.yaml', help='yaml config path') | |
| parsed_args = parser.parse_args() | |
| config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader) | |
| data_config = config['data_config'] | |
| model_config = config['model_config'] | |
| clap_config = config['clap_config'] | |
| args = Dict2Class(config['train_config']) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning | |
| model, tokenizer = create_model_and_transforms( | |
| **model_config, | |
| clap_config=clap_config, | |
| use_local_files=args.offline, | |
| gradient_checkpointing=args.gradient_checkpointing, | |
| freeze_lm_embeddings=args.freeze_lm_embeddings, | |
| ) | |
| device_id = 0 | |
| model = model.to(device_id) | |
| autocast = get_autocast( | |
| args.precision, cache_enabled=(not args.fsdp) | |
| ) # if fsdp, disable cache to save memory | |
| cast_dtype = get_cast_dtype(args.precision) | |
| valid_losses = validation_losses( | |
| model, | |
| data_config, | |
| clap_config, | |
| tokenizer, | |
| args.batch_size, | |
| autocast, | |
| cast_dtype, | |
| device_id, | |
| verbose=True | |
| ) | |
| print(valid_losses) |