File size: 1,092 Bytes
3835a42
 
 
 
 
 
b731827
3835a42
d544db4
3835a42
 
b731827
 
 
3835a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
from typing import Dict, List, Any
from long_coref.coref.prediction import CorefPredictor
from long_coref.coref.utils import ArchiveContent
from allennlp.common.params import Params

CHECKPOINT = "coref-spanbert-large-2021.03.10"

class PreTrainedPipeline:
    def __init__(self, path=""):
        archive_content = ArchiveContent(
            archive_dir=os.path.join(path, CHECKPOINT),
            weight_path=os.path.join(path, CHECKPOINT, "weights.th"),
            config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")),
        )
        self.predictor = CorefPredictor.from_extracted_archive(archive_content)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
         data args:
              inputs (:obj: `str`)
              date (:obj: `str`)
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """
        # get inputs
        inputs: str = data.pop("inputs", data)
        prediction = self.predictor.resolve_paragraphs(inputs.split("\n\n"))
        return prediction.to_dict()