# (c) Meta Platforms, Inc. and affiliates. import os import subprocess import torch import torchvision import imageio import glob from SAFMNPP import SAFMNPP def main(input_path, output_path, video_name, model): """ Script for testing video super resolution models. This script uses BasicVSR++ as a demo. Please replace the model loading and prediction sections with your own model. """ tmp_path = os.path.join('/frams', video_name[:-4]) os.makedirs(tmp_path, exist_ok=True) video_path = os.path.join(output_path, video_name) if os.path.exists(video_path): return input_video = torchvision.io.read_video( os.path.join(input_path, video_name)) #torchvision.io.read_video(args.input) normalized_frames = input_video[0].permute(0, 3, 1, 2) # THWC to TCHW normalized_frames = torchvision.transforms.functional.convert_image_dtype(normalized_frames, torch.float32) input_data = normalized_frames.unsqueeze(0) device = torch.device('cuda', 0) #==========Replace the model loading and prediction in this section======== print(f'total frames: {input_data.size(1)}') with torch.no_grad(): frame_idx = 0 for xi in input_data.chunk(100, dim=1): # output.append() frames = model(xi.to(device)).detach_().cpu() for _, frame in enumerate(frames.squeeze(0).unbind(dim=0)): frame = frame.clamp(0, 1) # Clamp values to be between 0 and 1 frame = torchvision.transforms.functional.convert_image_dtype(frame, torch.uint8) frame = frame.squeeze(0).permute(1, 2, 0) # CTHW to HWC if not os.path.exists(os.path.join(tmp_path, f'{frame_idx:08d}.png')): imageio.imwrite(os.path.join(tmp_path, f'{frame_idx:08d}.png'), frame.numpy()) print('save frames : ', os.path.join(tmp_path, f'{frame_idx:08d}.png')) else: print('exist frame : ', os.path.join(tmp_path, f'{frame_idx:08d}.png')) frame_idx+= 1 fps = input_video[2]['video_fps'] cmd = ( f"ffmpeg -r {fps} -i {tmp_path}/%08d.png " f"-c:v libx264 -crf 12 -preset veryfast {video_path}" ) try: subprocess.run(cmd, shell=True, check=True) print("Video created successfully.") # 删除帧图片 for frame_filename in glob.glob(os.path.join(tmp_path, '*.png')): os.remove(frame_filename) print(f"Deleted {frame_filename}") except subprocess.CalledProcessError as e: print(f"An error occurred while trying to run FFmpeg: {e}") if __name__ == '__main__': device = torch.device('cuda', 0) model = SAFMNPP(upscaling_factor=4).to(device) model_path = os.path.join(r'light_safmnpp.pth') model.load_state_dict(torch.load(model_path)['params'], strict=True) input_path = r'ValidationSet-1080p/bitstreams' output_path = r'Video_Output_4X' if not os.path.exists(output_path): os.makedirs(output_path) for video_name in os.listdir(input_path): main(input_path, output_path, video_name, model) print("Done", video_name)