Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2024 <Julius Erbach ETH Zurich> | |
| # | |
| # This file is part of the var_post_samp project and is licensed under the MIT License. | |
| # See the LICENSE file in the project root for more information. | |
| """ | |
| Usage: | |
| python run_image_inv.py --config <config.yaml> | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import csv | |
| import yaml | |
| import torch | |
| import random | |
| import click | |
| import numpy as np | |
| import tqdm | |
| import datetime | |
| import torchvision | |
| from flair.helper_functions import parse_click_context | |
| from flair.pipelines import model_loader | |
| from flair.utils import data_utils | |
| from flair import var_post_samp | |
| dtype = torch.bfloat16 | |
| num_gpus = torch.cuda.device_count() | |
| if num_gpus > 0: | |
| devices = [f"cuda:{i}" for i in range(num_gpus)] | |
| primary_device = devices[0] | |
| print(f"Using devices: {devices}") | |
| print(f"Primary device for operations: {primary_device}") | |
| else: | |
| print("No CUDA devices found. Using CPU.") | |
| devices = ["cpu"] | |
| primary_device = "cpu" | |
| def main(ctx, config_file_arg, target_file, result_folder, mask_file=None): | |
| """Main entry point for image inversion and sampling. | |
| The user must provide either a caption_file (with per-image captions) OR a single prompt for all images in the config YAML file. | |
| """ | |
| with open(config_file_arg, "r") as f: | |
| config = yaml.safe_load(f) | |
| ctx = parse_click_context(ctx) | |
| config.update(ctx) | |
| # Read caption_file and prompt from config | |
| caption_file = config.get("caption_file", None) | |
| prompt = config.get("prompt", None) | |
| # Enforce mutually exclusive caption_file or prompt | |
| if (not caption_file and not prompt) or (caption_file and prompt): | |
| raise ValueError("You must provide either 'caption_file' OR 'prompt' (not both) in the config file. See documentation.") | |
| # wandb removed, so config_dict is just a copy | |
| config_dict = dict(config) | |
| torch.manual_seed(config["seed"]) | |
| np.random.seed(config["seed"]) | |
| random.seed(config["seed"]) | |
| # Use config values as-is (no to_absolute_path) | |
| caption_file = caption_file if caption_file else None | |
| guidance_img_iterator = data_utils.yield_images( | |
| target_file, size=config["resolution"] | |
| ) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| counter = 1 | |
| name = f'results_{config["model"]}_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}_{timestamp}' | |
| candidate = os.path.join(name) | |
| while os.path.exists(candidate): | |
| candidate = os.path.join(f"{name}_{counter}") | |
| counter += 1 | |
| output_folders = data_utils.generate_output_structure( | |
| result_folder, | |
| [ | |
| candidate, | |
| f'input_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}', | |
| f'target_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}', | |
| ], | |
| ) | |
| config_out = os.path.join(os.path.split(output_folders[0])[0], "config.yaml") | |
| with open(config_out, "w") as f: | |
| yaml.safe_dump(config_dict, f) | |
| source_files = list(data_utils.find_files(target_file, ext="png")) | |
| num_images = len(source_files) | |
| print(f"Found {num_images} images.") | |
| # Load captions | |
| if caption_file: | |
| captions = data_utils.load_captions_from_file(caption_file, user_prompt="") | |
| if not captions: | |
| sys.exit("Error: No captions were loaded from the provided caption file.") | |
| if len(captions) != num_images: | |
| print("Warning: Number of captions does not match number of images.") | |
| prompts_in_order = [captions.get(os.path.basename(f), "") for f in source_files] | |
| else: | |
| # Use the single prompt for all images | |
| prompts_in_order = [prompt for _ in range(num_images)] | |
| if any(p == "" for p in prompts_in_order): | |
| print("Warning: Some images might not have corresponding captions or prompt is empty.") | |
| config["prompt"] = prompts_in_order | |
| model, inp_kwargs = model_loader.load_model(config, device=devices) | |
| if mask_file and config["degradation"]["name"] == "Inpainting": | |
| config["degradation"]["kwargs"]["mask"] = mask_file | |
| posterior_model = var_post_samp.VariationalPosterior(model, config) | |
| guidance_img_iterator = data_utils.yield_images( | |
| target_file, size=config["resolution"] | |
| ) | |
| for idx, guidance_img in tqdm.tqdm(enumerate(guidance_img_iterator), total=num_images): | |
| guidance_img = guidance_img.to(dtype).cuda() | |
| y = posterior_model.forward_operator(guidance_img) | |
| tic = time.time() | |
| with torch.no_grad(): | |
| result_dict = posterior_model.forward(y, inp_kwargs[idx]) | |
| x_hat = result_dict["x_hat"] | |
| toc = time.time() | |
| print(f"Runtime: {toc - tic}") | |
| guidance_img = guidance_img.cuda() | |
| result_file = output_folders[0].format(idx) | |
| input_file = output_folders[1].format(idx) | |
| ground_truth_file = output_folders[2].format(idx) | |
| x_hat_pil = torchvision.transforms.ToPILImage()( | |
| x_hat.float()[0].clip(-1, 1) * 0.5 + 0.5 | |
| ) | |
| x_hat_pil.save(result_file) | |
| try: | |
| if config["degradation"]["name"] == "SuperRes": | |
| input_img = posterior_model.forward_operator.nn(y) | |
| else: | |
| input_img = posterior_model.forward_operator.pseudo_inv(y) | |
| input_img_pil = torchvision.transforms.ToPILImage()( | |
| input_img.float()[0].clip(-1, 1) * 0.5 + 0.5 | |
| ) | |
| input_img_pil.save(input_file) | |
| except Exception: | |
| print("Error in pseudo-inverse operation. Skipping input image save.") | |
| guidance_img_pil = torchvision.transforms.ToPILImage()( | |
| guidance_img.float()[0] * 0.5 + 0.5 | |
| ) | |
| guidance_img_pil.save(ground_truth_file) | |
| if __name__ == "__main__": | |
| main() | |