File size: 2,644 Bytes
9c213b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# https://huggingface.co/docs/transformers/custom_models

from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn.functional import log_softmax
from torch.nn.modules.container import ModuleList

class CustomModel(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, *args, labels=None, **kwargs):
        loss = None
        logits = None
        for model, coeff in zip(self.models, self.coeffs):
            logp = log_softmax(model.forward(*args, **kwargs).logits, dim=-1)
            logits = coeff * logp if logits is None else logits + coeff * logp
        # The rest copied from modeling_llama.py:
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(loss=loss, logits=logits)
        

    @classmethod
    def combine_models(cls, *args, coeffs = [], **kwargs):
        models = []
        for model in args:
            models.append(AutoModelForCausalLM.from_pretrained(model, **kwargs).eval())
        if coeffs == []:
            coeffs = [1/len(args)] * len(args)
        m = cls(models[0].config)
        m.models = ModuleList(models)
        m.coeffs = coeffs
        return m

# In the example there is also config class but we'll just use the one from GPTNeoX
# The norm is to import from PreTrainedModel but we'll take a shortcut
# class CustomModel(GPTNeoXForCausalLM):
#     def __init__(self, config):
#         super().__init__(config)

#     def forward(self, *args, **kwargs):
#         # See https://huggingface.co/docs/transformers/main_classes/output
#         out = super().forward(*args, **kwargs)
#         out.logits = log_softmax(out.logits, dim=-1)
#         return out

#     @classmethod
#     def copy_from_neox(cls, *args, **kwargs):
#         m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs)
#         m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device)
#         m1.load_state_dict(m0.state_dict())
#         return m1

CustomModel.register_for_auto_class('AutoModelForCausalLM')