|
from argparse import ArgumentParser, Namespace |
|
|
|
import torch |
|
|
|
from accelerate.utils import set_seed |
|
from utils.inference import ( |
|
V1InferenceLoop, |
|
BSRInferenceLoop, BFRInferenceLoop, BIDInferenceLoop, UnAlignedBFRInferenceLoop |
|
) |
|
|
|
|
|
def check_device(device: str) -> str: |
|
if device == "cuda": |
|
if not torch.cuda.is_available(): |
|
print("CUDA not available because the current PyTorch install was not " |
|
"built with CUDA enabled.") |
|
device = "cpu" |
|
else: |
|
if device == "mps": |
|
if not torch.backends.mps.is_available(): |
|
if not torch.backends.mps.is_built(): |
|
print("MPS not available because the current PyTorch install was not " |
|
"built with MPS enabled.") |
|
device = "cpu" |
|
else: |
|
print("MPS not available because the current MacOS version is not 12.3+ " |
|
"and/or you do not have an MPS-enabled device on this machine.") |
|
device = "cpu" |
|
print(f"using device {device}") |
|
return device |
|
|
|
|
|
def parse_args() -> Namespace: |
|
parser = ArgumentParser() |
|
|
|
parser.add_argument("--task", type=str, required=True, choices=["sr", "dn", "fr", "fr_bg"]) |
|
parser.add_argument("--upscale", type=float, required=True) |
|
parser.add_argument("--version", type=str, default="v2", choices=["v1", "v2"]) |
|
|
|
parser.add_argument("--steps", type=int, default=50) |
|
parser.add_argument("--better_start", action="store_true") |
|
parser.add_argument("--tiled", action="store_true") |
|
parser.add_argument("--tile_size", type=int, default=512) |
|
parser.add_argument("--tile_stride", type=int, default=256) |
|
parser.add_argument("--pos_prompt", type=str, default="") |
|
parser.add_argument("--neg_prompt", type=str, default="low quality, blurry, low-resolution, noisy, unsharp, weird textures") |
|
parser.add_argument("--cfg_scale", type=float, default=4.0) |
|
|
|
parser.add_argument("--input", type=str, required=True) |
|
parser.add_argument("--n_samples", type=int, default=1) |
|
|
|
parser.add_argument("--guidance", action="store_true") |
|
parser.add_argument("--g_loss", type=str, default="w_mse", choices=["mse", "w_mse"]) |
|
parser.add_argument("--g_scale", type=float, default=0.0) |
|
parser.add_argument("--g_start", type=int, default=1001) |
|
parser.add_argument("--g_stop", type=int, default=-1) |
|
parser.add_argument("--g_space", type=str, default="latent") |
|
parser.add_argument("--g_repeat", type=int, default=1) |
|
|
|
parser.add_argument("--output", type=str, required=True) |
|
|
|
parser.add_argument("--seed", type=int, default=231) |
|
parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"]) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
args.device = check_device(args.device) |
|
set_seed(args.seed) |
|
if args.version == "v1": |
|
V1InferenceLoop(args).run() |
|
else: |
|
supported_tasks = { |
|
"sr": BSRInferenceLoop, |
|
"dn": BIDInferenceLoop, |
|
"fr": BFRInferenceLoop, |
|
"fr_bg": UnAlignedBFRInferenceLoop |
|
} |
|
supported_tasks[args.task](args).run() |
|
print("done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|