import data import torch from models import imagebind_model from models.imagebind_model import ModalityType import torch.nn as nn from imagen_pytorch import ImagenTrainer from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer from extract.getim import load_image import torch.optim as optim import os from torchvision import transforms from image2vidimg import cobtwoten, cobtwoten256 import os # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' device = torch.device("cuda") #import matplotlib.pyplot as plt #import torch.nn.functional as F #import cv2 torch.cuda.empty_cache() transform = transforms.Compose([ transforms.ToTensor(), # 将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间 ]) # 来自ImageNet的mean和variance unloader = transforms.ToPILImage() # def imshow(tensor, title=None): # # tensor=tensor.permute(1,2,0) # print(tensor.shape) # cv2.imshow('image:', tensor.cpu().numpy()) # # 防止图片关闭 # cv2.waitKey(0) # # plt.imshow(img_pil) # # if title is not None: # # plt.title(title) # # plt.pause(0.001) # pause a bit so that plots are updated def imagebind_out(audio_paths,model): # Load data inputs = { ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), } with torch.no_grad(): embeddings = model(inputs) return embeddings class encode_audio(nn.Module): def __init__(self): super().__init__() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # self.link1=nn.Linear(1024,768) self.link2=nn.Linear(1024,343) # self.link3=nn.Linear(1024,768) def forward(self,embeddings): l1=embeddings l2=self.link2(embeddings) # l3=self.link3(embeddings) l3=torch.matmul(l2.transpose(1,2),l1) return torch.cat([l1,l3],dim=1) # os.listdir()方法获取文件夹名字,返回数组 def getAllFiles(targetDir): listFiles = os.listdir(targetDir) return listFiles # unet for imagen unet1 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 64, dim_mults = (1, 2, 4, 8)).to(device) unet2 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 128, dim_mults = (1, 2, 4, 8)).to(device) # unet3 = Unet3D(dim = 256, dim_mults = (1, 2, 4, 8)).cuda() #unet1 = NullUnet() # add a placeholder "null" unet for the base unet imagen = ElucidatedImagen( text_embed_dim=1024, unets = (unet1, unet2), image_sizes = (64, 128), random_crop_sizes = (None, 64), temporal_downsample_factor = (2, 1), # in this example, the first unet would receive the video temporally downsampled by 2x num_sample_steps = 10, cond_drop_prob = 0.1, sigma_min = 0.002, # min noise level sigma_max = (80, 160), # max noise level, double the max noise level for upsampler sigma_data = 0.5, # standard deviation of data distribution rho = 7, # controls the sampling schedule P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper S_tmin = 0.05, S_tmax = 50, S_noise = 1.003, ).to(device) trainer = ImagenTrainer(imagen) # trainer.to(device) # trainer.load("./checkpoint.pt") trainer = trainer.to(device) # Instantiate model # device_ids = [0, 1] model_imageb = imagebind_model.imagebind_huge(pretrained=True) model_imageb=model_imageb.to(device) model_imageb.eval() # model_imageb=model_imageb.cuda(device=device_ids) # model_imageb.to(device) epo=31 p=1 files = getAllFiles("./extract/audio") outloss=0 model1=(encode_audio()).to(device) # model1.load_state_dict(torch.load("wlc.pt").state_dict()) optimizer = optim.Adam(model1.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True) # model1.eval() model1.train() torch.cuda.empty_cache() for k in range(epo): for nm in range(0, len(files) + 1 - p, p): #for i in (1, 2): file_ext0 = os.path.splitext(files[nm]) front0, ext0 = file_ext0 audio_pat=[] audio_pat.append("./extract/audio/" + str(front0) + ".wav") # fcontents = load_image("./extract/image/0.jpg", transform=None, shape=[256, 128]) fcontent = cobtwoten("./extract/image/" + str(front0) + ".jpg") # print(fcontent.shape) #fcontent = load_image("./extract/image/" + str(front0) + ".jpg", transform, shape=[256, 256]) for ni in range(1,p): file_ext = os.path.splitext(files[nm+ni]) front, ext = file_ext # content = load_image("./extract/image/" + str(front) + ".jpg", transform, shape=[256, 256]) content = cobtwoten("./extract/image/" + str(front) + ".jpg") fcontent = torch.cat((fcontent, content), -5) audio_pat.append("./extract/audio/" + str(front) + ".wav") # imageb=torch.LongTensor(imageb_out["audio"]) imageb_out = imagebind_out(audio_pat,model_imageb) fmusic = model1(imageb_out["audio"].unsqueeze(1))#(5,1,1024)->(5,344,1024) # fmusic = model1(imageb_out["audio"].unsqueeze(1).cuda())#(5,1,1024)->(5,344,1024) # print(fmusic) # print(fmusic.shape) fmusic=fmusic.to(device) fcontent=fcontent.to(device) loss = trainer(fcontent, text_embeds=fmusic, unet_number = 2,ignore_time = False, max_batch_size = p) trainer.update(unet_number = 2) optimizer.step() # print(optimizer.state) optimizer.zero_grad() print(loss) outloss=outloss+loss #print("unet"+str(i)+" "+str(loss)) outloss=outloss print("epoch"+str(k)+" "+" loss: "+str(outloss)) outloss=0 if k % 3 == 2: torch.save(model1, "wlc.pt") trainer.save('./checkpoint.pt') # text_list=["A dog.", "A car", "A bird"] # image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"] # # # # # print( # "Vision x Text: ", # torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1), # ) # print( # "Audio x Text: ", # torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1), # ) # print( # "Vision x Audio: ", # torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1), # ) # # # # print(embeddings['audio'].shape) # print(embeddings[ModalityType.AUDIO].shape) # print(embeddings[ModalityType.VISION].shape) # Expected output: # # Vision x Text: # tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05], # [3.3836e-05, 9.9994e-01, 2.4118e-05], # [4.7997e-05, 1.3496e-02, 9.8646e-01]]) # # Audio x Text: # tensor([[1., 0., 0.], # [0., 1., 0.], # [0., 0., 1.]]) # # Vision x Audio: # tensor([[0.8070, 0.1088, 0.0842], # [0.1036, 0.7884, 0.1079], # [0.0018, 0.0022, 0.9960]])