|
|
|
import argparse |
|
import binascii |
|
import os |
|
import os.path as osp |
|
|
|
import imageio |
|
import torch |
|
import torchvision |
|
|
|
__all__ = ['cache_video', 'cache_image', 'str2bool'] |
|
|
|
|
|
def rand_name(length=8, suffix=''): |
|
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') |
|
if suffix: |
|
if not suffix.startswith('.'): |
|
suffix = '.' + suffix |
|
name += suffix |
|
return name |
|
|
|
|
|
def cache_video(tensor, |
|
save_file=None, |
|
fps=30, |
|
suffix='.mp4', |
|
nrow=8, |
|
normalize=True, |
|
value_range=(-1, 1), |
|
retry=5): |
|
|
|
cache_file = osp.join('/tmp', rand_name( |
|
suffix=suffix)) if save_file is None else save_file |
|
|
|
|
|
error = None |
|
for _ in range(retry): |
|
try: |
|
|
|
tensor = tensor.clamp(min(value_range), max(value_range)) |
|
tensor = torch.stack([ |
|
torchvision.utils.make_grid( |
|
u, nrow=nrow, normalize=normalize, value_range=value_range) |
|
for u in tensor.unbind(2) |
|
], |
|
dim=1).permute(1, 2, 3, 0) |
|
tensor = (tensor * 255).type(torch.uint8).cpu() |
|
|
|
|
|
writer = imageio.get_writer( |
|
cache_file, fps=fps, codec='libx264', quality=8) |
|
for frame in tensor.numpy(): |
|
writer.append_data(frame) |
|
writer.close() |
|
return cache_file |
|
except Exception as e: |
|
error = e |
|
continue |
|
else: |
|
print(f'cache_video failed, error: {error}', flush=True) |
|
return None |
|
|
|
|
|
def cache_image(tensor, |
|
save_file, |
|
nrow=8, |
|
normalize=True, |
|
value_range=(-1, 1), |
|
retry=5): |
|
|
|
suffix = osp.splitext(save_file)[1] |
|
if suffix.lower() not in [ |
|
'.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' |
|
]: |
|
suffix = '.png' |
|
|
|
|
|
error = None |
|
for _ in range(retry): |
|
try: |
|
tensor = tensor.clamp(min(value_range), max(value_range)) |
|
torchvision.utils.save_image( |
|
tensor, |
|
save_file, |
|
nrow=nrow, |
|
normalize=normalize, |
|
value_range=value_range) |
|
return save_file |
|
except Exception as e: |
|
error = e |
|
continue |
|
|
|
|
|
def str2bool(v): |
|
""" |
|
Convert a string to a boolean. |
|
|
|
Supported true values: 'yes', 'true', 't', 'y', '1' |
|
Supported false values: 'no', 'false', 'f', 'n', '0' |
|
|
|
Args: |
|
v (str): String to convert. |
|
|
|
Returns: |
|
bool: Converted boolean value. |
|
|
|
Raises: |
|
argparse.ArgumentTypeError: If the value cannot be converted to boolean. |
|
""" |
|
if isinstance(v, bool): |
|
return v |
|
v_lower = v.lower() |
|
if v_lower in ('yes', 'true', 't', 'y', '1'): |
|
return True |
|
elif v_lower in ('no', 'false', 'f', 'n', '0'): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError('Boolean value expected (True/False)') |
|
|