import torch import torch.nn as nn from transformers import ( PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, ) import pandas as pd from datasets import load_dataset, Dataset, DatasetDict class MCQModel(nn.Module): def __init__(self, name_model): super(MCQModel, self).__init__() self.model = AutoModelForCausalLM.from_pretrained( name_model, trust_remote_code=True, output_hidden_states=True, ) self.classifier = nn.Linear( self.model.config.model_dim, 4 ) # 4 classes for 'A', 'B', 'C', 'D' def forward(self, input_ids, attention_mask=None, labels=None, position_ids = None, past_key_values = None, inputs_embeds = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None): print("INSIDE CUSTOM MODEL LABELS: ", labels) outputs = self.model(input_ids, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states) print("OUTPUT KEYS: " , outputs.keys()) # print(outputs.logits.shape) # # print(outputs.hidden_states) # print(outputs.hidden_states[0].shape) # print(outputs.hidden_states[1].shape) # print(len(outputs.hidden_states)) # hidden state is a tuple with all the hidden layer outputs from the attention, # We are only interested in the last hidden layer and the last token logits = self.classifier(outputs.hidden_states[-1][:, -1, :]) outputs.logits = logits loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits, labels ) # labels [batch_size], logits [batch_size x num_classes] # print("LOSS", loss) outputs["loss"] = loss print("===================") print(loss) print("===================") # print("OUTPUTS KEY" ,outputs.keys()) return outputs class MyCustomConfig(PretrainedConfig): model_type = "mcq_hf_model" def __init__(self, name_model="apple/OpenELM-450M-Instruct", **kwargs): super().__init__(**kwargs) self.name_model = name_model class MCQHFModel(PreTrainedModel): config_class = MyCustomConfig def __init__(self, config): super().__init__(config) self.model = MCQModel(config.name_model) self.lm_head = None def forward(self, input_ids, attention_mask=None, labels=None, position_ids = None, past_key_values = None, inputs_embeds = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None): return self.model(input_ids, labels=labels, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)