Update modeling_quiet.py
Browse files- modeling_quiet.py +16 -24
modeling_quiet.py
CHANGED
|
@@ -1098,18 +1098,27 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1098 |
self.remove_negative_rewards = True
|
| 1099 |
self.post_init()
|
| 1100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1101 |
def _generate_thoughts(self, hidden_states, max_length):
|
| 1102 |
batch_size = hidden_states.size(0)
|
| 1103 |
thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
| 1104 |
thought_embeddings = []
|
| 1105 |
|
| 1106 |
-
# Create an instance of QuietForCausalLM using the current model's configuration
|
| 1107 |
-
causal_lm_model = QuietForCausalLM(self.config)
|
| 1108 |
-
causal_lm_model.eval() # Set the model to evaluation mode
|
| 1109 |
-
|
| 1110 |
for i in range(self.config.max_thoughts):
|
| 1111 |
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
| 1112 |
-
thought_outputs =
|
| 1113 |
input_ids=thought_input_ids,
|
| 1114 |
max_length=max_length,
|
| 1115 |
do_sample=True,
|
|
@@ -1124,21 +1133,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1124 |
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
| 1125 |
return thought_ids, thought_embeddings
|
| 1126 |
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
def calculate_policy_loss(self, thoughts, rewards):
|
| 1130 |
-
thought_log_probs = []
|
| 1131 |
-
for thought in thoughts:
|
| 1132 |
-
thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
|
| 1133 |
-
thought_log_probs.append(thought_log_prob)
|
| 1134 |
-
|
| 1135 |
-
thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
|
| 1136 |
-
thought_probs = torch.exp(thought_log_probs)
|
| 1137 |
-
|
| 1138 |
-
policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
|
| 1139 |
-
|
| 1140 |
-
return policy_loss
|
| 1141 |
-
|
| 1142 |
def get_input_embeddings(self):
|
| 1143 |
return self.model.embed_tokens
|
| 1144 |
|
|
@@ -1214,13 +1208,11 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1214 |
use_cache=use_cache,
|
| 1215 |
output_attentions=output_attentions,
|
| 1216 |
output_hidden_states=output_hidden_states,
|
| 1217 |
-
return_dict=True,
|
| 1218 |
)
|
| 1219 |
-
|
| 1220 |
hidden_states = outputs.last_hidden_state
|
| 1221 |
logits = self.lm_head(hidden_states)
|
| 1222 |
|
| 1223 |
-
|
| 1224 |
thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.thought_length)
|
| 1225 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
| 1226 |
|
|
@@ -1230,7 +1222,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
| 1230 |
# Mix base and thought logits
|
| 1231 |
mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
|
| 1232 |
mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
|
| 1233 |
-
|
| 1234 |
loss = None
|
| 1235 |
if labels is not None:
|
| 1236 |
# Shift so that tokens < n predict n
|
|
|
|
| 1098 |
self.remove_negative_rewards = True
|
| 1099 |
self.post_init()
|
| 1100 |
|
| 1101 |
+
def calculate_policy_loss(self, thoughts, rewards):
|
| 1102 |
+
thought_log_probs = []
|
| 1103 |
+
for thought in thoughts:
|
| 1104 |
+
thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
|
| 1105 |
+
thought_log_probs.append(thought_log_prob)
|
| 1106 |
+
|
| 1107 |
+
thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
|
| 1108 |
+
thought_probs = torch.exp(thought_log_probs)
|
| 1109 |
+
|
| 1110 |
+
policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
|
| 1111 |
+
|
| 1112 |
+
return policy_loss
|
| 1113 |
+
|
| 1114 |
def _generate_thoughts(self, hidden_states, max_length):
|
| 1115 |
batch_size = hidden_states.size(0)
|
| 1116 |
thought_ids = torch.zeros((batch_size, self.config.max_thoughts, max_length), dtype=torch.long, device=hidden_states.device)
|
| 1117 |
thought_embeddings = []
|
| 1118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
for i in range(self.config.max_thoughts):
|
| 1120 |
thought_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=hidden_states.device)
|
| 1121 |
+
thought_outputs = self.generate(
|
| 1122 |
input_ids=thought_input_ids,
|
| 1123 |
max_length=max_length,
|
| 1124 |
do_sample=True,
|
|
|
|
| 1133 |
thought_embeddings = torch.stack(thought_embeddings, dim=1)
|
| 1134 |
return thought_ids, thought_embeddings
|
| 1135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1136 |
def get_input_embeddings(self):
|
| 1137 |
return self.model.embed_tokens
|
| 1138 |
|
|
|
|
| 1208 |
use_cache=use_cache,
|
| 1209 |
output_attentions=output_attentions,
|
| 1210 |
output_hidden_states=output_hidden_states,
|
| 1211 |
+
return_dict=True,
|
| 1212 |
)
|
|
|
|
| 1213 |
hidden_states = outputs.last_hidden_state
|
| 1214 |
logits = self.lm_head(hidden_states)
|
| 1215 |
|
|
|
|
| 1216 |
thought_ids, thought_embeddings = self._generate_thoughts(hidden_states, max_length=self.config.thought_length)
|
| 1217 |
thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
|
| 1218 |
|
|
|
|
| 1222 |
# Mix base and thought logits
|
| 1223 |
mixed_logits = logits.unsqueeze(1) + self.mixing_head(thought_logits)
|
| 1224 |
mixed_logits = mixed_logits.view(-1, mixed_logits.size(-1))
|
| 1225 |
+
|
| 1226 |
loss = None
|
| 1227 |
if labels is not None:
|
| 1228 |
# Shift so that tokens < n predict n
|