from __future__ import annotations

from abc import ABC
from typing import Sequence, TypeVar

import attr
import torch
from attr import define

from esm.tokenization import (
    TokenizerCollectionProtocol,
    get_model_tokenizers,
)
from esm.utils import encoding
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import (
    FunctionAnnotation,
    PathLike,
    PathOrBuffer,
)


## Basic Types
@define
class ESMProtein:
    # Tracks
    sequence: str | None = None
    secondary_structure: str | None = None
    sasa: list[float | str | None] | None = None
    function_annotations: list[FunctionAnnotation] | None = None
    coordinates: torch.Tensor | None = None
    # Metrics
    plddt: torch.Tensor | None = None
    ptm: torch.Tensor | None = None

    def __len__(self):
        if self.sequence is not None:
            return len(self.sequence)
        elif self.secondary_structure is not None:
            return len(self.secondary_structure)
        elif self.sasa is not None:
            return len(self.sasa)
        elif self.coordinates is not None:
            return self.coordinates.size(0)
        else:
            raise ValueError("No track to determine length from.")

    @classmethod
    def from_pdb(
        cls,
        path: PathOrBuffer,
        chain_id: str = "detect",
        id: str | None = None,
        is_predicted: bool = False,
    ) -> ESMProtein:
        protein_chain = ProteinChain.from_pdb(
            path=path, chain_id=chain_id, id=id, is_predicted=is_predicted
        )
        return cls.from_protein_chain(protein_chain)

    @classmethod
    def from_protein_chain(
        cls, protein_chain: ProteinChain, with_annotations: bool = False
    ) -> ESMProtein:
        # By default, we don't annotate with DSSP / SASA, which are expensive.
        # If mkdssp is installed, we can annotate with a flag.
        if with_annotations:
            return ESMProtein(
                sequence=protein_chain.sequence,
                secondary_structure=protein_chain.dssp().tolist(),
                sasa=protein_chain.sasa().tolist(),
                function_annotations=None,
                coordinates=torch.tensor(protein_chain.atom37_positions),
            )
        else:
            return ESMProtein(
                sequence=protein_chain.sequence,
                secondary_structure=None,
                sasa=None,
                function_annotations=None,
                coordinates=torch.tensor(protein_chain.atom37_positions),
            )

    def to_pdb(self, pdb_path: PathLike) -> None:
        protein_chain = self.to_protein_chain()
        protein_chain.to_pdb(pdb_path)

    def to_pdb_string(self) -> str:
        protein_chain = self.to_protein_chain()
        return protein_chain.to_pdb_string()

    def to_protein_chain(self) -> ProteinChain:
        if self.coordinates is None:
            raise ValueError("Coordinates are required to convert to a ProteinChain.")
        protein_chain = ProteinChain.from_atom37(
            atom37_positions=self.coordinates.to("cpu").numpy(),
            id=None,
            sequence=self.sequence,
            chain_id=None,
            entity_id=None,
            residue_index=None,
            insertion_code=None,
            confidence=None if self.plddt is None else self.plddt.detach().cpu().numpy(),
        )
        return protein_chain


@define
class ESMProteinTensor:
    sequence: torch.Tensor | None = None
    structure: torch.Tensor | None = None
    secondary_structure: torch.Tensor | None = None
    sasa: torch.Tensor | None = None
    function: torch.Tensor | None = None
    residue_annotations: torch.Tensor | None = None
    coordinates: torch.Tensor | None = None

    def __len__(self) -> int:
        if self.sequence is not None:
            return self.sequence.size(0)
        elif self.structure is not None:
            return self.structure.size(0)
        elif self.secondary_structure is not None:
            return self.secondary_structure.size(0)
        elif self.sasa is not None:
            return self.sasa.size(0)
        elif self.coordinates is not None:
            return self.coordinates.size(0)
        else:
            raise ValueError("No track to determine length from.")

    @property
    def device(self) -> str | torch.device:
        device_ = None

        tracks = [f.name for f in attr.fields(ESMProteinTensor)]

        for track in tracks:
            current_track: torch.Tensor | None = getattr(self, track)
            if current_track is not None:
                if device_ is not None and device_ != current_track.device:
                    raise ValueError(f"Inconsistent devices for track {track}.")
                device_ = getattr(self, track).device

        if device_ is None:
            raise ValueError("No track to determine device from.")

        return device_

    def to(self, device: str | torch.device | None) -> ESMProteinTensor:
        if device is None:
            return self

        device = torch.device(device)

        def _to(name):
            v = getattr(self, name)
            if v is not None:
                setattr(self, name, v.to(device))

        for n in [
            "sequence",
            "structure",
            "secondary_structure",
            "sasa",
            "function",
            "residue_annotations",
            "coordinates",
        ]:
            _to(n)

        return self

    @classmethod
    def empty(
        cls,
        length: int,
        tokenizers: TokenizerCollectionProtocol | None = None,
        device: torch.device | str = "cpu",
    ) -> ESMProteinTensor:
        if tokenizers is None:
            tokenizers = get_model_tokenizers(ESM3_OPEN_SMALL)

        return ESMProteinTensor(
            sequence=encoding.get_default_sequence_tokens(
                length, tokenizers.sequence
            ).to(device),
            structure=encoding.get_default_structure_tokens(
                length, tokenizers.structure
            ).to(device),
            secondary_structure=encoding.get_default_secondary_structure_tokens(
                length, tokenizers.secondary_structure
            ).to(device),
            sasa=encoding.get_default_sasa_tokens(length, tokenizers.sasa).to(device),
            function=encoding.get_default_function_tokens(
                length, tokenizers.function
            ).to(device),
            residue_annotations=encoding.get_default_residue_annotation_tokens(
                length, tokenizers.residue_annotations
            ).to(device),
        )


## High Level Endpoint Types
@define
class GenerationConfig:
    track: str = ""
    invalid_ids: Sequence[int] = []
    schedule: str = "cosine"
    num_steps: int = 8
    temperature: float = 1.0
    top_p: float = 1.0
    condition_on_coordinates_only: bool = True


## Low Level Endpoint Types
@define
class SamplingTrackConfig:
    temperature: float = 1.0
    top_p: float = 1.0
    only_sample_masked_tokens: bool = True
    invalid_ids: Sequence[int] = []
    topk_logprobs: int = 0


@define
class SamplingConfig:
    sequence: SamplingTrackConfig | None = None
    structure: SamplingTrackConfig | None = None
    secondary_structure: SamplingTrackConfig | None = None
    sasa: SamplingTrackConfig | None = None
    function: SamplingTrackConfig | None = None

    return_per_residue_embeddings: bool = False
    return_mean_embedding: bool = False


@define
class ReturnLogitsConfig:
    sequence: bool = False
    structure: bool = False
    secondary_structure: bool = False
    sasa: bool = False
    function: bool = False
    residue_annotations: bool = False


@define
class ForwardConfig:
    return_logits: ReturnLogitsConfig = ReturnLogitsConfig()
    return_embeddings: bool = False


@define
class ForwardTrackData:
    sequence: torch.Tensor | None = None
    structure: torch.Tensor | None = None
    secondary_structure: torch.Tensor | None = None
    sasa: torch.Tensor | None = None
    function: torch.Tensor | None = None


@define
class ForwardOutput:
    logits: ForwardTrackData | None = None
    embeddings: torch.Tensor | None = None

    # Residue annotations is multi-hot, so deserves special treatment
    # It's not a categorical distribution, but instead a bernoulli, so
    # softmax across the last dimension is _wrong_
    residue_annotation_logits: torch.Tensor | None = None


@define
class ForwardAndSampleOutput(ForwardOutput):
    protein_tensor: ESMProteinTensor = ESMProteinTensor()

    entropy: ForwardTrackData | None = None
    # Probability of sampled token
    prob: ForwardTrackData | None = None
    logprob: ForwardTrackData | None = None
    # Top probability at this position
    top_prob: ForwardTrackData | None = None
    topk_logprob: ForwardTrackData | None = None
    # Which tokens correspond to top probability
    topk_tokens: ForwardTrackData | None = None

    per_residue_embedding: torch.Tensor | None = None
    mean_embedding: torch.Tensor | None = None


ProteinType = TypeVar("ProteinType", bound=ESMProteinTensor | ESMProtein)


class ESM3InferenceClient(ABC):
    def generate(self, input: ProteinType, config: GenerationConfig) -> ProteinType:
        # This is the easiest and most flexible way to run ESM3. Generate will
        # iteratively sample tokens an provide an output with the track specified
        # completely filled out, according to the GenerationConfig provided.
        # It is a local function wrapping calls for encode -> iterative_sampling -> decode.
        # if a ESMProteinTensor is provided, encode and decode are skipped
        raise NotImplementedError

    def encode(self, input: ESMProtein) -> ESMProteinTensor:
        # Encode allows for encoding RawRepresentation into TokenizedRepresentation.
        # This runs the structure_token_encoder, as well as dealing with PDB => atom37 conversion
        raise NotImplementedError

    def decode(self, input: ESMProteinTensor) -> ESMProtein:
        # Decode is the inverse of encode, and runs a structure_token_decoder to output coordinates
        raise NotImplementedError

    def _forward(
        self, input: ESMProteinTensor, config: ForwardConfig = ForwardConfig()
    ) -> ForwardOutput:
        # Our API generally discourages using raw forwards.
        # This is because sending logits can be prohibitively expensive.
        # Please use forward_and_sample instead.
        raise NotImplementedError

    def forward_and_sample(
        self, input: ESMProteinTensor, sampling_configuration: SamplingConfig
    ) -> ForwardAndSampleOutput:
        # forward_and_sample runs a single model forward, sampling tokens according to `SamplingConfiguration`.
        # This is the way for power users to run ESM3. We hope to design this in a way to enable high throughput
        # inference, as well as arbitrary chain-of-though invocations of ESM3.
        raise NotImplementedError