|
import torch
|
|
import torch.nn as nn
|
|
from transformers import (
|
|
PreTrainedModel,
|
|
PretrainedConfig,
|
|
AutoConfig,
|
|
AutoModel,
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
)
|
|
|
|
import pandas as pd
|
|
from datasets import load_dataset, Dataset, DatasetDict
|
|
|
|
|
|
class MCQModel(nn.Module):
|
|
def __init__(self, name_model):
|
|
super(MCQModel, self).__init__()
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
name_model,
|
|
trust_remote_code=True,
|
|
output_hidden_states=True,
|
|
)
|
|
self.classifier = nn.Linear(
|
|
self.model.config.model_dim, 4
|
|
)
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None,
|
|
position_ids = None,
|
|
past_key_values = None,
|
|
inputs_embeds = None,
|
|
use_cache = None,
|
|
output_attentions = None,
|
|
output_hidden_states = None,
|
|
return_dict = None,
|
|
cache_position = None):
|
|
|
|
print("INSIDE CUSTOM MODEL LABELS: ", labels)
|
|
outputs = self.model(input_ids, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)
|
|
print("OUTPUT KEYS: " , outputs.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = self.classifier(outputs.hidden_states[-1][:, -1, :])
|
|
outputs.logits = logits
|
|
loss = None
|
|
if labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(
|
|
logits, labels
|
|
)
|
|
|
|
outputs["loss"] = loss
|
|
|
|
print("===================")
|
|
print(loss)
|
|
print("===================")
|
|
|
|
return outputs
|
|
|
|
|
|
class MyCustomConfig(PretrainedConfig):
|
|
model_type = "mcq_hf_model"
|
|
|
|
def __init__(self, name_model="apple/OpenELM-450M-Instruct", **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.name_model = name_model
|
|
|
|
|
|
class MCQHFModel(PreTrainedModel):
|
|
config_class = MyCustomConfig
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = MCQModel(config.name_model)
|
|
self.lm_head = None
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None,
|
|
position_ids = None,
|
|
past_key_values = None,
|
|
inputs_embeds = None,
|
|
use_cache = None,
|
|
output_attentions = None,
|
|
output_hidden_states = None,
|
|
return_dict = None,
|
|
cache_position = None):
|
|
|
|
return self.model(input_ids, labels=labels, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds,
|
|
output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)
|
|
|