from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel


class Split(str, Enum):
    train = "train"
    dev = "dev"
    test = "test"


@dataclass
class IRDataset:
    corpus: List[Document]
    queries: List[Query]
    split2qrels: Dict[Split, List[QRel]]

    def get_stats(self) -> Dict[str, int]:
        stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
        for split, qrels in self.split2qrels.items():
            stats[f"|qrels-{split}|"] = len(qrels)
        return stats

    def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
        qrels_dict = {}
        for qrel in self.split2qrels[split]:
            qrels_dict.setdefault(qrel.query_id, {})
            qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
        return qrels_dict

    def get_split_queries(self, split: Split) -> List[Query]:
        qrels = self.split2qrels[split]
        qids = {qrel.query_id for qrel in qrels}
        return list(filter(lambda query: query.query_id in qids, self.queries))