File size: 2,355 Bytes
83edfa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import kenlm
import sentencepiece as spm
from tokenizers import normalizers


class KenlmModel:
    def __init__(
        self,
        vocabulary_size: str,
        ngram: str,
        pruning: str,
        normalize_nfd: bool = True,
        normalize_numbers: bool = True,
        normalize_puctuation: bool = True,
    ):
        self.model = kenlm.Model(os.path.join("files", f"jomleh-sp-{vocabulary_size}-o{ngram}-prune{pruning}.probing"))
        self.tokenizer = spm.SentencePieceProcessor(os.path.join("files", f"jomleh-sp-{vocabulary_size}.model"))

        norm_list = []
        if normalize_numbers:
            norm_list += [normalizers.Replace("۱", "۰"),
                          normalizers.Replace("۲", "۰"),
                          normalizers.Replace("۳", "۰"),
                          normalizers.Replace("۴", "۰"),
                          normalizers.Replace("۵", "۰"),
                          normalizers.Replace("۶", "۰"),
                          normalizers.Replace("۷", "۰"),
                          normalizers.Replace("۸", "۰"),
                          normalizers.Replace("۹", "۰"),
                          normalizers.Replace(".", "")]
        if normalize_puctuation:
            norm_list += [normalizers.Replace(".", ""),
                          normalizers.Replace("!", ""),
                          normalizers.Replace("؛", ""),
                          normalizers.Replace("،", ""),
                          normalizers.Replace("؟", "")]
        if normalize_nfd:
            norm_list += [normalizers.NFD()]
        norm_list += [normalizers.Strip()]

        self.normalizer = normalizers.Sequence(norm_list)

    @classmethod
    def from_pretrained(
        cls,
        vocabulary_size: str,
        ngram: str,
        pruning: str,
    ):
        return cls(vocabulary_size, ngram, pruning)

    def score(self, doc: str):
        doc = self.normalizer.normalize_str(doc)
        doc = ' '.join(self.tokenizer.encode(doc, out_type=str))
        return self.model.score(doc)

    def perplexity(self, doc: str):
        doc = self.normalizer.normalize_str(doc)
        doc = ' '.join(self.tokenizer.encode(doc, out_type=str))
        log_score = self.model.score(doc)
        length = len(doc.split()) + 1
        return round(10.0 ** (-log_score / length), 1)