# https://huggingface.co/docs/transformers/custom_models from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn.functional import log_softmax from torch.nn.modules.container import ModuleList class CustomModel(PreTrainedModel): def __init__(self, config): super().__init__(config) def forward(self, *args, labels=None, **kwargs): loss = None logits = None for model, coeff in zip(self.models, self.coeffs): logp = log_softmax(model.forward(*args, **kwargs).logits, dim=-1) logits = coeff * logp if logits is None else logits + coeff * logp # The rest copied from modeling_llama.py: if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) return CausalLMOutputWithPast(loss=loss, logits=logits) @classmethod def combine_models(cls, *args, coeffs = [], **kwargs): models = [] for model in args: models.append(AutoModelForCausalLM.from_pretrained(model, **kwargs).eval()) if coeffs == []: coeffs = [1/len(args)] * len(args) m = cls(models[0].config) m.models = ModuleList(models) m.coeffs = coeffs return m # In the example there is also config class but we'll just use the one from GPTNeoX # The norm is to import from PreTrainedModel but we'll take a shortcut # class CustomModel(GPTNeoXForCausalLM): # def __init__(self, config): # super().__init__(config) # def forward(self, *args, **kwargs): # # See https://huggingface.co/docs/transformers/main_classes/output # out = super().forward(*args, **kwargs) # out.logits = log_softmax(out.logits, dim=-1) # return out # @classmethod # def copy_from_neox(cls, *args, **kwargs): # m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs) # m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device) # m1.load_state_dict(m0.state_dict()) # return m1 CustomModel.register_for_auto_class('AutoModelForCausalLM')