|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""A test script for mid frame interpolation from two input frames. |
|
|
|
Usage example: |
|
python3 -m frame_interpolation.eval.interpolator_test \ |
|
--frame1 <filepath of the first frame> \ |
|
--frame2 <filepath of the second frame> \ |
|
--model_path <The filepath of the TF2 saved model to use> |
|
|
|
The output is saved to <the directory of the input frames>/output_frame.png. If |
|
`--output_frame` filepath is provided, it will be used instead. |
|
""" |
|
import os |
|
from typing import Sequence |
|
|
|
from . import interpolator as interpolator_lib |
|
from . import util |
|
from absl import app |
|
from absl import flags |
|
import numpy as np |
|
|
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' |
|
|
|
|
|
_FRAME1 = flags.DEFINE_string( |
|
name='frame1', |
|
default=None, |
|
help='The filepath of the first input frame.', |
|
required=True) |
|
_FRAME2 = flags.DEFINE_string( |
|
name='frame2', |
|
default=None, |
|
help='The filepath of the second input frame.', |
|
required=True) |
|
_MODEL_PATH = flags.DEFINE_string( |
|
name='model_path', |
|
default=None, |
|
help='The path of the TF2 saved model to use.') |
|
_OUTPUT_FRAME = flags.DEFINE_string( |
|
name='output_frame', |
|
default=None, |
|
help='The output filepath of the interpolated mid-frame.') |
|
_ALIGN = flags.DEFINE_integer( |
|
name='align', |
|
default=64, |
|
help='If >1, pad the input size so it is evenly divisible by this value.') |
|
_BLOCK_HEIGHT = flags.DEFINE_integer( |
|
name='block_height', |
|
default=1, |
|
help='An int >= 1, number of patches along height, ' |
|
'patch_height = height//block_height, should be evenly divisible.') |
|
_BLOCK_WIDTH = flags.DEFINE_integer( |
|
name='block_width', |
|
default=1, |
|
help='An int >= 1, number of patches along width, ' |
|
'patch_width = width//block_width, should be evenly divisible.') |
|
|
|
|
|
def _run_interpolator() -> None: |
|
"""Writes interpolated mid frame from a given two input frame filepaths.""" |
|
|
|
interpolator = interpolator_lib.Interpolator( |
|
model_path=_MODEL_PATH.value, |
|
align=_ALIGN.value, |
|
block_shape=[_BLOCK_HEIGHT.value, _BLOCK_WIDTH.value]) |
|
|
|
|
|
image_1 = util.read_image(_FRAME1.value) |
|
image_batch_1 = np.expand_dims(image_1, axis=0) |
|
|
|
|
|
image_2 = util.read_image(_FRAME2.value) |
|
image_batch_2 = np.expand_dims(image_2, axis=0) |
|
|
|
|
|
batch_dt = np.full(shape=(1,), fill_value=0.5, dtype=np.float32) |
|
|
|
|
|
mid_frame = interpolator(image_batch_1, image_batch_2, batch_dt)[0] |
|
|
|
|
|
mid_frame_filepath = _OUTPUT_FRAME.value |
|
if not mid_frame_filepath: |
|
mid_frame_filepath = f'{os.path.dirname(_FRAME1.value)}/output_frame.png' |
|
util.write_image(mid_frame_filepath, mid_frame) |
|
|
|
|
|
def main(argv: Sequence[str]) -> None: |
|
if len(argv) > 1: |
|
raise app.UsageError('Too many command-line arguments.') |
|
_run_interpolator() |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(main) |
|
|