import json
import os
from typing import Optional, Sequence, Tuple

from src.video_util import get_frame_count


class RerenderConfig:

    def __init__(self):
        ...

    def create_from_parameters(self,
                               input_path: str,
                               output_path: str,
                               prompt: str,
                               work_dir: Optional[str] = None,
                               key_subdir: str = 'keys',
                               frame_count: Optional[int] = None,
                               interval: int = 10,
                               crop: Sequence[int] = (0, 0, 0, 0),
                               sd_model: Optional[str] = None,
                               a_prompt: str = '',
                               n_prompt: str = '',
                               ddim_steps=20,
                               scale=7.5,
                               control_type: str = 'HED',
                               control_strength=1,
                               seed: int = -1,
                               image_resolution: int = 512,
                               x0_strength: float = -1,
                               style_update_freq: int = 10,
                               cross_period: Tuple[float, float] = (0, 1),
                               warp_period: Tuple[float, float] = (0, 0.1),
                               mask_period: Tuple[float, float] = (0.5, 0.8),
                               ada_period: Tuple[float, float] = (1.0, 1.0),
                               mask_strength: float = 0.5,
                               inner_strength: float = 0.9,
                               smooth_boundary: bool = True,
                               color_preserve: bool = True,
                               **kwargs):
        self.input_path = input_path
        self.output_path = output_path
        self.prompt = prompt
        self.work_dir = work_dir
        if work_dir is None:
            self.work_dir = os.path.dirname(output_path)
        self.key_dir = os.path.join(self.work_dir, key_subdir)
        self.first_dir = os.path.join(self.work_dir, 'first')

        # Split video into frames
        if not os.path.isfile(input_path):
            raise FileNotFoundError(f'Cannot find video file {input_path}')
        self.input_dir = os.path.join(self.work_dir, 'video')

        self.frame_count = frame_count
        if frame_count is None:
            self.frame_count = get_frame_count(self.input_path)
        self.interval = interval
        self.crop = crop
        self.sd_model = sd_model
        self.a_prompt = a_prompt
        self.n_prompt = n_prompt
        self.ddim_steps = ddim_steps
        self.scale = scale
        self.control_type = control_type
        if self.control_type == 'canny':
            self.canny_low = kwargs.get('canny_low', 100)
            self.canny_high = kwargs.get('canny_high', 200)
        else:
            self.canny_low = None
            self.canny_high = None
        self.control_strength = control_strength
        self.seed = seed
        self.image_resolution = image_resolution
        self.x0_strength = x0_strength
        self.style_update_freq = style_update_freq
        self.cross_period = cross_period
        self.mask_period = mask_period
        self.warp_period = warp_period
        self.ada_period = ada_period
        self.mask_strength = mask_strength
        self.inner_strength = inner_strength
        self.smooth_boundary = smooth_boundary
        self.color_preserve = color_preserve

        os.makedirs(self.input_dir, exist_ok=True)
        os.makedirs(self.work_dir, exist_ok=True)
        os.makedirs(self.key_dir, exist_ok=True)
        os.makedirs(self.first_dir, exist_ok=True)

    def create_from_path(self, cfg_path: str):
        with open(cfg_path, 'r') as fp:
            cfg = json.load(fp)
        kwargs = dict()

        def append_if_not_none(key):
            value = cfg.get(key, None)
            if value is not None:
                kwargs[key] = value

        kwargs['input_path'] = cfg['input']
        kwargs['output_path'] = cfg['output']
        kwargs['prompt'] = cfg['prompt']
        append_if_not_none('work_dir')
        append_if_not_none('key_subdir')
        append_if_not_none('frame_count')
        append_if_not_none('interval')
        append_if_not_none('crop')
        append_if_not_none('sd_model')
        append_if_not_none('a_prompt')
        append_if_not_none('n_prompt')
        append_if_not_none('ddim_steps')
        append_if_not_none('scale')
        append_if_not_none('control_type')
        if kwargs.get('control_type', '') == 'canny':
            append_if_not_none('canny_low')
            append_if_not_none('canny_high')
        append_if_not_none('control_strength')
        append_if_not_none('seed')
        append_if_not_none('image_resolution')
        append_if_not_none('x0_strength')
        append_if_not_none('style_update_freq')
        append_if_not_none('cross_period')
        append_if_not_none('warp_period')
        append_if_not_none('mask_period')
        append_if_not_none('ada_period')
        append_if_not_none('mask_strength')
        append_if_not_none('inner_strength')
        append_if_not_none('smooth_boundary')
        append_if_not_none('color_perserve')
        self.create_from_parameters(**kwargs)

    @property
    def use_warp(self):
        return self.warp_period[0] <= self.warp_period[1]

    @property
    def use_mask(self):
        return self.mask_period[0] <= self.mask_period[1]

    @property
    def use_ada(self):
        return self.ada_period[0] <= self.ada_period[1]