|
|
|
|
|
|
|
""" |
|
# Get the per-token log probabilities for the completions for the model and the reference model |
|
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): |
|
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded |
|
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits |
|
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred |
|
|
|
input_ids = input_ids[:, -logits_to_keep:] |
|
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. |
|
# See https://github.com/huggingface/trl/issues/2770 |
|
logits = logits[:, -logits_to_keep:] |
|
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
if return_outputs: |
|
raise ValueError("The GRPOTrainer does not support returning outputs") |
|
# Compute the per-token log probabilities for the model |
|
|
|
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] |
|
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] |
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
|
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens |
|
|
|
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) |
|
|
|
# Compute the KL divergence between the model and the reference model |
|
ref_per_token_logps = inputs["ref_per_token_logps"] |
|
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 |
|
|
|
# x - x.detach() allows for preserving gradients from x |
|
advantages = inputs["advantages"] |
|
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) |
|
per_token_loss = -(per_token_loss - self.beta * per_token_kl) |
|
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() |
|
|
|
# Log the metrics |
|
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() |
|
self._metrics["completion_length"].append(completion_length) |
|
|
|
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() |
|
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) |
|
|
|
return loss |
|
""" |
|
|
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.ops.utils.op import exp, log |
|
from fla.utils import input_guard |
|
|
|
|
|
@triton.autotune( |
|
[triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES) |
|
for BLOCK_SIZE in [1024, 2048, 4096, 8192] |
|
for NUM_WARPS in [8, 16, 32] |
|
for NUM_STAGES in [1, 2, 4] |
|
], key=['B', 'N'] |
|
) |
|
@triton.jit |
|
def grpo_fwd_kernel( |
|
logits_ptr, |
|
ref_logp_ptr, |
|
input_ids_ptr, |
|
advantages_ptr, |
|
completion_mask_ptr, |
|
loss_ptr, |
|
lse_ptr, |
|
beta, |
|
save_kl: tl.constexpr, |
|
B, |
|
M, |
|
N, |
|
L, |
|
start_idx, |
|
BLOCK_SIZE: tl.constexpr |
|
): |
|
row_idx = tl.program_id(0) |
|
|
|
off_b = row_idx // L |
|
N = tl.cast(N, tl.int64) |
|
|
|
loss_ptr += row_idx |
|
|
|
completion_mask_ptr += row_idx |
|
not_skip = tl.load(completion_mask_ptr).to(tl.int1) |
|
if not_skip == 1: |
|
ref_logp_ptr += row_idx |
|
lse_ptr += row_idx |
|
advantages_ptr += off_b |
|
logits_ptr += N * (row_idx + off_b) |
|
input_ids_ptr += row_idx + (off_b+1) * start_idx |
|
base_cols = tl.arange(0, BLOCK_SIZE) |
|
|
|
m_i = -float("inf") |
|
l_i = 0.0 |
|
for start_n in tl.range(0, N, BLOCK_SIZE): |
|
cols = start_n + base_cols |
|
mask = cols < N |
|
logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32) |
|
m_ij = tl.max(logits) |
|
new_m_i = tl.maximum(m_i, m_ij) |
|
l_i = l_i * exp(m_i - new_m_i) + tl.sum(exp(logits - new_m_i)) |
|
m_i = new_m_i |
|
lse = log(l_i) + m_i |
|
|
|
idx = tl.load(input_ids_ptr) |
|
x = tl.load(logits_ptr+idx).to(tl.float32) |
|
advantage = tl.load(advantages_ptr).to(tl.float32) |
|
ref_logp = tl.load(ref_logp_ptr) |
|
logp = x - lse |
|
diff = ref_logp - logp |
|
kl = exp(diff) - diff - 1 |
|
loss = kl * beta - advantage |
|
|
|
tl.store(loss_ptr, loss.to(loss_ptr.dtype.element_ty)) |
|
tl.store(lse_ptr, lse.to(lse_ptr.dtype.element_ty)) |
|
if save_kl: |
|
tl.store(loss_ptr+M, kl.to(loss_ptr.dtype.element_ty)) |
|
else: |
|
|
|
tl.store(loss_ptr, 0.0) |
|
if save_kl: |
|
tl.store(loss_ptr+M, 0.0) |
|
|
|
|
|
@triton.autotune( |
|
[triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES) |
|
for BLOCK_SIZE in [1024, 2048, 4096, 8192] |
|
for NUM_WARPS in [8, 16, 32] |
|
for NUM_STAGES in [1, 2, 4] |
|
], key=['B', 'N'] |
|
) |
|
@triton.jit |
|
def grpo_bwd_kernel( |
|
dloss_ptr, |
|
dlogits_ptr, |
|
logits_ptr, |
|
ref_logp_ptr, |
|
input_ids_ptr, |
|
advantages_ptr, |
|
completion_mask_ptr, |
|
lse_ptr, |
|
beta, |
|
B, |
|
N, |
|
L, |
|
start_idx, |
|
BLOCK_SIZE: tl.constexpr |
|
): |
|
|
|
row_idx = tl.program_id(0) |
|
off_b = row_idx // L |
|
|
|
N = tl.cast(N, tl.int64) |
|
|
|
dlogits_ptr += N * (row_idx + off_b) |
|
base_cols = tl.arange(0, BLOCK_SIZE) |
|
completion_mask_ptr += row_idx |
|
not_skip = tl.load(completion_mask_ptr).to(tl.int1) |
|
|
|
if not_skip == 1: |
|
lse_ptr += row_idx |
|
dloss_ptr += row_idx |
|
advantages_ptr += off_b |
|
ref_logp_ptr += row_idx |
|
logits_ptr += N * (row_idx + off_b) |
|
input_ids_ptr += row_idx + (off_b+1) * start_idx |
|
dloss = tl.load(dloss_ptr).to(tl.float32) |
|
lse = tl.load(lse_ptr).to(tl.float32) |
|
idx = tl.load(input_ids_ptr) |
|
x = tl.load(logits_ptr+idx).to(tl.float32) |
|
advantage = tl.load(advantages_ptr).to(tl.float32) |
|
ref_logp = tl.load(ref_logp_ptr) |
|
logp = x - lse |
|
|
|
dlogp = (beta * (-1.0 * exp(ref_logp - logp) + 1) |
|
- advantage) * dloss |
|
|
|
for start_n in tl.range(0, N, BLOCK_SIZE): |
|
cols = start_n + base_cols |
|
mask = cols < N |
|
logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32) |
|
probs = exp(logits - lse) |
|
dlogits = tl.where(cols == idx, 1-probs, -probs) * dlogp |
|
|
|
tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask) |
|
else: |
|
dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) |
|
for start_n in tl.range(0, N, BLOCK_SIZE): |
|
cols = start_n + base_cols |
|
mask = cols < N |
|
|
|
tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask) |
|
|
|
|
|
class GrpoLoss(torch.autograd.Function): |
|
|
|
@input_guard |
|
@staticmethod |
|
def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl): |
|
ctx.input_shape = logits.shape |
|
B, L_ADD_1, N = ctx.input_shape |
|
L = L_ADD_1 - 1 |
|
M = B * L |
|
input_ids_start_index = input_ids.size(1) - L |
|
|
|
if not save_kl: |
|
loss = torch.empty(B, L, device=logits.device, dtype=torch.float32) |
|
else: |
|
loss = torch.empty(B*2, L, device=logits.device, dtype=torch.float32) |
|
|
|
lse = torch.empty(B, L, device=logits.device, dtype=torch.float32) |
|
|
|
if completion_mask is None: |
|
completion_mask = torch.ones(B, L, device=logits.device, dtype=torch.int32) |
|
else: |
|
loss[:B].masked_fill_(completion_mask.logical_not(), 0.0) |
|
|
|
grpo_fwd_kernel[(M,)]( |
|
logits_ptr=logits, |
|
ref_logp_ptr=ref_logp, |
|
input_ids_ptr=input_ids, |
|
advantages_ptr=advantages, |
|
completion_mask_ptr=completion_mask, |
|
loss_ptr=loss, |
|
lse_ptr=lse, |
|
beta=beta, |
|
save_kl=save_kl, |
|
B=B, M=M, N=N, L=L, |
|
start_idx=input_ids_start_index, |
|
) |
|
ctx.beta = beta |
|
ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask) |
|
ctx.ref_logp = ref_logp |
|
return loss |
|
|
|
@input_guard |
|
@staticmethod |
|
def backward(ctx, dloss): |
|
|
|
lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors |
|
B, L_ADD_1, N = ctx.input_shape |
|
L = L_ADD_1 - 1 |
|
M = B * L |
|
|
|
input_ids_start_index = input_ids.size(1) - L |
|
|
|
dlogits = torch.empty_like(logits) |
|
|
|
grpo_bwd_kernel[(M,)]( |
|
dloss_ptr=dloss, |
|
dlogits_ptr=dlogits, |
|
logits_ptr=logits, |
|
ref_logp_ptr=ctx.ref_logp, |
|
input_ids_ptr=input_ids, |
|
advantages_ptr=advantages, |
|
completion_mask_ptr=completion_mask, |
|
lse_ptr=lse, |
|
beta=ctx.beta, |
|
B=B, N=N, L=L, |
|
start_idx=input_ids_start_index, |
|
) |
|
|
|
|
|
dlogits[:, -1, :].fill_(0.0) |
|
return dlogits.view(*ctx.input_shape), None, None, None, None, None, None |
|
|
|
|
|
def fused_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False) -> torch.Tensor: |
|
''' |
|
compute grpo loss, save memory(no addition usage) and fast speed(6X for A800) |
|
|
|
Args: |
|
logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1] |
|
ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1] |
|
input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids |
|
advantages: Tensor, [B], the advantages of each prompt |
|
beta: float, the weight of kl loss |
|
completion_mask: Tensor, loss mask |
|
save_kl: bool, if true will save kl |
|
|
|
Retutn: |
|
loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part |
|
|
|
NOTE: logits(ref_logits) is computed by these steps |
|
logits_to_keep = completion_ids.size(1) |
|
|
|
def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep): |
|
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded |
|
logits = model( |
|
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 |
|
).logits |
|
return logits |
|
|
|
logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep) |
|
''' |
|
out = GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl) |
|
if not save_kl: |
|
return out |
|
else: |
|
return out.chunk(2, axis=0) |
|
|
|
|
|
def grpo_loss_torch(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False): |
|
def get_log_probs(logits, input_ids): |
|
per_token_logps = [] |
|
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]): |
|
log_probs = logits_row.log_softmax(dim=-1) |
|
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) |
|
per_token_logps.append(token_log_prob) |
|
return torch.stack(per_token_logps) |
|
|
|
logits = logits[:, :-1] |
|
per_token_logps = get_log_probs(logits, input_ids) |
|
ref_per_token_logps = ref_logp |
|
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 |
|
|
|
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) |
|
per_token_loss = -(per_token_loss - beta * per_token_kl) |
|
if completion_mask is not None: |
|
per_token_loss *= completion_mask |
|
if save_kl: |
|
per_token_kl *= completion_mask |
|
return per_token_loss if not save_kl else (per_token_loss, per_token_kl) |
|
|
|
|
|
@torch.compile(fullgraph=True) |
|
def grpo_loss_with_old_logps( |
|
logps: torch.Tensor, |
|
ref_logps: torch.Tensor, |
|
old_logps: torch.Tensor, |
|
pad_mask: torch.Tensor, |
|
logits_to_keep: int, |
|
rewards: torch.Tensor, |
|
beta: float = 0.2, |
|
epsilon: float = 0.2 |
|
): |
|
""" |
|
Compute the GRPO (Group Relative Policy Optimization) loss. |
|
|
|
Args: |
|
logps (torch.Tensor): [Batch, Token_length] Log probabilities of the current policy. |
|
ref_logps (torch.Tensor):[Batch, Token_length] Log probabilities of the reference policy. |
|
old_logps (torch.Tensor): [Batch, Token_length] Log probabilities of the old policy. |
|
completion_ids (torch.Tensor): [Batch, Token_length] Completion token IDs (bool). |
|
pad_token_id: Pad token ID. |
|
logits_to_keep (int): Number of logits to keep for masking. |
|
rewards (torch.Tensor): [Batch] Rewards for each generation. |
|
beta (float) = 0.2: A hyperparameter for weighting the KL divergence term. |
|
epsilon (float) = 0.2: An float hyperparameter for clipping the importance weights. |
|
|
|
Returns: |
|
torch.Tensor: The computed GRPO loss. |
|
""" |
|
B = logps.shape[0] |
|
assert B > 1, "Batch * Num generations should be greater than 1" |
|
|
|
rewards_shaped = rewards.view(-1, B) |
|
advantages = (rewards_shaped - rewards_shaped.mean(dim=1, keepdim=True)) / \ |
|
(rewards_shaped.std(dim=1, keepdim=True) + 1e-8) |
|
advantages = advantages.view(-1) |
|
|
|
per_token_kl = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1 |
|
|
|
|
|
|
|
importance_weights = torch.exp(logps - old_logps) |
|
|
|
|
|
importance_weights_clipped = torch.clamp(importance_weights, 1 - epsilon, 1 + epsilon) |
|
|
|
|
|
completion_mask = torch.arange(logits_to_keep, device=logps.device)[None, :] >= 0 |
|
|
|
|
|
completion_mask = completion_mask & pad_mask |
|
|
|
|
|
advantages = advantages.unsqueeze(1) |
|
|
|
|
|
|
|
token_loss = -(torch.min(advantages * importance_weights, advantages * |
|
importance_weights_clipped) - beta * per_token_kl) * completion_mask |
|
|
|
|
|
loss = -token_loss.sum() / completion_mask.sum() |
|
|
|
return loss |
|
|