Files changed (1) hide show
  1. inference.py +169 -0
inference.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ import numpy as np
5
+
6
+ from omegaconf import OmegaConf
7
+ from torchvision.transforms import v2
8
+ from diffusers.utils import load_image
9
+ from einops import rearrange
10
+ from pipeline import CausalInferencePipeline
11
+ from wan.vae.wanx_vae import get_wanx_vae_wrapper
12
+ from demo_utils.vae_block3 import VAEDecoderWrapper
13
+ from utils.visualize import process_video
14
+ from utils.misc import set_seed
15
+ from utils.conditions import *
16
+ from utils.wan_wrapper import WanDiffusionWrapper
17
+ from safetensors.torch import load_file
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--config_path", type=str, default="configs/inference_yaml/inference_universal.yaml", help="Path to the config file")
22
+ parser.add_argument("--checkpoint_path", type=str, default="", help="Path to the checkpoint")
23
+ parser.add_argument("--img_path", type=str, default="demo_images/universal/0000.png", help="Path to the image")
24
+ parser.add_argument("--output_folder", type=str, default="outputs/", help="Output folder")
25
+ parser.add_argument("--num_output_frames", type=int, default=150,
26
+ help="Number of output latent frames")
27
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
28
+ parser.add_argument("--pretrained_model_path", type=str, default="Matrix-Game-2.0", help="Path to the VAE model folder")
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+ class InteractiveGameInference:
33
+ def __init__(self, args):
34
+ self.args = args
35
+ self.device = torch.device("cuda")
36
+ self.weight_dtype = torch.bfloat16
37
+
38
+ self._init_config()
39
+ self._init_models()
40
+
41
+ self.frame_process = v2.Compose([
42
+ v2.Resize(size=(352, 640), antialias=True),
43
+ v2.ToTensor(),
44
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
45
+ ])
46
+
47
+ def _init_config(self):
48
+ self.config = OmegaConf.load(self.args.config_path)
49
+
50
+ def _init_models(self):
51
+ # Initialize pipeline
52
+ generator = WanDiffusionWrapper(
53
+ **getattr(self.config, "model_kwargs", {}), is_causal=True)
54
+ current_vae_decoder = VAEDecoderWrapper()
55
+ vae_state_dict = torch.load(os.path.join(self.args.pretrained_model_path, "Wan2.1_VAE.pth"), map_location="cpu")
56
+ decoder_state_dict = {}
57
+ for key, value in vae_state_dict.items():
58
+ if 'decoder.' in key or 'conv2' in key:
59
+ decoder_state_dict[key] = value
60
+ current_vae_decoder.load_state_dict(decoder_state_dict)
61
+ current_vae_decoder.to(self.device, torch.float16)
62
+ current_vae_decoder.requires_grad_(False)
63
+ current_vae_decoder.eval()
64
+ current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
65
+ pipeline = CausalInferencePipeline(self.config, generator=generator, vae_decoder=current_vae_decoder)
66
+ if self.args.checkpoint_path:
67
+ print("Loading Pretrained Model...")
68
+ state_dict = load_file(self.args.checkpoint_path)
69
+ pipeline.generator.load_state_dict(state_dict)
70
+
71
+ self.pipeline = pipeline.to(device=self.device, dtype=self.weight_dtype)
72
+ self.pipeline.vae_decoder.to(torch.float16)
73
+
74
+ vae = get_wanx_vae_wrapper(self.args.pretrained_model_path, torch.float16)
75
+ vae.requires_grad_(False)
76
+ vae.eval()
77
+ self.vae = vae.to(self.device, self.weight_dtype)
78
+
79
+ def _resizecrop(self, image, th, tw):
80
+ w, h = image.size
81
+ if h / w > th / tw:
82
+ new_w = int(w)
83
+ new_h = int(new_w * th / tw)
84
+ else:
85
+ new_h = int(h)
86
+ new_w = int(new_h * tw / th)
87
+ left = (w - new_w) / 2
88
+ top = (h - new_h) / 2
89
+ right = (w + new_w) / 2
90
+ bottom = (h + new_h) / 2
91
+ image = image.crop((left, top, right, bottom))
92
+ return image
93
+
94
+ def generate_videos(self):
95
+ mode = self.config.pop('mode')
96
+ assert mode in ['universal', 'gta_drive', 'templerun']
97
+
98
+ image = load_image(self.args.img_path)
99
+ image = self._resizecrop(image, 352, 640)
100
+ image = self.frame_process(image)[None, :, None, :, :].to(dtype=self.weight_dtype, device=self.device)
101
+ # Encode the input image as the first latent
102
+ padding_video = torch.zeros_like(image).repeat(1, 1, 4 * (self.args.num_output_frames - 1), 1, 1)
103
+ img_cond = torch.concat([image, padding_video], dim=2)
104
+ tiler_kwargs={"tiled": True, "tile_size": [44, 80], "tile_stride": [23, 38]}
105
+ img_cond = self.vae.encode(img_cond, device=self.device, **tiler_kwargs).to(self.device)
106
+ mask_cond = torch.ones_like(img_cond)
107
+ mask_cond[:, :, 1:] = 0
108
+ cond_concat = torch.cat([mask_cond[:, :4], img_cond], dim=1)
109
+ visual_context = self.vae.clip.encode_video(image)
110
+ sampled_noise = torch.randn(
111
+ [1, 16,self.args.num_output_frames, 44, 80], device=self.device, dtype=self.weight_dtype
112
+ )
113
+ num_frames = (self.args.num_output_frames - 1) * 4 + 1
114
+
115
+ conditional_dict = {
116
+ "cond_concat": cond_concat.to(device=self.device, dtype=self.weight_dtype),
117
+ "visual_context": visual_context.to(device=self.device, dtype=self.weight_dtype)
118
+ }
119
+
120
+ if mode == 'universal':
121
+ cond_data = Bench_actions_universal(num_frames)
122
+ mouse_condition = cond_data['mouse_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
123
+ conditional_dict['mouse_cond'] = mouse_condition
124
+ elif mode == 'gta_drive':
125
+ cond_data = Bench_actions_gta_drive(num_frames)
126
+ mouse_condition = cond_data['mouse_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
127
+ conditional_dict['mouse_cond'] = mouse_condition
128
+ else:
129
+ cond_data = Bench_actions_templerun(num_frames)
130
+ keyboard_condition = cond_data['keyboard_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
131
+ conditional_dict['keyboard_cond'] = keyboard_condition
132
+
133
+ with torch.no_grad():
134
+ videos = self.pipeline.inference(
135
+ noise=sampled_noise,
136
+ conditional_dict=conditional_dict,
137
+ return_latents=False,
138
+ mode=mode,
139
+ profile=False
140
+ )
141
+
142
+ videos_tensor = torch.cat(videos, dim=1)
143
+ videos = rearrange(videos_tensor, "B T C H W -> B T H W C")
144
+ videos = ((videos.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)[0]
145
+ video = np.ascontiguousarray(videos)
146
+ mouse_icon = 'assets/images/mouse.png'
147
+ if mode != 'templerun':
148
+ config = (
149
+ keyboard_condition[0].float().cpu().numpy(),
150
+ mouse_condition[0].float().cpu().numpy()
151
+ )
152
+ else:
153
+ config = (
154
+ keyboard_condition[0].float().cpu().numpy()
155
+ )
156
+ process_video(video.astype(np.uint8), self.args.output_folder+f'/demo.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=False, mode=mode)
157
+ process_video(video.astype(np.uint8), self.args.output_folder+f'/demo_icon.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=True, mode=mode)
158
+ print("Done")
159
+
160
+ def main():
161
+ """Main entry point for video generation."""
162
+ args = parse_args()
163
+ set_seed(args.seed)
164
+ os.makedirs(args.output_folder, exist_ok=True)
165
+ pipeline = InteractiveGameInference(args)
166
+ pipeline.generate_videos()
167
+
168
+ if __name__ == "__main__":
169
+ main()