Spaces:
Running
Running
| #Code outsourced from https://github.com/deepmind/dmvr/tree/master and later modified. | |
| """Python script to generate TFRecords of SequenceExample from raw videos.""" | |
| import contextlib | |
| import math | |
| import os | |
| import cv2 | |
| from typing import Dict, Optional, Sequence | |
| import moviepy.editor | |
| from absl import app | |
| from absl import flags | |
| import ffmpeg | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| flags.DEFINE_string("csv_path", "fakeavceleb_1k.csv", "Input csv") | |
| flags.DEFINE_string("output_path", "fakeavceleb_tfrec", "Tfrecords output path.") | |
| flags.DEFINE_string("video_root_path", "./", | |
| "Root directory containing the raw videos.") | |
| flags.DEFINE_integer( | |
| "num_shards", 4, "Number of shards to output, -1 means" | |
| "it will automatically adapt to the sqrt(num_examples).") | |
| flags.DEFINE_bool("decode_audio", False, "Whether or not to decode the audio") | |
| flags.DEFINE_bool("shuffle_csv", False, "Whether or not to shuffle the csv.") | |
| FLAGS = flags.FLAGS | |
| _JPEG_HEADER = b"\xff\xd8" | |
| def _close_on_exit(writers): | |
| """Call close on all writers on exit.""" | |
| try: | |
| yield writers | |
| finally: | |
| for writer in writers: | |
| writer.close() | |
| def add_float_list(key: str, values: Sequence[float], | |
| sequence: tf.train.SequenceExample): | |
| sequence.feature_lists.feature_list[key].feature.add( | |
| ).float_list.value[:] = values | |
| def add_bytes_list(key: str, values: Sequence[bytes], | |
| sequence: tf.train.SequenceExample): | |
| sequence.feature_lists.feature_list[key].feature.add().bytes_list.value[:] = values | |
| def add_int_list(key: str, values: Sequence[int], | |
| sequence: tf.train.SequenceExample): | |
| sequence.feature_lists.feature_list[key].feature.add().int64_list.value[:] = values | |
| def set_context_int_list(key: str, value: Sequence[int], | |
| sequence: tf.train.SequenceExample): | |
| sequence.context.feature[key].int64_list.value[:] = value | |
| def set_context_bytes(key: str, value: bytes, | |
| sequence: tf.train.SequenceExample): | |
| sequence.context.feature[key].bytes_list.value[:] = (value,) | |
| def set_context_bytes_list(key: str, value: Sequence[bytes], | |
| sequence: tf.train.SequenceExample): | |
| sequence.context.feature[key].bytes_list.value[:] = value | |
| def set_context_float(key: str, value: float, | |
| sequence: tf.train.SequenceExample): | |
| sequence.context.feature[key].float_list.value[:] = (value,) | |
| def set_context_int(key: str, value: int, sequence: tf.train.SequenceExample): | |
| sequence.context.feature[key].int64_list.value[:] = (value,) | |
| def extract_frames(video_path, fps = 10, min_resize = 256): | |
| '''Load n number of frames from a video''' | |
| v_cap = cv2.VideoCapture(video_path) | |
| v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if fps is None: | |
| sample = np.arange(0, v_len) | |
| else: | |
| sample = np.linspace(0, v_len - 1, fps).astype(int) | |
| frames = [] | |
| for j in range(v_len): | |
| success = v_cap.grab() | |
| if j in sample: | |
| success, frame = v_cap.retrieve() | |
| if not success: | |
| continue | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = cv2.resize(frame, (min_resize, min_resize)) | |
| frames.append(frame) | |
| v_cap.release() | |
| frame_np = np.stack(frames) | |
| return frame_np.tobytes() | |
| def extract_audio(video_path: str, | |
| sampling_rate: int = 16_000): | |
| """Extract raw mono audio float list from video_path with ffmpeg.""" | |
| video = moviepy.editor.VideoFileClip(video_path) | |
| audio = video.audio.to_soundarray() | |
| #Load first channel. | |
| audio = audio[:, 0] | |
| return np.array(audio) | |
| #Each of the features can be coerced into a tf.train.Example-compatible type using one of the _bytes_feature, _float_feature and the _int64_feature. | |
| #You can then create a tf.train.Example message from these encoded features. | |
| def serialize_example(video_path: str, label_name: str, label_map: Optional[Dict[str, int]] = None): | |
| # Initiate the sequence example. | |
| seq_example = tf.train.SequenceExample() | |
| imgs_encoded = extract_frames(video_path, fps = 10) | |
| audio = extract_audio(video_path) | |
| set_context_bytes(f'image/encoded', imgs_encoded, seq_example) | |
| set_context_bytes("video_path", video_path.encode(), seq_example) | |
| set_context_bytes("WAVEFORM/feature/floats", audio.tobytes(), seq_example) | |
| set_context_int("clip/label/index", label_map[label_name], seq_example) | |
| set_context_bytes("clip/label/text", label_name.encode(), seq_example) | |
| return seq_example | |
| def main(argv): | |
| del argv | |
| # reads the input csv. | |
| input_csv = pd.read_csv(FLAGS.csv_path) | |
| if FLAGS.num_shards == -1: | |
| num_shards = int(math.sqrt(len(input_csv))) | |
| else: | |
| num_shards = FLAGS.num_shards | |
| # Set up the TFRecordWriters. | |
| basename = os.path.splitext(os.path.basename(FLAGS.csv_path))[0] | |
| shard_names = [ | |
| os.path.join(FLAGS.output_path, f"{basename}-{i:05d}-of-{num_shards:05d}") | |
| for i in range(num_shards) | |
| ] | |
| writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] | |
| if "label" in input_csv: | |
| unique_labels = list(set(input_csv["label"].values)) | |
| l_map = {unique_labels[i]: i for i in range(len(unique_labels))} | |
| else: | |
| l_map = None | |
| if FLAGS.shuffle_csv: | |
| input_csv = input_csv.sample(frac=1) | |
| with _close_on_exit(writers) as writers: | |
| row_count = 0 | |
| for row in input_csv.itertuples(): | |
| index = row[0] | |
| v = row[1] | |
| if os.name == 'posix': | |
| v = v.str.replace('\\', '/') | |
| l = row[2] | |
| row_count += 1 | |
| print("Processing example %d of %d (%d%%) \r" %(row_count, len(input_csv), row_count * 100 / len(input_csv)), end="") | |
| seq_ex = serialize_example(video_path = v, label_name = l,label_map = l_map) | |
| writers[index % len(writers)].write(seq_ex.SerializeToString()) | |
| if __name__ == "__main__": | |
| app.run(main) | |