File size: 3,589 Bytes
b39f8cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
from transformers import (
    PreTrainedModel,
    PretrainedConfig,
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)

import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict


class MCQModel(nn.Module):
    def __init__(self, name_model):
        super(MCQModel, self).__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            name_model,
            trust_remote_code=True,
            output_hidden_states=True,
        )
        self.classifier = nn.Linear(
            self.model.config.model_dim, 4
        )  # 4 classes for 'A', 'B', 'C', 'D'

    def forward(self, input_ids, attention_mask=None, labels=None,

        position_ids = None,

        past_key_values = None,

        inputs_embeds = None,

        use_cache  = None,

        output_attentions  = None,

        output_hidden_states = None,

        return_dict  = None,

        cache_position = None):

        print("INSIDE CUSTOM MODEL LABELS: ", labels)
        outputs = self.model(input_ids, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds,
                              output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)
        print("OUTPUT KEYS: " , outputs.keys())
        # print(outputs.logits.shape)
        # # print(outputs.hidden_states)
        # print(outputs.hidden_states[0].shape)
        # print(outputs.hidden_states[1].shape)
        # print(len(outputs.hidden_states))

        # hidden state is a tuple with all the hidden layer outputs from the attention,
        # We are only interested in the last hidden layer and the last token
        logits = self.classifier(outputs.hidden_states[-1][:, -1, :])
        outputs.logits = logits
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits, labels
            )  # labels [batch_size], logits [batch_size x num_classes]
            # print("LOSS", loss)
            outputs["loss"] = loss

        print("===================")
        print(loss)
        print("===================")
        # print("OUTPUTS KEY" ,outputs.keys())
        return outputs


class MyCustomConfig(PretrainedConfig):
    model_type = "mcq_hf_model"

    def __init__(self, name_model="apple/OpenELM-450M-Instruct", **kwargs):
        super().__init__(**kwargs)
        self.name_model = name_model


class MCQHFModel(PreTrainedModel):
    config_class = MyCustomConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = MCQModel(config.name_model)
        self.lm_head = None

    def forward(self, input_ids, attention_mask=None, labels=None,

        position_ids = None,

        past_key_values = None,

        inputs_embeds = None,

        use_cache  = None,

        output_attentions  = None,

        output_hidden_states = None,

        return_dict  = None,

        cache_position = None):

        return self.model(input_ids, labels=labels, attention_mask=attention_mask, position_ids= position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds,
                            output_attentions=output_attentions, cache_position=cache_position, use_cache=use_cache, return_dict=return_dict, output_hidden_states=output_hidden_states)