# General
import os
from os.path import join as opj
import datetime
import torch
from einops import rearrange, repeat

# Utilities
from videogen_hub.pipelines.streamingt2v.inference_utils import *

from modelscope.outputs import OutputKeys
import imageio
from PIL import Image
import numpy as np

import torch.nn.functional as F
import torchvision.transforms as transforms
from diffusers.utils import load_image

transform = transforms.Compose([transforms.PILToTensor()])


def ms_short_gen(prompt, ms_model, inference_generator, t=50, device="cuda"):
    frames = ms_model(
        prompt,
        num_inference_steps=t,
        generator=inference_generator,
        eta=1.0,
        height=256,
        width=256,
        latents=None,
    ).frames
    frames = torch.stack([torch.from_numpy(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    return rearrange(frames[0], "F W H C -> F C W H")


def ad_short_gen(prompt, ad_model, inference_generator, t=25, device="cuda"):
    frames = ad_model(
        prompt,
        negative_prompt="bad quality, worse quality",
        num_frames=16,
        num_inference_steps=t,
        generator=inference_generator,
        guidance_scale=7.5,
    ).frames[0]
    frames = torch.stack([transform(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    frames = F.interpolate(frames, size=256)
    frames = frames / 255.0
    return frames


def sdxl_image_gen(prompt, sdxl_model):
    image = sdxl_model(prompt=prompt).images[0]
    return image


def svd_short_gen(
    image, prompt, svd_model, sdxl_model, inference_generator, t=25, device="cuda"
):
    if image is None or image == "":
        image = sdxl_image_gen(prompt, sdxl_model)
        image = image.resize((576, 576))
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
    elif type(image) is str:
        image = load_image(image)
        image = resize_and_keep(image)
        image = center_crop(image)
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))
    else:
        image = Image.fromarray(np.uint8(image))
        image = resize_and_keep(image)
        image = center_crop(image)
        image = add_margin(image, 0, 224, 0, 224, (0, 0, 0))

    frames = svd_model(
        image, decode_chunk_size=8, generator=inference_generator
    ).frames[0]
    frames = torch.stack([transform(frame) for frame in frames])
    frames = frames.to(device).to(torch.float32)
    frames = frames[:16, :, :, 224:-224]
    frames = F.interpolate(frames, size=256)
    frames = frames / 255.0
    return frames


def stream_long_gen(
    prompt,
    short_video,
    n_autoreg_gen,
    negative_prompt,
    seed,
    t,
    image_guidance,
    result_file_stem,
    stream_cli,
    stream_model,
):
    trainer = stream_cli.trainer
    trainer.limit_predict_batches = 1

    trainer.predict_cfg = {
        "predict_dir": stream_cli.config["result_fol"].as_posix(),
        "result_file_stem": result_file_stem,
        "prompt": prompt,
        "video": short_video,
        "seed": seed,
        "num_inference_steps": t,
        "guidance_scale": image_guidance,
        "n_autoregressive_generations": n_autoreg_gen,
    }
    stream_model.inference_params.negative_prompt = negative_prompt
    trainer.predict(model=stream_model, datamodule=stream_cli.datamodule)


def video2video(prompt, video, where_to_log, cfg_v2v, model_v2v, square=True):
    downscale = cfg_v2v["downscale"]
    upscale_size = cfg_v2v["upscale_size"]
    pad = cfg_v2v["pad"]

    now = datetime.datetime.now()
    now = str(now.time()).replace(":", "_").replace(".", "_")
    name = prompt[:100].replace(" ", "_") + "_" + now
    enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4")

    video_frames = imageio.mimread(video)
    h, w, _ = video_frames[0].shape

    # Downscale video, then resize to fit the upscale size
    video = [
        Image.fromarray(frame).resize((w // downscale, h // downscale))
        for frame in video_frames
    ]
    video = [resize_to_fit(frame, upscale_size) for frame in video]

    if pad:
        video = [pad_to_fit(frame, upscale_size) for frame in video]
    # video = [np.array(frame) for frame in video]

    imageio.mimsave(opj(where_to_log, "temp_" + now + ".mp4"), video, fps=8)

    p_input = {
        "video_path": opj(where_to_log, "temp_" + now + ".mp4"),
        "text": prompt,
        "positive_prompt": prompt,
        "total_noise_levels": 600,
    }
    model_v2v(p_input, output_video=enhanced_video_mp4)[OutputKeys.OUTPUT_VIDEO]

    # Remove padding
    video_frames = imageio.mimread(enhanced_video_mp4)
    video_frames_square = []
    for frame in video_frames:
        frame = frame[:, 280:-280, :]
        video_frames_square.append(frame)
    imageio.mimsave(enhanced_video_mp4, video_frames_square)

    return enhanced_video_mp4


# The main functionality for video to video
def video2video_randomized(
    prompt,
    video,
    where_to_log,
    cfg_v2v,
    model_v2v,
    square=True,
    chunk_size=24,
    overlap_size=8,
    negative_prompt="",
):
    downscale = cfg_v2v["downscale"]
    upscale_size = cfg_v2v["upscale_size"]
    pad = cfg_v2v["pad"]

    now = datetime.datetime.now()
    name = (
        prompt[:100].replace(" ", "_")
        + "_"
        + str(now.time()).replace(":", "_").replace(".", "_")
    )
    enhanced_video_mp4 = opj(where_to_log, name + "_enhanced.mp4")

    video_frames = imageio.mimread(video)
    h, w, _ = video_frames[0].shape

    n_chunks = (len(video_frames) - overlap_size) // (chunk_size - overlap_size)
    trim_length = n_chunks * (chunk_size - overlap_size) + overlap_size
    if trim_length < chunk_size:
        raise ValueError(
            f"Chunk size [{chunk_size}] cannot be larger than the number of frames in the video [{len(video_frames)}], please provide smaller chunk size"
        )
    if trim_length < len(video_frames):
        print(
            "Video cannot be processed with chunk size {chunk_size} and overlap size {overlap_size}, "
            "trimming it to length {trim_length} to be able to process it"
        )
        video_frames = video_frames[:trim_length]

    model_v2v.chunk_size = chunk_size
    model_v2v.overlap_size = overlap_size

    # Downscale video, then resize to fit the upscale size
    video = [
        Image.fromarray(frame).resize((w // downscale, h // downscale))
        for frame in video_frames
    ]
    video = [resize_to_fit(frame, upscale_size) for frame in video]

    if pad:
        video = [pad_to_fit(frame, upscale_size) for frame in video]

    video = list(map(np.array, video))

    imageio.mimsave(opj(where_to_log, "temp.mp4"), video, fps=8)

    p_input = {
        "video_path": opj(where_to_log, "temp.mp4"),
        "text": prompt,
        "positive_prompt": "",
        "negative_prompt": negative_prompt,
        "total_noise_levels": 600,
    }

    output_video_path = model_v2v(p_input, output_video=enhanced_video_mp4)[
        OutputKeys.OUTPUT_VIDEO
    ]

    return enhanced_video_mp4