# from utils.args import parse_args
import logging
import os
import argparse
from pathlib import Path
from PIL import Image

import numpy as np
import torch
from tqdm.auto import tqdm
from diffusers.utils import check_min_version

from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from utils.seed_all import seed_all

from contextlib import nullcontext
import cv2
from tqdm import tqdm  # 添加这一行以导入 tqdm

check_min_version('0.28.0.dev0')

def infer_pipe(pipe, image_input, task_name, seed, device):
    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(seed)

    if torch.backends.mps.is_available():
        autocast_ctx = nullcontext()
    else:
        autocast_ctx = torch.autocast(pipe.device.type)
    with autocast_ctx:

        test_image = Image.open(image_input).convert('RGB')
        test_image = np.array(test_image).astype(np.float16)
        test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
        test_image = test_image / 127.5 - 1.0 
        test_image = test_image.to(device)

        task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
        task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)

        # Run
        pred = pipe(
            rgb_in=test_image, 
            prompt='', 
            num_inference_steps=1, 
            generator=generator, 
            # guidance_scale=0,
            output_type='np',
            timesteps=[999],
            task_emb=task_emb,
            ).images[0]

        # Post-process the prediction
        if task_name == 'depth':
            output_npy = pred.mean(axis=-1)
                # 修改为输出灰度图
            output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
        else:
            output_npy = pred
            output_color = Image.fromarray((output_npy * 255).astype(np.uint8))

    return output_color

def lotus_video(input_video, task_name, seed, device):
    if task_name == 'depth':
        model_g = 'jingheya/lotus-depth-g-v1-0'
    else:
        model_g = 'jingheya/lotus-normal-g-v1-0'

    dtype = torch.float16
    pipe_g = LotusGPipeline.from_pretrained(
        model_g,
        torch_dtype=dtype,
    )
    pipe_g.to(device)
    pipe_g.set_progress_bar_config(disable=True)
    logging.info(f"Successfully loading pipeline from {model_g}.")
    
    # load the video and split it into frames
    cap = cv2.VideoCapture(input_video)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    logging.info(f"There are {len(frames)} frames in the video.")

    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(seed)

    task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
    task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)

    output_g = []
    for frame in tqdm(frames, desc="Processing frames"):  # 使用 tqdm 包裹 frames 列表
        if torch.backends.mps.is_available():
            autocast_ctx = nullcontext()
        else:
            autocast_ctx = torch.autocast(pipe_g.device.type)
        with autocast_ctx:
            test_image = frame
            test_image = np.array(test_image).astype(np.float16)
            test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
            test_image = test_image / 127.5 - 1.0 
            test_image = test_image.to(device)

            # Run
            pred_g = pipe_g(
                rgb_in=test_image, 
                prompt='', 
                num_inference_steps=1, 
                generator=generator, 
                # guidance_scale=0,
                output_type='np',
                timesteps=[999],
                task_emb=task_emb,
                ).images[0]
            # Post-process the prediction
            if task_name == 'depth':
                output_npy = pred_g.mean(axis=-1)
                # 修改为输出灰度图
                output_color_g = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
            else:
                output_npy_g = pred_g
                output_color_g = Image.fromarray((output_npy_g * 255).astype(np.uint8))

            output_g.append(output_color_g)

    return output_g

def lotus(image_input, task_name, seed, device):
    if task_name == 'depth':
        model_g = 'jingheya/lotus-depth-g-v1-0'
    else:
        model_g = 'jingheya/lotus-normal-g-v1-0'

    dtype = torch.float16
    pipe_g = LotusGPipeline.from_pretrained(
        model_g,
        torch_dtype=dtype,
    )

    pipe_g.to(device)

    pipe_g.set_progress_bar_config(disable=True)

    logging.info(f"Successfully loading pipeline from {model_g}.")
    output_g = infer_pipe(pipe_g, image_input, task_name, seed, device)
   
    return output_g

def parse_args():
    '''Set the Args'''
    parser = argparse.ArgumentParser(
        description="Run Lotus..."
    )
    # model settings
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        help="pretrained model path from hugging face or local dir",
    )
    parser.add_argument(
        "--prediction_type",
        type=str,
        default="sample",
        help="The used prediction_type. ",
    )
    parser.add_argument(
        "--timestep",
        type=int,
        default=999,
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="regression", # "generation"
        help="Whether to use the generation or regression pipeline."
    )
    parser.add_argument(
        "--task_name",
        type=str,
        default="depth", # "normal"
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )
    
    # inference settings
    parser.add_argument("--seed", type=int, default=None, help="Random seed.")
    parser.add_argument(
        "--output_dir", type=str, required=True, help="Output directory."
    )
    parser.add_argument(
        "--input_dir", type=str, required=True, help="Input directory."
    )
    parser.add_argument(
        "--half_precision",
        action="store_true",
        help="Run with half-precision (16-bit float), might lead to suboptimal result.",
    )
    
    args = parser.parse_args()

    return args

def main():
    logging.basicConfig(level=logging.INFO)
    logging.info(f"Run inference...")
    
    args = parse_args()

    # -------------------- Preparation --------------------
    # Random seed
    if args.seed is not None:
        seed_all(args.seed)

    # Output directories
    os.makedirs(args.output_dir, exist_ok=True)
    logging.info(f"Output dir = {args.output_dir}")

    output_dir_color = os.path.join(args.output_dir, f'{args.task_name}_vis')
    output_dir_npy = os.path.join(args.output_dir, f'{args.task_name}')
    if not os.path.exists(output_dir_color): os.makedirs(output_dir_color)
    if not os.path.exists(output_dir_npy): os.makedirs(output_dir_npy)

    # half_precision
    if args.half_precision:
        dtype = torch.float16
        logging.info(f"Running with half precision ({dtype}).")
    else:
        dtype = torch.float16

    # -------------------- Device --------------------
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        logging.warning("CUDA is not available. Running on CPU will be slow.")
    logging.info(f"Device = {device}")

    # -------------------- Data --------------------
    root_dir = Path(args.input_dir)
    test_images = list(root_dir.rglob('*.png')) + list(root_dir.rglob('*.jpg'))
    test_images = sorted(test_images)
    print('==> There are', len(test_images), 'images for validation.')
    # -------------------- Model --------------------
    
    if args.mode == 'generation':
        pipeline = LotusGPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            torch_dtype=dtype,
        )
    elif args.mode == 'regression':
        pipeline = LotusDPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            torch_dtype=dtype,
        )
    else:
        raise ValueError(f'Invalid mode: {args.mode}')
    logging.info(f"Successfully loading pipeline from {args.pretrained_model_name_or_path}.")

    pipeline = pipeline.to(device)
    pipeline.set_progress_bar_config(disable=True)

    if args.enable_xformers_memory_efficient_attention:
        pipeline.enable_xformers_memory_efficient_attention()


    if args.seed is None:
        generator = None
    else:
        generator = torch.Generator(device=device).manual_seed(args.seed)

    # -------------------- Inference and saving --------------------
    with torch.no_grad():
        for i in tqdm(range(len(test_images))):
            # Preprocess validation image
            test_image = Image.open(test_images[i]).convert('RGB')
            test_image = np.array(test_image).astype(np.float16)
            test_image = torch.tensor(test_image).permute(2,0,1).unsqueeze(0)
            test_image = test_image / 127.5 - 1.0 
            test_image = test_image.to(device)

            task_emb = torch.tensor([1, 0]).float().unsqueeze(0).repeat(1, 1).to(device)
            task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1).repeat(1, 1)

            # Run
            pred = pipeline(
                rgb_in=test_image, 
                prompt='', 
                num_inference_steps=1, 
                generator=generator, 
                # guidance_scale=0,
                output_type='np',
                timesteps=[args.timestep],
                task_emb=task_emb,
                ).images[0]

            # Post-process the prediction
            save_file_name = os.path.basename(test_images[i])[:-4]
            if args.task_name == 'depth':
                output_npy = pred.mean(axis=-1)
                # 修改为输出灰度图
                output_color = Image.fromarray((output_npy * 255).astype(np.uint8), mode='L')
            else:
                output_npy = pred
                output_color = Image.fromarray((output_npy * 255).astype(np.uint8))

            output_color.save(os.path.join(output_dir_color, f'{save_file_name}.png'))
            np.save(os.path.join(output_dir_npy, f'{save_file_name}.npy'), output_npy)
    
    print('==> Inference is done. \n==> Results saved to:', args.output_dir)

if __name__ == '__main__':
    main()