MathBite commited on
Commit
fa820b7
·
verified ·
1 Parent(s): 1d7496d

updated deletion token boost

Browse files
Files changed (1) hide show
  1. modeling.py +4 -3
modeling.py CHANGED
@@ -81,6 +81,7 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
81
 
82
  # 5. Modify the token logits conditionally.
83
  deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
 
84
 
85
  # Conditionally add the deletion logits.
86
  if hallucination_labels is not None and labels is not None:
@@ -97,13 +98,13 @@ class SelfCorrectiveLlama(LlamaForCausalLM):
97
  combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
98
  to_add = torch.where(
99
  combined_mask,
100
- deletion_logits,
101
- torch.zeros_like(deletion_logits)
102
  )
103
  logits[:, :, -self.num_new_tokens:].add_(to_add)
104
  else:
105
  # Inference case: always add the deletion logits to the token logits
106
- logits[:, :, -self.num_new_tokens:].add_(deletion_logits)
107
 
108
  # 6. Return the custom output object
109
  return SelfCorrectiveLlamaOutput(
 
81
 
82
  # 5. Modify the token logits conditionally.
83
  deletion_logits = all_hallucination_logits[..., 1:] # skip the first token (no hallucination)
84
+ deletion_tokens_boost = F.softplus(deletion_logits)
85
 
86
  # Conditionally add the deletion logits.
87
  if hallucination_labels is not None and labels is not None:
 
98
  combined_mask = (mask_no_hallucination | mask_is_deletion_token).unsqueeze(-1)
99
  to_add = torch.where(
100
  combined_mask,
101
+ deletion_tokens_boost,
102
+ torch.zeros_like(deletion_tokens_boost)
103
  )
104
  logits[:, :, -self.num_new_tokens:].add_(to_add)
105
  else:
106
  # Inference case: always add the deletion logits to the token logits
107
+ logits[:, :, -self.num_new_tokens:].add_(deletion_tokens_boost)
108
 
109
  # 6. Return the custom output object
110
  return SelfCorrectiveLlamaOutput(