Baraaqasem's picture
Upload 585 files
5d32408 verified
from typing import List
from torch import _validate_compressed_sparse_indices
from torchvision.utils import save_image
from videogen_hub import MODEL_PATH
from with_mask_sample import *
class SEINEPipeline():
def __init__(self, seine_path: str = os.path.join(MODEL_PATH, "SEINE", "seine.pt"),
pretrained_model_path: str = os.path.join(MODEL_PATH, "SEINE", "stable-diffusion-v1-4"),
config_path: str = "src/videogen_hub/pipelines/seine/sample_i2v.yaml"):
"""
Load the configuration file and set the paths of models.
Args:
seine_path: The path of the downloaded seine pretrained model.
pretrained_model_path: The path of the downloaded stable diffusion pretrained model.
config_path: The path of the configuration file.
"""
self.config = OmegaConf.load(config_path)
self.config.ckpt = seine_path
self.config.pretrained_model_path = pretrained_model_path
def infer_one_video(self, input_image,
text_prompt: List = [],
output_size: List = [240, 560],
num_frames: int = 16,
num_sampling_steps: int = 250,
seed: int = 42,
save_video: bool = False):
"""
Generate video based on provided input_image and text_prompt.
Args:
input_image: The input image to generate video.
text_prompt: The text prompt to generate video.
output_size: The size of the generated video. Defaults to [240, 560].
num_frames: number of frames of the generated video. Defaults to 16.
num_sampling_steps: number of sampling steps to generate the video. Defaults to 250.
seed: The random seed for video generation. Defaults to 42.
save_video: save the video to the path in config if it is True. Not save if it is False. Defaults to False.
Returns:
The generated video as tensor with shape (num_frames, channels, height, width).
"""
self.config.image_size = output_size
self.config.num_frames = num_frames
self.config.num_sampling_steps = num_sampling_steps
self.config.seed = seed
self.config.text_prompt = text_prompt
print(input_image, type(input_image) == str)
if type(input_image) == str:
self.config.input_path = input_image
else:
assert torch.is_tensor(input_image)
assert len(input_image.shape) == 3
assert input_image.shape[0] == 3
save_image(input_image, "src/videogen_hub/pipelines/seine/input_image.png")
args = self.config
# Setup PyTorch:
if args.seed:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
if args.ckpt is None:
raise ValueError("Please specify a checkpoint path using --ckpt <path>")
# Load model:
latent_h = args.image_size[0] // 8
latent_w = args.image_size[1] // 8
args.image_h = args.image_size[0]
args.image_w = args.image_size[1]
args.latent_h = latent_h
args.latent_w = latent_w
print('loading model')
model = get_models(args).to(device)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
model.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# load model
ckpt_path = args.ckpt
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)['ema']
model.load_state_dict(state_dict)
print('loading succeed')
model.eval()
pretrained_model_path = args.pretrained_model_path
diffusion = create_diffusion(str(args.num_sampling_steps))
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
text_encoder = TextEmbedder(pretrained_model_path).to(device)
if args.use_fp16:
print('Warnning: using half percision for inferencing!')
vae.to(dtype=torch.float16)
model.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
# prompt:
prompt = args.text_prompt
if prompt is None or prompt == []:
prompt = args.input_path.split('/')[-1].split('.')[0].replace('_', ' ')
else:
prompt = prompt[0]
prompt_base = prompt.replace(' ', '_')
prompt = prompt + args.additional_prompt
if save_video:
if not os.path.exists(os.path.join(args.save_path)):
os.makedirs(os.path.join(args.save_path))
video_input, researve_frames = get_input(args) # f,c,h,w
video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
masked_video = video_input * (mask == 0)
video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model,
device, )
video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3,
1)
if save_video:
save_video_path = os.path.join(args.save_path, prompt_base + '.mp4')
torchvision.io.write_video(save_video_path, video_, fps=8)
print(f'save in {save_video_path}')
return video_.permute(0, 3, 1, 2)