File size: 8,232 Bytes
1772f26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Evaluate the frame interpolation model from a tfrecord and store results.

This script runs the inference on examples in a tfrecord and generates images
and numeric results according to the gin config. For details, see the
run_evaluation() function below.

Usage example:
  python3 -m frame_interpolation.eval.eval_cli -- \
    --gin_config <path to eval_dataset.gin> \
    --base_folder <the root directory to all training sessions> \
    --label < the foldername of the training session>

or

  python3 -m frame_interpolation.eval.eval_cli -- \
    --gin_config <path to eval_dataset.gin> \
    --model_path <The filepath of the TF2 saved model>

The output is saved at the parent directory of the `model_path`:
<parent directory of model_path>/batch_eval.

The evaluation is run on a GPU by default. Add the `--mode` argument for others.
"""
import collections
import os
from typing import Any, Dict

from . import util
from absl import app
from absl import flags
from absl import logging
import gin.tf
from ..losses import losses
import numpy as np
import tensorflow as tf
from ..training import data_lib


_GIN_CONFIG = flags.DEFINE_string('gin_config', None, 'Gin config file.')
_LABEL = flags.DEFINE_string(
    'label', None, 'Descriptive label for the training session to eval.')
_BASE_FOLDER = flags.DEFINE_string('base_folder', None,
                                   'Root folder of training sessions.')
_MODEL_PATH = flags.DEFINE_string(
    name='model_path',
    default=None,
    help='The path of the TF2 saved model to use. If _MODEL_PATH argument is '
    'directly specified, _LABEL and _BASE_FOLDER arguments will be ignored.')
_OUTPUT_FRAMES = flags.DEFINE_boolean(
    name='output_frames',
    default=False,
    help='If true, saves the the inputs, groud-truth and interpolated frames.')
_MODE = flags.DEFINE_enum('mode', 'gpu', ['cpu', 'gpu'],
                          'Device to run evaluations.')


@gin.configurable('experiment')
def _get_experiment_config(name) -> Dict[str, Any]:
  """Fetches the gin config."""
  return {
      'name': name,
  }


def _set_visible_devices():
  """Set the visible devices according to running mode."""
  mode_devices = tf.config.list_physical_devices(_MODE.value.upper())
  tf.config.set_visible_devices([], 'GPU')
  tf.config.set_visible_devices([], 'TPU')
  tf.config.set_visible_devices(mode_devices, _MODE.value.upper())
  return


@gin.configurable('evaluation')
def run_evaluation(model_path, tfrecord, output_dir, max_examples, metrics):
  """Runs the eval loop for examples in the tfrecord.

  The evaluation is run for the first 'max_examples' number of examples, and
  resulting images are stored into the given output_dir.  Any tensor that
  appears like an image is stored with its name -- this may include intermediate
  results, depending on what the model outputs.

  Additionally, numeric results are stored into results.csv file within the same
  directory. This includes per-example metrics and the mean across the whole
  dataset.

  Args:
    model_path: Directory TF2 saved model.
    tfrecord: Directory to the tfrecord eval data.
    output_dir: Directory to store the results into.
    max_examples: Maximum examples to evaluate.
    metrics: The names of loss functions to use.
  """
  model = tf.saved_model.load(model_path)

  # Store a 'readme.txt' that contains information on where the data came from.
  with tf.io.gfile.GFile(os.path.join(output_dir, 'readme.txt'), mode='w') as f:
    print('Results for:', file=f)
    print(f' model:   {model_path}', file=f)
    print(f' tfrecord: {tfrecord}', file=f)

  with tf.io.gfile.GFile(
      os.path.join(output_dir, 'results.csv'), mode='w') as csv_file:
    test_losses = losses.test_losses(metrics, [
        1.0,
    ] * len(metrics))
    title_row = ['key'] + list(test_losses)
    print(', '.join(title_row), file=csv_file)

    datasets = data_lib.create_eval_datasets(
        batch_size=1,
        files=[tfrecord],
        names=[os.path.basename(output_dir)],
        max_examples=max_examples)
    dataset = datasets[os.path.basename(output_dir)]

    all_losses = collections.defaultdict(list)
    for example in dataset:
      inputs = {
          'x0': example['x0'],
          'x1': example['x1'],
          'time': example['time'][..., tf.newaxis],
      }
      prediction = model(inputs, training=False)

      # Get the key from encoded mid-frame path.
      path = example['path'][0].numpy().decode('utf-8')
      key = path.rsplit('.', 1)[0].rsplit(os.sep)[-1]

      # Combines both inputs and outputs into a single dictionary:
      combined = {**prediction, **example} if _OUTPUT_FRAMES.value else {}
      for name in combined:
        image = combined[name]
        if isinstance(image, tf.Tensor):
          # This saves any tensor that has a shape that can be interpreted
          # as an image, e.g. (1, H, W, C), where the batch dimension is always
          # 1, H and W are the image height and width, and C is either 1 or 3
          # (grayscale or color image).
          if len(image.shape) == 4 and (image.shape[-1] == 1 or
                                        image.shape[-1] == 3):
            util.write_image(
                os.path.join(output_dir, f'{key}_{name}.png'), image[0].numpy())

      # Evaluate losses if the dataset has ground truth 'y', otherwise just do
      # a visual eval.
      if 'y' in example:
        loss_values = []
        # Clip interpolator output to the range [0,1]. Clipping is done only
        # on the eval loop to get better metrics, but not on the training loop
        # so gradients are not killed.
        prediction['image'] = tf.clip_by_value(prediction['image'], 0., 1.)
        for loss_name, (loss_value_fn, loss_weight_fn) in test_losses.items():
          loss_value = loss_value_fn(example, prediction) * loss_weight_fn(0)
          loss_values.append(loss_value.numpy())
          all_losses[loss_name].append(loss_value.numpy())
        print(f'{key}, {str(loss_values)[1:-1]}', file=csv_file)

    if all_losses:
      totals = [np.mean(all_losses[loss_name]) for loss_name in test_losses]
      print(f'mean, {str(totals)[1:-1]}', file=csv_file)
  totals_dict = {
      loss_name: np.mean(all_losses[loss_name]) for loss_name in test_losses
  }
  logging.info('mean, %s', totals_dict)


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if _MODEL_PATH.value is not None:
    model_path = _MODEL_PATH.value
  else:
    model_path = os.path.join(_BASE_FOLDER.value, _LABEL.value, 'saved_model')

  gin.parse_config_files_and_bindings(
      config_files=[_GIN_CONFIG.value],
      bindings=None,
      skip_unknown=True)

  config = _get_experiment_config()  # pylint: disable=no-value-for-parameter
  eval_name = config['name']
  output_dir = os.path.join(
      os.path.dirname(model_path), 'batch_eval', eval_name)
  logging.info('Creating output_dir @ %s ...', output_dir)

  # Copy config file to <base_folder>/<label>/batch_eval/<eval_name>/config.gin.
  tf.io.gfile.makedirs(output_dir)
  tf.io.gfile.copy(
      _GIN_CONFIG.value, os.path.join(output_dir, 'config.gin'), overwrite=True)

  _set_visible_devices()
  logging.info('Evaluating %s on %s ...', eval_name, [
      el.name.split('/physical_device:')[-1]
      for el in tf.config.get_visible_devices()
  ])
  run_evaluation(model_path=model_path, output_dir=output_dir)  # pylint: disable=no-value-for-parameter

  logging.info('Done. Evaluations saved @ %s.', output_dir)

if __name__ == '__main__':
  app.run(main)