|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from model import ImageToVideoModel |
|
from de_en.tokenizer import VideoTokenizer |
|
import argparse |
|
import os |
|
|
|
class ImageToVideoGenerator: |
|
def __init__(self, model_path, config): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = ImageToVideoModel( |
|
encoder_config=config['encoder'], |
|
decoder_config=config['decoder'], |
|
transformer_config=config['transformer'] |
|
).to(self.device) |
|
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
self.model.eval() |
|
|
|
self.tokenizer = VideoTokenizer(config['resolution']) |
|
self.config = config |
|
|
|
def generate_video(self, image_path, output_path, num_frames=24, speed_level=5): |
|
|
|
image = Image.open(image_path).convert('RGB') |
|
image_tensor = self.tokenizer.encode_image(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
frames = self.model.generate(image_tensor, num_frames, speed_level) |
|
|
|
|
|
self.tokenizer.save_video(frames.squeeze(0).cpu(), output_path) |
|
|
|
return output_path |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--image", type=str, required=True, help="Input image path") |
|
parser.add_argument("--output", type=str, required=True, help="Output video path") |
|
parser.add_argument("--frames", type=int, default=24, help="Number of frames to generate") |
|
parser.add_argument("--speed", type=int, default=5, help="Generation speed level (0-9)") |
|
parser.add_argument("--model", type=str, default="image_to_video_model.pth", help="Model path") |
|
parser.add_argument("--resolution", type=int, default=128, help="Video resolution") |
|
args = parser.parse_args() |
|
|
|
config = { |
|
'resolution': args.resolution, |
|
'encoder': { |
|
'in_channels': 3, |
|
'hidden_dims': [64, 128, 256, 512], |
|
'embed_dim': 512 |
|
}, |
|
'decoder': { |
|
'embed_dim': 512, |
|
'hidden_dims': [512, 256, 128, 64], |
|
'out_channels': 3 |
|
}, |
|
'transformer': { |
|
'd_model': 512, |
|
'nhead': 8, |
|
'num_encoder_layers': 3, |
|
'num_decoder_layers': 3, |
|
'dim_feedforward': 2048, |
|
'dropout': 0.1 |
|
} |
|
} |
|
|
|
generator = ImageToVideoGenerator(args.model, config) |
|
generator.generate_video( |
|
args.image, |
|
args.output, |
|
num_frames=args.frames, |
|
speed_level=args.speed |
|
) |
|
|
|
print(f"Video generated and saved to {args.output}") |