Staty's picture
Upload 22 files
b818573 verified
raw
history blame
7.48 kB
import data
import torch
from models import imagebind_model
from models.imagebind_model import ModalityType
import torch.nn as nn
from imagen_pytorch import ImagenTrainer
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
from extract.getim import load_image
import torch.optim as optim
import os
from torchvision import transforms
from image2vidimg import cobtwoten, cobtwoten256
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
device = torch.device("cuda")
#import matplotlib.pyplot as plt
#import torch.nn.functional as F
#import cv2
torch.cuda.empty_cache()
transform = transforms.Compose([
transforms.ToTensor(), # 将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间
]) # 来自ImageNet的mean和variance
unloader = transforms.ToPILImage()
# def imshow(tensor, title=None):
#
# tensor=tensor.permute(1,2,0)
# print(tensor.shape)
# cv2.imshow('image:', tensor.cpu().numpy())
# # 防止图片关闭
# cv2.waitKey(0)
# # plt.imshow(img_pil)
# # if title is not None:
# # plt.title(title)
# # plt.pause(0.001) # pause a bit so that plots are updated
def imagebind_out(audio_paths,model):
# Load data
inputs = {
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
return embeddings
class encode_audio(nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# self.link1=nn.Linear(1024,768)
self.link2=nn.Linear(1024,343)
# self.link3=nn.Linear(1024,768)
def forward(self,embeddings):
l1=embeddings
l2=self.link2(embeddings)
# l3=self.link3(embeddings)
l3=torch.matmul(l2.transpose(1,2),l1)
return torch.cat([l1,l3],dim=1)
# os.listdir()方法获取文件夹名字,返回数组
def getAllFiles(targetDir):
listFiles = os.listdir(targetDir)
return listFiles
# unet for imagen
unet1 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 64, dim_mults = (1, 2, 4, 8)).to(device)
unet2 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 128, dim_mults = (1, 2, 4, 8)).to(device)
# unet3 = Unet3D(dim = 256, dim_mults = (1, 2, 4, 8)).cuda()
#unet1 = NullUnet() # add a placeholder "null" unet for the base unet
imagen = ElucidatedImagen(
text_embed_dim=1024,
unets = (unet1, unet2),
image_sizes = (64, 128),
random_crop_sizes = (None, 64),
temporal_downsample_factor = (2, 1), # in this example, the first unet would receive the video temporally downsampled by 2x
num_sample_steps = 10,
cond_drop_prob = 0.1,
sigma_min = 0.002, # min noise level
sigma_max = (80, 160), # max noise level, double the max noise level for upsampler
sigma_data = 0.5, # standard deviation of data distribution
rho = 7, # controls the sampling schedule
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin = 0.05,
S_tmax = 50,
S_noise = 1.003,
).to(device)
trainer = ImagenTrainer(imagen)
# trainer.to(device)
# trainer.load("./checkpoint.pt")
trainer = trainer.to(device)
# Instantiate model
# device_ids = [0, 1]
model_imageb = imagebind_model.imagebind_huge(pretrained=True)
model_imageb=model_imageb.to(device)
model_imageb.eval()
# model_imageb=model_imageb.cuda(device=device_ids)
# model_imageb.to(device)
epo=31
p=1
files = getAllFiles("./extract/audio")
outloss=0
model1=(encode_audio()).to(device)
# model1.load_state_dict(torch.load("wlc.pt").state_dict())
optimizer = optim.Adam(model1.parameters(), lr=1e-5,
betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True)
# model1.eval()
model1.train()
torch.cuda.empty_cache()
for k in range(epo):
for nm in range(0, len(files) + 1 - p, p):
#for i in (1, 2):
file_ext0 = os.path.splitext(files[nm])
front0, ext0 = file_ext0
audio_pat=[]
audio_pat.append("./extract/audio/" + str(front0) + ".wav")
# fcontents = load_image("./extract/image/0.jpg", transform=None, shape=[256, 128])
fcontent = cobtwoten("./extract/image/" + str(front0) + ".jpg")
# print(fcontent.shape)
#fcontent = load_image("./extract/image/" + str(front0) + ".jpg", transform, shape=[256, 256])
for ni in range(1,p):
file_ext = os.path.splitext(files[nm+ni])
front, ext = file_ext
# content = load_image("./extract/image/" + str(front) + ".jpg", transform, shape=[256, 256])
content = cobtwoten("./extract/image/" + str(front) + ".jpg")
fcontent = torch.cat((fcontent, content), -5)
audio_pat.append("./extract/audio/" + str(front) + ".wav")
# imageb=torch.LongTensor(imageb_out["audio"])
imageb_out = imagebind_out(audio_pat,model_imageb)
fmusic = model1(imageb_out["audio"].unsqueeze(1))#(5,1,1024)->(5,344,1024)
# fmusic = model1(imageb_out["audio"].unsqueeze(1).cuda())#(5,1,1024)->(5,344,1024)
# print(fmusic)
# print(fmusic.shape)
fmusic=fmusic.to(device)
fcontent=fcontent.to(device)
loss = trainer(fcontent, text_embeds=fmusic, unet_number = 2,ignore_time = False, max_batch_size = p)
trainer.update(unet_number = 2)
optimizer.step()
# print(optimizer.state)
optimizer.zero_grad()
print(loss)
outloss=outloss+loss
#print("unet"+str(i)+" "+str(loss))
outloss=outloss
print("epoch"+str(k)+" "+" loss: "+str(outloss))
outloss=0
if k % 3 == 2:
torch.save(model1, "wlc.pt")
trainer.save('./checkpoint.pt')
# text_list=["A dog.", "A car", "A bird"]
# image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
#
#
#
#
# print(
# "Vision x Text: ",
# torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
# )
# print(
# "Audio x Text: ",
# torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
# )
# print(
# "Vision x Audio: ",
# torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
# )
#
#
#
# print(embeddings['audio'].shape)
# print(embeddings[ModalityType.AUDIO].shape)
# print(embeddings[ModalityType.VISION].shape)
# Expected output:
#
# Vision x Text:
# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
# [3.3836e-05, 9.9994e-01, 2.4118e-05],
# [4.7997e-05, 1.3496e-02, 9.8646e-01]])
#
# Audio x Text:
# tensor([[1., 0., 0.],
# [0., 1., 0.],
# [0., 0., 1.]])
#
# Vision x Audio:
# tensor([[0.8070, 0.1088, 0.0842],
# [0.1036, 0.7884, 0.1079],
# [0.0018, 0.0022, 0.9960]])