import data | |
import torch | |
from models import imagebind_model | |
from models.imagebind_model import ModalityType | |
import torch.nn as nn | |
from imagen_pytorch import ImagenTrainer | |
from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer | |
from extract.getim import load_image | |
import torch.optim as optim | |
import os | |
from torchvision import transforms | |
from image2vidimg import cobtwoten, cobtwoten256 | |
import os | |
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' | |
device = torch.device("cuda") | |
#import matplotlib.pyplot as plt | |
#import torch.nn.functional as F | |
#import cv2 | |
torch.cuda.empty_cache() | |
transform = transforms.Compose([ | |
transforms.ToTensor(), # 将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间 | |
]) # 来自ImageNet的mean和variance | |
unloader = transforms.ToPILImage() | |
# def imshow(tensor, title=None): | |
# | |
# tensor=tensor.permute(1,2,0) | |
# print(tensor.shape) | |
# cv2.imshow('image:', tensor.cpu().numpy()) | |
# # 防止图片关闭 | |
# cv2.waitKey(0) | |
# # plt.imshow(img_pil) | |
# # if title is not None: | |
# # plt.title(title) | |
# # plt.pause(0.001) # pause a bit so that plots are updated | |
def imagebind_out(audio_paths,model): | |
# Load data | |
inputs = { | |
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
return embeddings | |
class encode_audio(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# self.link1=nn.Linear(1024,768) | |
self.link2=nn.Linear(1024,343) | |
# self.link3=nn.Linear(1024,768) | |
def forward(self,embeddings): | |
l1=embeddings | |
l2=self.link2(embeddings) | |
# l3=self.link3(embeddings) | |
l3=torch.matmul(l2.transpose(1,2),l1) | |
return torch.cat([l1,l3],dim=1) | |
# os.listdir()方法获取文件夹名字,返回数组 | |
def getAllFiles(targetDir): | |
listFiles = os.listdir(targetDir) | |
return listFiles | |
# unet for imagen | |
unet1 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 64, dim_mults = (1, 2, 4, 8)).to(device) | |
unet2 = Unet3D(max_text_len=344,text_embed_dim=1024,dim = 128, dim_mults = (1, 2, 4, 8)).to(device) | |
# unet3 = Unet3D(dim = 256, dim_mults = (1, 2, 4, 8)).cuda() | |
#unet1 = NullUnet() # add a placeholder "null" unet for the base unet | |
imagen = ElucidatedImagen( | |
text_embed_dim=1024, | |
unets = (unet1, unet2), | |
image_sizes = (64, 128), | |
random_crop_sizes = (None, 64), | |
temporal_downsample_factor = (2, 1), # in this example, the first unet would receive the video temporally downsampled by 2x | |
num_sample_steps = 10, | |
cond_drop_prob = 0.1, | |
sigma_min = 0.002, # min noise level | |
sigma_max = (80, 160), # max noise level, double the max noise level for upsampler | |
sigma_data = 0.5, # standard deviation of data distribution | |
rho = 7, # controls the sampling schedule | |
P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training | |
P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training | |
S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper | |
S_tmin = 0.05, | |
S_tmax = 50, | |
S_noise = 1.003, | |
).to(device) | |
trainer = ImagenTrainer(imagen) | |
# trainer.to(device) | |
# trainer.load("./checkpoint.pt") | |
trainer = trainer.to(device) | |
# Instantiate model | |
# device_ids = [0, 1] | |
model_imageb = imagebind_model.imagebind_huge(pretrained=True) | |
model_imageb=model_imageb.to(device) | |
model_imageb.eval() | |
# model_imageb=model_imageb.cuda(device=device_ids) | |
# model_imageb.to(device) | |
epo=31 | |
p=1 | |
files = getAllFiles("./extract/audio") | |
outloss=0 | |
model1=(encode_audio()).to(device) | |
# model1.load_state_dict(torch.load("wlc.pt").state_dict()) | |
optimizer = optim.Adam(model1.parameters(), lr=1e-5, | |
betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True) | |
# model1.eval() | |
model1.train() | |
torch.cuda.empty_cache() | |
for k in range(epo): | |
for nm in range(0, len(files) + 1 - p, p): | |
#for i in (1, 2): | |
file_ext0 = os.path.splitext(files[nm]) | |
front0, ext0 = file_ext0 | |
audio_pat=[] | |
audio_pat.append("./extract/audio/" + str(front0) + ".wav") | |
# fcontents = load_image("./extract/image/0.jpg", transform=None, shape=[256, 128]) | |
fcontent = cobtwoten("./extract/image/" + str(front0) + ".jpg") | |
# print(fcontent.shape) | |
#fcontent = load_image("./extract/image/" + str(front0) + ".jpg", transform, shape=[256, 256]) | |
for ni in range(1,p): | |
file_ext = os.path.splitext(files[nm+ni]) | |
front, ext = file_ext | |
# content = load_image("./extract/image/" + str(front) + ".jpg", transform, shape=[256, 256]) | |
content = cobtwoten("./extract/image/" + str(front) + ".jpg") | |
fcontent = torch.cat((fcontent, content), -5) | |
audio_pat.append("./extract/audio/" + str(front) + ".wav") | |
# imageb=torch.LongTensor(imageb_out["audio"]) | |
imageb_out = imagebind_out(audio_pat,model_imageb) | |
fmusic = model1(imageb_out["audio"].unsqueeze(1))#(5,1,1024)->(5,344,1024) | |
# fmusic = model1(imageb_out["audio"].unsqueeze(1).cuda())#(5,1,1024)->(5,344,1024) | |
# print(fmusic) | |
# print(fmusic.shape) | |
fmusic=fmusic.to(device) | |
fcontent=fcontent.to(device) | |
loss = trainer(fcontent, text_embeds=fmusic, unet_number = 2,ignore_time = False, max_batch_size = p) | |
trainer.update(unet_number = 2) | |
optimizer.step() | |
# print(optimizer.state) | |
optimizer.zero_grad() | |
print(loss) | |
outloss=outloss+loss | |
#print("unet"+str(i)+" "+str(loss)) | |
outloss=outloss | |
print("epoch"+str(k)+" "+" loss: "+str(outloss)) | |
outloss=0 | |
if k % 3 == 2: | |
torch.save(model1, "wlc.pt") | |
trainer.save('./checkpoint.pt') | |
# text_list=["A dog.", "A car", "A bird"] | |
# image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"] | |
# | |
# | |
# | |
# | |
# print( | |
# "Vision x Text: ", | |
# torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1), | |
# ) | |
# print( | |
# "Audio x Text: ", | |
# torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1), | |
# ) | |
# print( | |
# "Vision x Audio: ", | |
# torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1), | |
# ) | |
# | |
# | |
# | |
# print(embeddings['audio'].shape) | |
# print(embeddings[ModalityType.AUDIO].shape) | |
# print(embeddings[ModalityType.VISION].shape) | |
# Expected output: | |
# | |
# Vision x Text: | |
# tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05], | |
# [3.3836e-05, 9.9994e-01, 2.4118e-05], | |
# [4.7997e-05, 1.3496e-02, 9.8646e-01]]) | |
# | |
# Audio x Text: | |
# tensor([[1., 0., 0.], | |
# [0., 1., 0.], | |
# [0., 0., 1.]]) | |
# | |
# Vision x Audio: | |
# tensor([[0.8070, 0.1088, 0.0842], | |
# [0.1036, 0.7884, 0.1079], | |
# [0.0018, 0.0022, 0.9960]]) |