File size: 4,332 Bytes
0e956f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from transformers import AutoTokenizer, RobertaModel

class EmbeddingModel(torch.nn.Module):


    tokenizers = {'roberta': RobertaModel}
    """
    A basic wrapper around a Hugging Face transformer model.
    Takes a string as input and produces an embedding vector of size d.
    """
    def __init__(self, config, **kwargs):

        super().__init__()

        self.model_class = self.tokenizers.get(config.get("model_class").lower())
        self.model_name = config.get("model_name")
        self.pooling = config.get("pooling")
        self.normalize = config.get("normalize")
        self.d = config.get("d")
        self.prompt = config.get("prompt")
        self.add_upper = config.get("add_upper")
        self.upper_case = config.get("upper_case")

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        try:
            self.transformer = self.model_class.from_pretrained(self.model_name)
        except OSError:
            self.transformer = self.model_class.from_pretrained(self.model_name,from_tf=True)

        self.dropout = torch.nn.Dropout(0.5)

        if self.d:
            # Project embedding to a lower dimension
            # Initialization based on random projection LSH (preserves approximate cosine distances)
            self.projection = torch.nn.Linear(self.transformer.config.hidden_size,self.d)
            torch.nn.init.normal_(self.projection.weight)
            torch.nn.init.constant_(self.projection.bias,0)

        self.to(config.get("device"))

    def to(self,device):
        super().to(device)
        self.device = device

    def encode(self,strings):
        if self.prompt is not None:
            strings = [self.prompt + s for s in strings]
        if self.add_upper:
            strings = [s + ' </s> ' + s.upper() for s in strings]
        if self.upper_case:
            strings = [s + ' </s> ' + s.upper() for s in strings]

        try:
            encoded = self.tokenizer(strings,padding=True,truncation=True)
        except Exception as e:
            print(strings)
            raise Exception(e)
        input_ids = torch.tensor(encoded['input_ids']).long()
        attention_mask = torch.tensor(encoded['attention_mask'])

        return input_ids,attention_mask

    def forward(self,strings):

        with torch.no_grad():
            input_ids,attention_mask = self.encode(strings)

            input_ids = input_ids.to(device=self.device)
            attention_mask = attention_mask.to(device=self.device)

        # with amp.autocast(self.amp):
        batch_out = self.transformer(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        return_dict=True)

        if self.pooling == 'pooler':
            v = batch_out['pooler_output']
        elif self.pooling == 'mean':
            h = batch_out['last_hidden_state']

            # Compute mean of unmasked token vectors
            h = h*attention_mask[:,:,None]
            v = h.sum(dim=1)/attention_mask.sum(dim=1)[:,None]

        if self.d:
            v = self.projection(v)

        if self.normalize:
            v = v/torch.sqrt((v**2).sum(dim=1)[:,None])

        return v

    def config_optimizer(self,transformer_lr=1e-5,projection_lr=1e-4):

        parameters = list(self.named_parameters())
        grouped_parameters = [
                {
                    'params': [param for name,param in parameters if name.startswith('transformer') and name.endswith('bias')],
                    'weight_decay_rate': 0.0,
                    'lr':transformer_lr,
                    },
                {
                    'params': [param for name,param in parameters if name.startswith('transformer') and not name.endswith('bias')],
                    'weight_decay_rate': 0.0,
                    'lr':transformer_lr,
                    },
                {
                    'params': [param for name,param in parameters if name.startswith('projection')],
                    'weight_decay_rate': 0.0,
                    'lr':projection_lr,
                    },
                ]

        # Drop groups with lr of 0
        grouped_parameters = [p for p in grouped_parameters if p['lr']]

        optimizer = torch.optim.AdamW(grouped_parameters)

        return optimizer