import gradio as gr GK=0 from transformers import AutoTokenizer import torch import os from VitsModelSplit.vits_model2 import VitsModel,get_state_grad_loss token=os.environ.get("key_") tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-ar-sa-huba",token=token) #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_vits=VitsModel.from_pretrained("wasmdashai/vits-ar-sa-huba",token=token)#.to(device) # import VitsModelSplit.monotonic_align as monotonic_align from IPython.display import clear_output from transformers import set_seed import wandb import logging import copy import torch import numpy as np import torch from datasets import DatasetDict,Dataset import os from VitsModelSplit.vits_model2 import VitsModel,get_state_grad_loss from VitsModelSplit.PosteriorDecoderModel import PosteriorDecoderModel from VitsModelSplit.feature_extraction import VitsFeatureExtractor from transformers import AutoTokenizer, HfArgumentParser, set_seed from VitsModelSplit.Arguments import DataTrainingArguments, ModelArguments, VITSTrainingArguments from VitsModelSplit.dataset_features_collector import FeaturesCollectionDataset from torch.cuda.amp import autocast, GradScaler device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model=VitsModel.from_pretrained("facebook/mms-tts-eng").to(device) # model1= VitsModel.from_pretrained("/content/drive/MyDrive/vitsM/OneBatch/S6/MMMMM-dash-azd60").to("cuda") # model= VitsModel.from_pretrained("/content/drive/MyDrive/vitsM/TO/sp3/core/vend").to("cuda") # model=VitsModel.from_pretrained("/content/drive/MyDrive/vitsM/heppa/EndCore3/v0").to("cuda") # model.discriminator=model1.discriminator # model.duration_predictor=model1.duration_predictor # model.setMfA(monotonic_align.maximum_path) # tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-ara",cache_dir="./") feature_extractor = VitsFeatureExtractor() parser = HfArgumentParser((ModelArguments, DataTrainingArguments, VITSTrainingArguments)) json_file = os.path.abspath('VitsModelSplit/finetune_config_ara.json') model_args, data_args, training_args = parser.parse_json_file(json_file = json_file) sgl=get_state_grad_loss(mel=True, # generator=False, # discriminator=False, duration=False) training_args.num_train_epochs=1000 training_args.fp16=True training_args.eval_steps=300 # sgl=get_state_grad_loss(k1=True,#generator=False, # discriminator=False, # duration=False # ) Lst=['input_ids', 'attention_mask', 'waveform', 'labels', 'labels_attention_mask', 'mel_scaled_input_features'] def covert_cuda_batch(d): # return d for key in Lst: d[key]=d[key].cuda(non_blocking=True) # for key in d['text_encoder_output']: # d['text_encoder_output'][key]=d['text_encoder_output'][key].cuda(non_blocking=True) # for key in d['posterior_encode_output']: # d['posterior_encode_output'][key]=d['posterior_encode_output'][key].cuda(non_blocking=True) return d def generator_loss(disc_outputs): total_loss = 0 gen_losses = [] for disc_output in disc_outputs: disc_output = disc_output loss = torch.mean((1 - disc_output) ** 2) gen_losses.append(loss) total_loss += loss return total_loss, gen_losses def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 real_losses = 0 generated_losses = 0 for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs): real_loss = torch.mean((1 - disc_real) ** 2) generated_loss = torch.mean(disc_generated**2) loss += real_loss + generated_loss real_losses += real_loss generated_losses += generated_loss return loss, real_losses, generated_losses def feature_loss(feature_maps_real, feature_maps_generated): loss = 0 for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated): for real, generated in zip(feature_map_real, feature_map_generated): real = real.detach() loss += torch.mean(torch.abs(real - generated)) return loss * 2 def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): """ z_p, logs_q: [b, h, t_t] m_p, logs_p: [b, h, t_t] """ z_p = z_p.float() logs_q = logs_q.float() m_p = m_p.float() logs_p = logs_p.float() z_mask = z_mask.float() kl = logs_p - logs_q - 0.5 kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) kl = torch.sum(kl * z_mask) l = kl / torch.sum(z_mask) return l #............................................. # def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask): # kl = prior_log_variance - posterior_log_variance - 0.5 # kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance) # kl = torch.sum(kl * labels_mask) # loss = kl / torch.sum(labels_mask) # return loss def get_state_grad_loss(k1=True, mel=True, duration=True, generator=True, discriminator=True): return {'k1':k1,'mel':mel,'duration':duration,'generator':generator,'discriminator':discriminator} def clip_grad_value_(parameters, clip_value, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) if clip_value is not None: clip_value = float(clip_value) total_norm = 0 for p in parameters: param_norm = p.grad.data.norm(norm_type) total_norm += param_norm.item() ** norm_type if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value) total_norm = total_norm ** (1. / norm_type) return total_norm def get_embed_speaker(self,speaker_id): if self.config.num_speakers > 1 and speaker_id is not None: if isinstance(speaker_id, int): speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device) elif isinstance(speaker_id, (list, tuple, np.ndarray)): speaker_id = torch.tensor(speaker_id, device=self.device) if not ((0 <= speaker_id).all() and (speaker_id < self.config.num_speakers).all()).item(): raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") return self.embed_speaker(speaker_id).unsqueeze(-1) else: return None def get_data_loader(train_dataset_dirs,eval_dataset_dir,full_generation_dir,device): ctrain_datasets=[] for dataset_dir ,id_sp in train_dataset_dirs: train_dataset = FeaturesCollectionDataset(dataset_dir = os.path.join(dataset_dir,'train'), device = device ) ctrain_datasets.append((train_dataset,id_sp)) eval_dataset = None if training_args.do_eval: eval_dataset = FeaturesCollectionDataset(dataset_dir = eval_dataset_dir, device = device ) full_generation_dataset = FeaturesCollectionDataset(dataset_dir = full_generation_dir, device = device) return ctrain_datasets,eval_dataset,full_generation_dataset global_step=0 def trainer_to_cuda(self, ctrain_datasets = None, eval_dataset = None, full_generation_dataset = None, feature_extractor = VitsFeatureExtractor(), training_args = None, full_generation_sample_index= 0, project_name = "Posterior_Decoder_Finetuning", wandbKey = "782b6a6e82bbb5a5348de0d3c7d40d1e76351e79", is_used_text_encoder=True, is_used_posterior_encode=True, dict_state_grad_loss=None, nk=1, path_save_model='./', maf=None, n_back_save_model=3000, start_speeker=0, end_speeker=1, n_epoch=0, ): # os.makedirs(training_args.output_dir,exist_ok=True) # logger = logging.getLogger(f"{__name__} Training") # log_level = training_args.get_process_log_level() # logger.setLevel(log_level) # # wandb.login(key= wandbKey) # # wandb.init(project= project_name,config = training_args.to_dict()) if dict_state_grad_loss is None: dict_state_grad_loss=get_state_grad_loss() global global_step set_seed(training_args.seed) scaler = GradScaler(enabled=training_args.fp16) self.config.save_pretrained(training_args.output_dir) len_db=len(ctrain_datasets) self.full_generation_sample = full_generation_dataset[full_generation_sample_index] # init optimizer, lr_scheduler for disc in self.discriminator.discriminators: disc.apply_weight_norm() self.decoder.apply_weight_norm() # torch.nn.utils.weight_norm(self.decoder.conv_pre) # torch.nn.utils.weight_norm(self.decoder.conv_post) for flow in self.flow.flows: torch.nn.utils.weight_norm(flow.conv_pre) torch.nn.utils.weight_norm(flow.conv_post) discriminator=self.discriminator self.discriminator=None optimizer = torch.optim.AdamW( self.parameters(), training_args.learning_rate, betas=[training_args.adam_beta1, training_args.adam_beta2], eps=training_args.adam_epsilon, ) # hack to be able to train on multiple device disc_optimizer = torch.optim.AdamW( discriminator.parameters(), training_args.d_learning_rate, betas=[training_args.d_adam_beta1, training_args.d_adam_beta2], eps=training_args.adam_epsilon, ) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=training_args.lr_decay, last_epoch=-1 ) disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1) logger.info("***** Running training *****") logger.info(f" Num Epochs = {training_args.num_train_epochs}") #.......................loop training............................ for epoch in range(training_args.num_train_epochs): train_losses_sum = 0 loss_gen=0 loss_des=0 loss_durationsall=0 loss_melall=0 loss_klall=0 loss_fmapsall=0 lr_scheduler.step() disc_lr_scheduler.step() train_dataset,speaker_id=ctrain_datasets[epoch%len_db] print(f" Num Epochs = {int((epoch+n_epoch)/len_db)}, speaker_id DB ={speaker_id}") num_div_proc=int(len(train_dataset)/10) print(' -process traning : [',end='') for step, batch in enumerate(train_dataset): # if speaker_id==None: # if step<3 :continue # if step>200:break batch=covert_cuda_batch(batch) displayloss={} with autocast(enabled=training_args.fp16): speaker_embeddings=get_embed_speaker(self,batch["speaker_id"] if speaker_id ==None else speaker_id ) waveform,ids_slice,log_duration,prior_latents,posterior_log_variances,prior_means,prior_log_variances,labels_padding_mask = self.forward_train( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], labels_attention_mask=batch["labels_attention_mask"], text_encoder_output =None , posterior_encode_output=None , return_dict=True, monotonic_alignment_function= maf, speaker_embeddings=speaker_embeddings ) mel_scaled_labels = batch["mel_scaled_input_features"] mel_scaled_target = self.slice_segments(mel_scaled_labels, ids_slice,self.segment_size) mel_scaled_generation = feature_extractor._torch_extract_fbank_features(waveform.squeeze(1))[1] target_waveform = batch["waveform"].transpose(1, 2) target_waveform = self.slice_segments( target_waveform, ids_slice * feature_extractor.hop_length, self.config.segment_size ) discriminator_target, fmaps_target = discriminator(target_waveform) discriminator_candidate, fmaps_candidate = discriminator(waveform.detach()) with autocast(enabled=False): if dict_state_grad_loss['discriminator']: loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss( discriminator_target, discriminator_candidate ) dk={"step_loss_disc": loss_disc.detach().item(), "step_loss_real_disc": loss_real_disc.detach().item(), "step_loss_fake_disc": loss_fake_disc.detach().item()} displayloss['dict_loss_discriminator']=dk loss_dd = loss_disc# + loss_real_disc + loss_fake_disc # loss_dd.backward() disc_optimizer.zero_grad() scaler.scale(loss_dd).backward() scaler.unscale_(disc_optimizer ) grad_norm_d = clip_grad_value_(discriminator.parameters(), None) scaler.step(disc_optimizer) loss_des+=grad_norm_d with autocast(enabled=training_args.fp16): # backpropagate discriminator_target, fmaps_target = discriminator(target_waveform) discriminator_candidate, fmaps_candidate = discriminator(waveform.detach()) with autocast(enabled=False): if dict_state_grad_loss['k1']: loss_kl = kl_loss( prior_latents, posterior_log_variances, prior_means, prior_log_variances, labels_padding_mask, ) loss_kl=loss_kl*training_args.weight_kl loss_klall+=loss_kl.detach().item() #if displayloss['loss_kl']>=0: # loss_kl.backward() if dict_state_grad_loss['mel']: loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)*training_args.weight_mel loss_melall+= loss_mel.detach().item() # train_losses_sum = train_losses_sum + displayloss['loss_mel'] # if displayloss['loss_mel']>=0: # loss_mel.backward() if dict_state_grad_loss['duration']: loss_duration=torch.sum(log_duration)*training_args.weight_duration loss_durationsall+=loss_duration.detach().item() # if displayloss['loss_duration']>=0: # loss_duration.backward() if dict_state_grad_loss['generator']: loss_fmaps = feature_loss(fmaps_target, fmaps_candidate) loss_gen, losses_gen = generator_loss(discriminator_candidate) loss_gen=loss_gen * training_args.weight_gen displayloss['loss_gen'] = loss_gen.detach().item() # loss_gen.backward(retain_graph=True) loss_fmaps=loss_fmaps * training_args.weight_fmaps displayloss['loss_fmaps'] = loss_fmaps.detach().item() # loss_fmaps.backward(retain_graph=True) total_generator_loss = ( loss_duration + loss_mel + loss_kl + loss_fmaps + loss_gen ) # total_generator_loss.backward() optimizer.zero_grad() scaler.scale(total_generator_loss).backward() scaler.unscale_(optimizer) grad_norm_g = clip_grad_value_(self.parameters(), None) scaler.step(optimizer) scaler.update() loss_gen+=grad_norm_g # optimizer.step() # print(f"TRAINIG - batch {step}, waveform {(batch['waveform'].shape)}, lr {lr_scheduler.get_last_lr()[0]}... ") # print(f"display loss function enable :{displayloss}") global_step +=1 if step%num_div_proc==0: print('==',end='') # validation do_eval = training_args.do_eval and (global_step % training_args.eval_steps == 0) if do_eval: speaker_id_c=int(torch.randint(start_speeker,end_speeker,size=(1,))[0]) logger.info("Running validation... ") eval_losses_sum = 0 cc=0; for step, batch in enumerate(eval_dataset): break if cc>2: break cc+=1 with torch.no_grad(): model_outputs = self.forward( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], labels_attention_mask=batch["labels_attention_mask"], speaker_id=batch["speaker_id"], return_dict=True, ) mel_scaled_labels = batch["mel_scaled_input_features"] mel_scaled_target = self.slice_segments(mel_scaled_labels, model_outputs.ids_slice,self.segment_size) mel_scaled_generation = feature_extractor._torch_extract_fbank_features(model_outputs.waveform.squeeze(1))[1] loss = loss_mel.detach().item() eval_losses_sum +=loss loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation) print(f"VALIDATION - batch {step}, waveform {(batch['waveform'].shape)}, step_loss_mel {loss} ... ") with torch.no_grad(): full_generation_sample = self.full_generation_sample full_generation =self.forward( input_ids =full_generation_sample["input_ids"], attention_mask=full_generation_sample["attention_mask"], speaker_id=speaker_id_c ) full_generation_waveform = full_generation.waveform.cpu().numpy() wandb.log({ "eval_losses": eval_losses_sum, "full generations samples": [ wandb.Audio(w.reshape(-1), caption=f"Full generation sample {epoch}", sample_rate=16000) for w in full_generation_waveform],}) step+=1 # wandb.log({"train_losses":loss_melall}) wandb.log({"loss_gen":loss_gen/step}) wandb.log({"loss_des":loss_des/step}) wandb.log({"loss_duration":loss_durationsall/step}) wandb.log({"loss_mel":loss_melall/step}) wandb.log({f"loss_kl_db{speaker_id}":loss_klall/step}) print(']',end='') # self.save_pretrained(path_save_model) self.discriminator=discriminator for disc in self.discriminator.discriminators: disc.remove_weight_norm() self.decoder.remove_weight_norm() # torch.nn.utils.remove_weight_norm(self.decoder.conv_pre) # torch.nn.utils.remove_weight_norm(self.decoder.conv_post) for flow in self.flow.flows: torch.nn.utils.remove_weight_norm(flow.conv_pre) torch.nn.utils.remove_weight_norm(flow.conv_post) self.save_pretrained(path_save_model) logger.info("Running final full generations samples... ") logger.info("***** Training / Inference Done *****") def modelspeech(texts): inputs = tokenizer(texts, return_tensors="pt")#.cuda() wav = model_vits(input_ids=inputs["input_ids"]).waveform#.detach() # display(Audio(wav, rate=model.config.sampling_rate)) return model_vits.config.sampling_rate,wav#remove_noise_nr(wav) def greet(text,id): global GK b=int(id) while True: GK+=1 texts=[text]*b out=modelspeech(texts) yield f"namber is {GK}" demo = gr.Interface(fn=greet, inputs=["text","text"], outputs="text") demo.launch()