Baraaqasem's picture
Upload 585 files
5d32408 verified
# General
import os
from os.path import join as opj
import datetime
import torch
from einops import rearrange, repeat
# Utilities
from videogen_hub.pipelines.streamingt2v.inference_utils import *
from modelscope.outputs import OutputKeys
import imageio
from PIL import Image
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
from diffusers.utils import load_image
transform = transforms.Compose([transforms.PILToTensor()])
def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"):
frames = ms_model(
prompt,
num_inference_steps=t,
generator=inference_generator,
eta=1.0,
height=256,
width=256,
latents=None,
).frames
frames = torch.stack([torch.from_numpy(frame) for frame in frames])
frames = frames.to(device).to(torch.float32)
return rearrange(frames[0], "F W H C -> F C W H")
def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"):
frames = ad_model(
prompt,
negative_prompt="bad quality, worse quality",
num_frames=16,
num_inference_steps=t,
generator=inference_generator,
guidance_scale=7.5,
).frames[0]
frames = torch.stack([transform(frame) for frame in frames])
frames = frames.to(device).to(torch.float32)
frames = F.interpolate(frames, size=256)
frames = frames / 255.0
return frames
def sdxl_image_gen(prompt, sdxl_model):
image = sdxl_model(prompt=prompt).images[0]
return image
def svd_short_gen(
image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"
):
if image is None or image == "":
image = sdxl_image_gen(prompt, sdxl_model)
image = image.resize((576, 576))
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
elif type(image) is str:
image = load_image(image)
image = resize_and_keep(image)
image = center_crop(image)
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
else:
image = Image.fromarray(np.uint8(image))
image = resize_and_keep(image)
image = center_crop(image)
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
frames = svd_model(
image, decode_chunk_size=8, generator=inference_generator
).frames[0]
frames = torch.stack([transform(frame) for frame in frames])
frames = frames.to(device).to(torch.float32)
frames = frames[:16, :, :, 224:-224]
frames = F.interpolate(frames, size=256)
frames = frames / 255.0
return frames
def stream_long_gen(
prompt,
short_video,
n_autoreg_gen,
negative_prompt,
seed,
t,
image_guidance,
result_file_stem,
stream_cli,
stream_model,
):
trainer = stream_cli.trainer
trainer.limit_predict_batches = 1
trainer.predict_cfg = {
"predict_dir": stream_cli.config["result_fol"].as_posix(),
"result_file_stem": result_file_stem,
"prompt": prompt,
"video": short_video,
"seed": seed,
"num_inference_steps": t,
"guidance_scale": image_guidance,
"n_autoregressive_generations": n_autoreg_gen,
}
stream_model.inference_params.negative_prompt = negative_prompt
trainer.predict(model=stream_model, datamodule=stream_cli.datamodule)
def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
downscale = cfg_v2v["downscale"]
upscale_size = cfg_v2v["upscale_size"]
pad = cfg_v2v["pad"]
now = datetime.datetime.now()
now = str(now.time()).replace(":", "_").replace(".", "_")
name = prompt[:100].replace(" ", "_") + "_" + now
enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4")
video_frames = imageio.mimread(video)
h, w, _ = video_frames[0].shape
# Downscale video, then resize to fit the upscale size
video = [
Image.fromarray(frame).resize((w // downscale, h // downscale))
for frame in video_frames
]
video = [resize_to_fit(frame, upscale_size) for frame in video]
if pad:
video = [pad_to_fit(frame, upscale_size) for frame in video]
# video = [np.array(frame) for frame in video]
imageio.mimsave(opj(where_to_log, "temp_" + now + ".mp4"), video, fps=8)
p_input = {
"video_path": opj(where_to_log, "temp_" + now + ".mp4"),
"text": prompt,
"positive_prompt": prompt,
"total_noise_levels": 600,
}
model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]
# Remove padding
video_frames = imageio.mimread(enhanced_video_mp4)
video_frames_square = []
for frame in video_frames:
frame = frame[:, 280:-280, :]
video_frames_square.append(frame)
imageio.mimsave(enhanced_video_mp4, video_frames_square)
return enhanced_video_mp4
# The main functionality for video to video
def video2video_randomized(
prompt,
video,
where_to_log,
cfg_v2v,
model_v2v,
square=True,
chunk_size=24,
overlap_size=8,
negative_prompt="",
):
downscale = cfg_v2v["downscale"]
upscale_size = cfg_v2v["upscale_size"]
pad = cfg_v2v["pad"]
now = datetime.datetime.now()
name = (
prompt[:100].replace(" ", "_")
+ "_"
+ str(now.time()).replace(":", "_").replace(".", "_")
)
enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4")
video_frames = imageio.mimread(video)
h, w, _ = video_frames[0].shape
n_chunks = (len(video_frames) - overlap_size) // (chunk_size - overlap_size)
trim_length = n_chunks * (chunk_size - overlap_size) + overlap_size
if trim_length < chunk_size:
raise ValueError(
f"Chunk size [{chunk_size}] cannot be larger than the number of frames in the video [{len(video_frames)}], please provide smaller chunk size"
)
if trim_length < len(video_frames):
print(
"Video cannot be processed with chunk size {chunk_size} and overlap size {overlap_size}, "
"trimming it to length {trim_length} to be able to process it"
)
video_frames = video_frames[:trim_length]
model_v2v.chunk_size = chunk_size
model_v2v.overlap_size = overlap_size
# Downscale video, then resize to fit the upscale size
video = [
Image.fromarray(frame).resize((w // downscale, h // downscale))
for frame in video_frames
]
video = [resize_to_fit(frame, upscale_size) for frame in video]
if pad:
video = [pad_to_fit(frame, upscale_size) for frame in video]
video = list(map(np.array, video))
imageio.mimsave(opj(where_to_log, "temp.mp4"), video, fps=8)
p_input = {
"video_path": opj(where_to_log, "temp.mp4"),
"text": prompt,
"positive_prompt": "",
"negative_prompt": negative_prompt,
"total_noise_levels": 600,
}
output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[
OutputKeys.OUTPUT_VIDEO
]
return enhanced_video_mp4