File size: 2,760 Bytes
e9cfe41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
from PIL import Image
import numpy as np
from model import ImageToVideoModel
from de_en.tokenizer import VideoTokenizer
import argparse
import os
class ImageToVideoGenerator:
def __init__(self, model_path, config):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = ImageToVideoModel(
encoder_config=config['encoder'],
decoder_config=config['decoder'],
transformer_config=config['transformer']
).to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
self.tokenizer = VideoTokenizer(config['resolution'])
self.config = config
def generate_video(self, image_path, output_path, num_frames=24, speed_level=5):
# Load and process image
image = Image.open(image_path).convert('RGB')
image_tensor = self.tokenizer.encode_image(image).unsqueeze(0).to(self.device)
# Generate video frames
frames = self.model.generate(image_tensor, num_frames, speed_level)
# Convert to video and save
self.tokenizer.save_video(frames.squeeze(0).cpu(), output_path)
return output_path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True, help="Input image path")
parser.add_argument("--output", type=str, required=True, help="Output video path")
parser.add_argument("--frames", type=int, default=24, help="Number of frames to generate")
parser.add_argument("--speed", type=int, default=5, help="Generation speed level (0-9)")
parser.add_argument("--model", type=str, default="image_to_video_model.pth", help="Model path")
parser.add_argument("--resolution", type=int, default=128, help="Video resolution")
args = parser.parse_args()
config = {
'resolution': args.resolution,
'encoder': {
'in_channels': 3,
'hidden_dims': [64, 128, 256, 512],
'embed_dim': 512
},
'decoder': {
'embed_dim': 512,
'hidden_dims': [512, 256, 128, 64],
'out_channels': 3
},
'transformer': {
'd_model': 512,
'nhead': 8,
'num_encoder_layers': 3,
'num_decoder_layers': 3,
'dim_feedforward': 2048,
'dropout': 0.1
}
}
generator = ImageToVideoGenerator(args.model, config)
generator.generate_video(
args.image,
args.output,
num_frames=args.frames,
speed_level=args.speed
)
print(f"Video generated and saved to {args.output}") |