foo4 / modeling_custom.py
denizyuret-shallowai's picture
Upload model
9c213b4
# https://huggingface.co/docs/transformers/custom_models
from transformers import PreTrainedModel, GPTNeoXForCausalLM, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn.functional import log_softmax
from torch.nn.modules.container import ModuleList
class CustomModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
def forward(self, *args, labels=None, **kwargs):
loss = None
logits = None
for model, coeff in zip(self.models, self.coeffs):
logp = log_softmax(model.forward(*args, **kwargs).logits, dim=-1)
logits = coeff * logp if logits is None else logits + coeff * logp
# The rest copied from modeling_llama.py:
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutputWithPast(loss=loss, logits=logits)
@classmethod
def combine_models(cls, *args, coeffs = [], **kwargs):
models = []
for model in args:
models.append(AutoModelForCausalLM.from_pretrained(model, **kwargs).eval())
if coeffs == []:
coeffs = [1/len(args)] * len(args)
m = cls(models[0].config)
m.models = ModuleList(models)
m.coeffs = coeffs
return m
# In the example there is also config class but we'll just use the one from GPTNeoX
# The norm is to import from PreTrainedModel but we'll take a shortcut
# class CustomModel(GPTNeoXForCausalLM):
# def __init__(self, config):
# super().__init__(config)
# def forward(self, *args, **kwargs):
# # See https://huggingface.co/docs/transformers/main_classes/output
# out = super().forward(*args, **kwargs)
# out.logits = log_softmax(out.logits, dim=-1)
# return out
# @classmethod
# def copy_from_neox(cls, *args, **kwargs):
# m0 = GPTNeoXForCausalLM.from_pretrained(*args, **kwargs)
# m1 = cls(m0.config).to(dtype=m0.dtype, device=m0.device)
# m1.load_state_dict(m0.state_dict())
# return m1
CustomModel.register_for_auto_class('AutoModelForCausalLM')