Spaces:
Paused
Paused
File size: 4,061 Bytes
1dbaf53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import math
import json
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
from torchvision.transforms import Compose, Lambda, ToTensor
from torchvision.transforms.functional import to_pil_image
def load_json(file_path):
with open(file_path, 'r') as f:
return json.load(f)
def load_jsonl(file_path):
with open(file_path, 'r') as f:
return [json.loads(l) for l in f]
def save_json(data, file_path):
with open(file_path, 'w') as f:
json.dump(data, f)
def save_jsonl(data, file_path):
with open(file_path, 'w') as f:
for d in data:
f.write(json.dumps(d) + '\n')
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def load_frames(frame_names, num_frames=None):
frame_names.sort()
# sample frames
if num_frames is not None and len(frame_names) != num_frames:
duration = len(frame_names)
frame_id_array = np.linspace(0, duration-1, num_frames, dtype=int)
frame_id_list = frame_id_array.tolist()
else:
frame_id_list = range(num_frames)
results = []
for frame_idx in frame_id_list:
frame_name = frame_names[frame_idx]
results.append(load_image(frame_name))
return results
def load_video_into_frames(
video_path,
video_decode_backend='opencv',
num_frames=8,
return_tensor=False,
):
print("VIDEO PATH !!!", video_path)
if video_decode_backend == 'decord':
import decord
from decord import VideoReader, cpu
decord.bridge.set_bridge('torch')
decord_vr = VideoReader(video_path, ctx=cpu(0))
duration = len(decord_vr)
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list)
if return_tensor:
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
else:
video_data = [to_pil_image(f) for f in video_data]
elif video_decode_backend == 'frames':
frames = load_frames([os.path.join(video_path, imname)
for imname in os.listdir(video_path)],
num_frames=num_frames)
video_data = frames
if return_tensor:
to_tensor = ToTensor()
video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
elif video_decode_backend == 'opencv':
import cv2
cv2_vr = cv2.VideoCapture(video_path)
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
# frame_id_list = np.linspace(0, duration-5, num_frames, dtype=int)
video_data = []
for frame_idx in frame_id_list:
cv2_vr.set(1, frame_idx)
ret, frame = cv2_vr.read()
if not ret:
raise ValueError(f'video error at {video_path}')
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if return_tensor:
video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
else:
video_data.append(Image.fromarray(frame))
cv2_vr.release()
if return_tensor:
video_data = torch.stack(video_data, dim=1)
else:
raise NameError(f'video_decode_backend should specify in (pytorchvideo, decord, opencv, frames) but got {video_decode_backend}')
return video_data
|