Mituuu's picture
Upload 11 files
3f002dd verified
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
class Split(str, Enum):
train = "train"
dev = "dev"
test = "test"
@dataclass
class IRDataset:
corpus: List[Document]
queries: List[Query]
split2qrels: Dict[Split, List[QRel]]
def get_stats(self) -> Dict[str, int]:
stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
for split, qrels in self.split2qrels.items():
stats[f"|qrels-{split}|"] = len(qrels)
return stats
def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
qrels_dict = {}
for qrel in self.split2qrels[split]:
qrels_dict.setdefault(qrel.query_id, {})
qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
return qrels_dict
def get_split_queries(self, split: Split) -> List[Query]:
qrels = self.split2qrels[split]
qids = {qrel.query_id for qrel in qrels}
return list(filter(lambda query: query.query_id in qids, self.queries))