File size: 5,853 Bytes
fef3dc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d048846
 
 
 
 
 
 
 
 
 
 
b92f27a
d048846
 
 
fef3dc9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from torch import nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass

@dataclass
class SelfCorrectiveLlamaOutput(CausalLMOutputWithPast):
    hallucination_logits: torch.FloatTensor = None

class SelfCorrectiveLlama(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
        self.num_new_tokens = 3
        self.original_vocab_size = config.vocab_size

        # Create a new, small embedding layer for only the special tokens
        self.new_token_embeddings = nn.Embedding(self.num_new_tokens, config.hidden_size)

        # Initialize new embeddings with the mean of the original ones
        with torch.no_grad():
            original_embeddings = self.model.embed_tokens.weight
            mean_embeddings = original_embeddings.mean(dim=0)
            self.new_token_embeddings.weight.data.copy_(
                mean_embeddings.unsqueeze(0).expand(self.num_new_tokens, -1)
            )

        intermediate_size = config.intermediate_size
        self.hallucination_gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.hallucination_up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.hallucination_down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
        self.hallucination_detector = nn.Linear(config.hidden_size, self.num_new_tokens + 1)
    
    def forward(
        self, 
        input_ids, 
        attention_mask=None, 
        labels=None, 
        hallucination_labels=None, 
        **kwargs
    ):
        # 1. Manually construct the input embeddings
        clamped_input_ids = torch.clamp(input_ids, max=self.original_vocab_size - 1)
        inputs_embeds = self.model.embed_tokens(clamped_input_ids)

        # Overwrite the embeddings for our new special tokens
        special_token_mask = input_ids >= self.original_vocab_size
        if special_token_mask.any():
            special_ids = input_ids[special_token_mask] - self.original_vocab_size
            special_embeds = self.new_token_embeddings(special_ids)
            inputs_embeds[special_token_mask] = special_embeds

        # 2. Pass the constructed embeddings through the base transformer model
        kwargs["inputs_embeds"] = inputs_embeds
        transformer_outputs = self.model(
            attention_mask=attention_mask,
            **kwargs
        )
        last_hidden = transformer_outputs.last_hidden_state

        # 3. Calculate token logits by combining outputs from both heads
        # Main logits from the original, frozen lm_head
        main_logits = self.lm_head(last_hidden)

        # New token logits from small, trainable embedding layer
        new_logits = F.linear(last_hidden, self.new_token_embeddings.weight)

        # Concatenate to get logits over the full, expanded vocabulary
        logits = torch.cat([main_logits, new_logits], dim=-1)

        # 4. SwiGLU-based hallucination detector
        gate_output = self.hallucination_gate_proj(last_hidden)
        up_output = self.hallucination_up_proj(last_hidden)
        gated_hidden = F.silu(gate_output) * up_output
        detector_hidden = self.hallucination_down_proj(gated_hidden)

        # Hallucination logits
        all_hallucination_logits = self.hallucination_detector(detector_hidden)

        # 5. Modify the token logits conditionally.
        deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
        deletion_tokens_boost = F.softplus(deletion_logits)

        # Conditionally add the deletion logits.
        if hallucination_labels is not None and labels is not None:
            # Training case:
            # Condition 1: The hallucination label is 0 (no hallucination)
            mask_no_hallucination = (hallucination_labels == 0)

            # Condition 2: The next token is one of the deletion tokens.
            # Check if the token ID is within the range of the last `num_new_tokens` in the vocab
            vocab_size = logits.shape[-1]
            mask_is_deletion_token = (labels >= (vocab_size - self.num_new_tokens)) & (labels < vocab_size)
            
            # Combine masks and create the tensor to add.
            combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
            to_add = torch.where(
                combined_mask,
                deletion_tokens_boost,
                torch.zeros_like(deletion_tokens_boost)
            )
        else:
            # Inference case: The hallucination detector's decision becomes a hard gate.
            hallucination_decision = torch.argmax(all_hallucination_logits, dim=-1)

            # Create a mask that is True only when a hallucination is detected (decision != 0)
            hallucination_present_mask = (hallucination_decision != 0).unsqueeze(-1)

            # Where the mask is True, use the softplus boost.
            # Where the mask is False, use a large negative value to suppress deletion.
            to_add = torch.where(
                hallucination_present_mask,
                deletion_tokens_boost,
                torch.full_like(deletion_tokens_boost, torch.finfo(deletion_tokens_boost.dtype).min) # Suppress if no hallucination
            )
        
        logits[:, :, -self.num_new_tokens:].add_(to_add)

        # 6. Return the custom output object
        return SelfCorrectiveLlamaOutput(
            loss=None, # Loss calculation is handled by the Trainer
            logits=logits,
            hallucination_logits=all_hallucination_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=None,
            attentions=transformer_outputs.attentions
        )