from dataclasses import dataclass from enum import Enum from typing import Dict, List from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel class Split(str, Enum): train = "train" dev = "dev" test = "test" @dataclass class IRDataset: corpus: List[Document] queries: List[Query] split2qrels: Dict[Split, List[QRel]] def get_stats(self) -> Dict[str, int]: stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)} for split, qrels in self.split2qrels.items(): stats[f"|qrels-{split}|"] = len(qrels) return stats def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]: qrels_dict = {} for qrel in self.split2qrels[split]: qrels_dict.setdefault(qrel.query_id, {}) qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance return qrels_dict def get_split_queries(self, split: Split) -> List[Query]: qrels = self.split2qrels[split] qids = {qrel.query_id for qrel in qrels} return list(filter(lambda query: query.query_id in qids, self.queries))