File size: 3,589 Bytes
b39f8cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
)
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
class MCQModel(nn.Module):
def __init__(self, name_model):
super(MCQModel, self).__init__()
self.model = AutoModelForCausalLM.from_pretrained(
name_model,
trust_remote_code=True,
output_hidden_states=True,
)
self.classifier = nn.Linear(
self.model.config.model_dim, 4
) # 4 classes for 'A', 'B', 'C', 'D'
def forward(self, input_ids, attention_mask=None, labels=None,
position_ids = None,
past_key_values = None,
inputs_embeds = None,
use_cache = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
cache_position = None):
print("INSIDE CUSTOM MODEL LABELS: ", labels)
outputs = self.model(input_ids, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)
print("OUTPUT KEYS: " , outputs.keys())
# print(outputs.logits.shape)
# # print(outputs.hidden_states)
# print(outputs.hidden_states[0].shape)
# print(outputs.hidden_states[1].shape)
# print(len(outputs.hidden_states))
# hidden state is a tuple with all the hidden layer outputs from the attention,
# We are only interested in the last hidden layer and the last token
logits = self.classifier(outputs.hidden_states[-1][:, -1, :])
outputs.logits = logits
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits, labels
) # labels [batch_size], logits [batch_size x num_classes]
# print("LOSS", loss)
outputs["loss"] = loss
print("===================")
print(loss)
print("===================")
# print("OUTPUTS KEY" ,outputs.keys())
return outputs
class MyCustomConfig(PretrainedConfig):
model_type = "mcq_hf_model"
def __init__(self, name_model="apple/OpenELM-450M-Instruct", **kwargs):
super().__init__(**kwargs)
self.name_model = name_model
class MCQHFModel(PreTrainedModel):
config_class = MyCustomConfig
def __init__(self, config):
super().__init__(config)
self.model = MCQModel(config.name_model)
self.lm_head = None
def forward(self, input_ids, attention_mask=None, labels=None,
position_ids = None,
past_key_values = None,
inputs_embeds = None,
use_cache = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
cache_position = None):
return self.model(input_ids, labels=labels, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)
|