File size: 4,061 Bytes
1dbaf53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import math
import json
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import torch
from torchvision.transforms import Compose, Lambda, ToTensor
from torchvision.transforms.functional import to_pil_image


def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def load_jsonl(file_path):
    with open(file_path, 'r') as f:
        return [json.loads(l) for l in f]
    
def save_json(data, file_path):
    with open(file_path, 'w') as f:
        json.dump(data, f)

def save_jsonl(data, file_path):
    with open(file_path, 'w') as f:
        for d in data:
            f.write(json.dumps(d) + '\n')

def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def load_frames(frame_names, num_frames=None):
    frame_names.sort()
    # sample frames
    if num_frames is not None and len(frame_names) != num_frames:
        duration = len(frame_names)
        frame_id_array = np.linspace(0, duration-1, num_frames, dtype=int)
        frame_id_list = frame_id_array.tolist()
    else:
        frame_id_list = range(num_frames)

    results = []
    for frame_idx in frame_id_list:
        frame_name = frame_names[frame_idx]
        results.append(load_image(frame_name))

    return results


def load_video_into_frames(
        video_path,
        video_decode_backend='opencv',
        num_frames=8,
        return_tensor=False,
):
    print("VIDEO PATH !!!", video_path)
    if video_decode_backend == 'decord':
        import decord
        from decord import VideoReader, cpu
        decord.bridge.set_bridge('torch')
        decord_vr = VideoReader(video_path, ctx=cpu(0))
        duration = len(decord_vr)
        frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
        video_data = decord_vr.get_batch(frame_id_list)
        if return_tensor:
            video_data = video_data.permute(3, 0, 1, 2)  # (T, H, W, C) -> (C, T, H, W)
        else:
            video_data = [to_pil_image(f) for f in video_data]
    elif video_decode_backend == 'frames':
        frames = load_frames([os.path.join(video_path, imname) 
                              for imname in os.listdir(video_path)],
                             num_frames=num_frames)
        video_data = frames
        if return_tensor:
            to_tensor = ToTensor()
            video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W)
    elif video_decode_backend == 'opencv':
        import cv2
        cv2_vr = cv2.VideoCapture(video_path)
        duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int)
        # frame_id_list = np.linspace(0, duration-5, num_frames, dtype=int)

        video_data = []
        for frame_idx in frame_id_list:
            cv2_vr.set(1, frame_idx)
            ret, frame = cv2_vr.read()
            if not ret:
                raise ValueError(f'video error at {video_path}')
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            if return_tensor:
                video_data.append(torch.from_numpy(frame).permute(2, 0, 1))
            else:
                video_data.append(Image.fromarray(frame))
        cv2_vr.release()
        if return_tensor:
            video_data = torch.stack(video_data, dim=1)
    else:
        raise NameError(f'video_decode_backend should specify in (pytorchvideo, decord, opencv, frames) but got {video_decode_backend}')
    return video_data