Spaces:
Runtime error
Runtime error
# General | |
import os | |
from os.path import join as opj | |
import datetime | |
import torch | |
from einops import rearrange, repeat | |
# Utilities | |
from videogen_hub.pipelines.streamingt2v.inference_utils import * | |
from modelscope.outputs import OutputKeys | |
import imageio | |
from PIL import Image | |
import numpy as np | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from diffusers.utils import load_image | |
transform = transforms.Compose([transforms.PILToTensor()]) | |
def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"): | |
frames = ms_model( | |
prompt, | |
num_inference_steps=t, | |
generator=inference_generator, | |
eta=1.0, | |
height=256, | |
width=256, | |
latents=None, | |
).frames | |
frames = torch.stack([torch.from_numpy(frame) for frame in frames]) | |
frames = frames.to(device).to(torch.float32) | |
return rearrange(frames[0], "F W H C -> F C W H") | |
def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"): | |
frames = ad_model( | |
prompt, | |
negative_prompt="bad quality, worse quality", | |
num_frames=16, | |
num_inference_steps=t, | |
generator=inference_generator, | |
guidance_scale=7.5, | |
).frames[0] | |
frames = torch.stack([transform(frame) for frame in frames]) | |
frames = frames.to(device).to(torch.float32) | |
frames = F.interpolate(frames, size=256) | |
frames = frames / 255.0 | |
return frames | |
def sdxl_image_gen(prompt, sdxl_model): | |
image = sdxl_model(prompt=prompt).images[0] | |
return image | |
def svd_short_gen( | |
image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda" | |
): | |
if image is None or image == "": | |
image = sdxl_image_gen(prompt, sdxl_model) | |
image = image.resize((576, 576)) | |
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) | |
elif type(image) is str: | |
image = load_image(image) | |
image = resize_and_keep(image) | |
image = center_crop(image) | |
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) | |
else: | |
image = Image.fromarray(np.uint8(image)) | |
image = resize_and_keep(image) | |
image = center_crop(image) | |
image = add_margin(image, 0, 224, 0, 224, (0, 0, 0)) | |
frames = svd_model( | |
image, decode_chunk_size=8, generator=inference_generator | |
).frames[0] | |
frames = torch.stack([transform(frame) for frame in frames]) | |
frames = frames.to(device).to(torch.float32) | |
frames = frames[:16, :, :, 224:-224] | |
frames = F.interpolate(frames, size=256) | |
frames = frames / 255.0 | |
return frames | |
def stream_long_gen( | |
prompt, | |
short_video, | |
n_autoreg_gen, | |
negative_prompt, | |
seed, | |
t, | |
image_guidance, | |
result_file_stem, | |
stream_cli, | |
stream_model, | |
): | |
trainer = stream_cli.trainer | |
trainer.limit_predict_batches = 1 | |
trainer.predict_cfg = { | |
"predict_dir": stream_cli.config["result_fol"].as_posix(), | |
"result_file_stem": result_file_stem, | |
"prompt": prompt, | |
"video": short_video, | |
"seed": seed, | |
"num_inference_steps": t, | |
"guidance_scale": image_guidance, | |
"n_autoregressive_generations": n_autoreg_gen, | |
} | |
stream_model.inference_params.negative_prompt = negative_prompt | |
trainer.predict(model=stream_model, datamodule=stream_cli.datamodule) | |
def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True): | |
downscale = cfg_v2v["downscale"] | |
upscale_size = cfg_v2v["upscale_size"] | |
pad = cfg_v2v["pad"] | |
now = datetime.datetime.now() | |
now = str(now.time()).replace(":", "_").replace(".", "_") | |
name = prompt[:100].replace(" ", "_") + "_" + now | |
enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4") | |
video_frames = imageio.mimread(video) | |
h, w, _ = video_frames[0].shape | |
# Downscale video, then resize to fit the upscale size | |
video = [ | |
Image.fromarray(frame).resize((w // downscale, h // downscale)) | |
for frame in video_frames | |
] | |
video = [resize_to_fit(frame, upscale_size) for frame in video] | |
if pad: | |
video = [pad_to_fit(frame, upscale_size) for frame in video] | |
# video = [np.array(frame) for frame in video] | |
imageio.mimsave(opj(where_to_log, "temp_" + now + ".mp4"), video, fps=8) | |
p_input = { | |
"video_path": opj(where_to_log, "temp_" + now + ".mp4"), | |
"text": prompt, | |
"positive_prompt": prompt, | |
"total_noise_levels": 600, | |
} | |
model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO] | |
# Remove padding | |
video_frames = imageio.mimread(enhanced_video_mp4) | |
video_frames_square = [] | |
for frame in video_frames: | |
frame = frame[:, 280:-280, :] | |
video_frames_square.append(frame) | |
imageio.mimsave(enhanced_video_mp4, video_frames_square) | |
return enhanced_video_mp4 | |
# The main functionality for video to video | |
def video2video_randomized( | |
prompt, | |
video, | |
where_to_log, | |
cfg_v2v, | |
model_v2v, | |
square=True, | |
chunk_size=24, | |
overlap_size=8, | |
negative_prompt="", | |
): | |
downscale = cfg_v2v["downscale"] | |
upscale_size = cfg_v2v["upscale_size"] | |
pad = cfg_v2v["pad"] | |
now = datetime.datetime.now() | |
name = ( | |
prompt[:100].replace(" ", "_") | |
+ "_" | |
+ str(now.time()).replace(":", "_").replace(".", "_") | |
) | |
enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4") | |
video_frames = imageio.mimread(video) | |
h, w, _ = video_frames[0].shape | |
n_chunks = (len(video_frames) - overlap_size) // (chunk_size - overlap_size) | |
trim_length = n_chunks * (chunk_size - overlap_size) + overlap_size | |
if trim_length < chunk_size: | |
raise ValueError( | |
f"Chunk size [{chunk_size}] cannot be larger than the number of frames in the video [{len(video_frames)}], please provide smaller chunk size" | |
) | |
if trim_length < len(video_frames): | |
print( | |
"Video cannot be processed with chunk size {chunk_size} and overlap size {overlap_size}, " | |
"trimming it to length {trim_length} to be able to process it" | |
) | |
video_frames = video_frames[:trim_length] | |
model_v2v.chunk_size = chunk_size | |
model_v2v.overlap_size = overlap_size | |
# Downscale video, then resize to fit the upscale size | |
video = [ | |
Image.fromarray(frame).resize((w // downscale, h // downscale)) | |
for frame in video_frames | |
] | |
video = [resize_to_fit(frame, upscale_size) for frame in video] | |
if pad: | |
video = [pad_to_fit(frame, upscale_size) for frame in video] | |
video = list(map(np.array, video)) | |
imageio.mimsave(opj(where_to_log, "temp.mp4"), video, fps=8) | |
p_input = { | |
"video_path": opj(where_to_log, "temp.mp4"), | |
"text": prompt, | |
"positive_prompt": "", | |
"negative_prompt": negative_prompt, | |
"total_noise_levels": 600, | |
} | |
output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[ | |
OutputKeys.OUTPUT_VIDEO | |
] | |
return enhanced_video_mp4 | |