"""
File copied from
https://github.com/nicola-decao/diffmask/blob/master/diffmask/optim/lookahead.py
"""

import torch
import torch.optim as optim

from collections import defaultdict
from torch import Tensor
from torch.optim.optimizer import Optimizer
from typing import Iterable, Optional, Union


_params_type = Union[Iterable[Tensor], Iterable[dict]]


class Lookahead(Optimizer):
    """Lookahead optimizer: https://arxiv.org/abs/1907.08610"""

    # noinspection PyMissingConstructor
    def __init__(self, base_optimizer: Optimizer, alpha: float = 0.5, k: int = 6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f"Invalid slow update rate: {alpha}")
        if not 1 <= k:
            raise ValueError(f"Invalid lookahead steps: {k}")
        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.defaults = base_optimizer.defaults
        self.defaults.update(defaults)
        self.state = defaultdict(dict)
        # manually add our defaults to the param groups
        for name, default in defaults.items():
            for group in self.param_groups:
                group.setdefault(name, default)

    def update_slow(self, group: dict):
        for fast_p in group["params"]:
            if fast_p.grad is None:
                continue
            param_state = self.state[fast_p]
            if "slow_buffer" not in param_state:
                param_state["slow_buffer"] = torch.empty_like(fast_p.data)
                param_state["slow_buffer"].copy_(fast_p.data)
            slow = param_state["slow_buffer"]
            slow.add_(fast_p.data - slow, alpha=group["lookahead_alpha"])
            fast_p.data.copy_(slow)

    def sync_lookahead(self):
        for group in self.param_groups:
            self.update_slow(group)

    def step(self, closure: Optional[callable] = None) -> Optional[float]:
        # print(self.k)
        # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
        loss = self.base_optimizer.step(closure)
        for group in self.param_groups:
            group["lookahead_step"] += 1
            if group["lookahead_step"] % group["lookahead_k"] == 0:
                self.update_slow(group)
        return loss

    def state_dict(self) -> dict:
        fast_state_dict = self.base_optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict["state"]
        param_groups = fast_state_dict["param_groups"]
        return {
            "state": fast_state,
            "slow_state": slow_state,
            "param_groups": param_groups,
        }

    def load_state_dict(self, state_dict: dict):
        fast_state_dict = {
            "state": state_dict["state"],
            "param_groups": state_dict["param_groups"],
        }
        self.base_optimizer.load_state_dict(fast_state_dict)

        # We want to restore the slow state, but share param_groups reference
        # with base_optimizer. This is a bit redundant but least code
        slow_state_new = False
        if "slow_state" not in state_dict:
            print("Loading state_dict from optimizer without Lookahead applied.")
            state_dict["slow_state"] = defaultdict(dict)
            slow_state_new = True
        slow_state_dict = {
            "state": state_dict["slow_state"],
            "param_groups": state_dict[
                "param_groups"
            ],  # this is pointless but saves code
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.param_groups = (
            self.base_optimizer.param_groups
        )  # make both ref same container
        if slow_state_new:
            # reapply defaults to catch missing lookahead specific ones
            for name, default in self.defaults.items():
                for group in self.param_groups:
                    group.setdefault(name, default)


def LookaheadAdam(
    params: _params_type,
    lr: float = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-08,
    weight_decay: float = 0,
    amsgrad: bool = False,
    lalpha: float = 0.5,
    k: int = 6,
):
    return Lookahead(
        torch.optim.Adam(params, lr, betas, eps, weight_decay, amsgrad), lalpha, k
    )


def LookaheadRAdam(
    params: _params_type,
    lr: float = 1e-3,
    betas: tuple[float, float] = (0.9, 0.999),
    eps: float = 1e-8,
    weight_decay: float = 0,
    lalpha: float = 0.5,
    k: int = 6,
):
    return Lookahead(optim.RAdam(params, lr, betas, eps, weight_decay), lalpha, k)


def LookaheadRMSprop(
    params: _params_type,
    lr: float = 1e-2,
    alpha: float = 0.99,
    eps: float = 1e-08,
    weight_decay: float = 0,
    momentum: float = 0,
    centered: bool = False,
    lalpha: float = 0.5,
    k: int = 6,
):
    return Lookahead(
        torch.optim.RMSprop(params, lr, alpha, eps, weight_decay, momentum, centered),
        lalpha,
        k,
    )