Erland's picture
Add files using upload-large-folder tool
7fdd671 verified
# -*- coding: utf-8 -*-
# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py
"""
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# x - x.detach() allows for preserving gradients from x
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
return loss
"""
import torch
import triton
import triton.language as tl
from fla.ops.utils.op import exp, log
from fla.utils import input_guard
@triton.autotune(
[triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
for BLOCK_SIZE in [1024, 2048, 4096, 8192]
for NUM_WARPS in [8, 16, 32]
for NUM_STAGES in [1, 2, 4]
], key=['B', 'N']
)
@triton.jit
def grpo_fwd_kernel(
logits_ptr,
ref_logp_ptr,
input_ids_ptr,
advantages_ptr,
completion_mask_ptr,
loss_ptr,
lse_ptr,
beta,
save_kl: tl.constexpr,
B,
M,
N,
L,
start_idx,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)
off_b = row_idx // L
N = tl.cast(N, tl.int64)
loss_ptr += row_idx
completion_mask_ptr += row_idx
not_skip = tl.load(completion_mask_ptr).to(tl.int1)
if not_skip == 1:
ref_logp_ptr += row_idx
lse_ptr += row_idx
advantages_ptr += off_b
logits_ptr += N * (row_idx + off_b)
input_ids_ptr += row_idx + (off_b+1) * start_idx
base_cols = tl.arange(0, BLOCK_SIZE)
m_i = -float("inf")
l_i = 0.0
for start_n in tl.range(0, N, BLOCK_SIZE):
cols = start_n + base_cols
mask = cols < N
logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
m_ij = tl.max(logits)
new_m_i = tl.maximum(m_i, m_ij)
l_i = l_i * exp(m_i - new_m_i) + tl.sum(exp(logits - new_m_i))
m_i = new_m_i
lse = log(l_i) + m_i
idx = tl.load(input_ids_ptr)
x = tl.load(logits_ptr+idx).to(tl.float32)
advantage = tl.load(advantages_ptr).to(tl.float32)
ref_logp = tl.load(ref_logp_ptr)
logp = x - lse
diff = ref_logp - logp
kl = exp(diff) - diff - 1
loss = kl * beta - advantage
tl.store(loss_ptr, loss.to(loss_ptr.dtype.element_ty))
tl.store(lse_ptr, lse.to(lse_ptr.dtype.element_ty))
if save_kl:
tl.store(loss_ptr+M, kl.to(loss_ptr.dtype.element_ty))
else:
# store 0
tl.store(loss_ptr, 0.0)
if save_kl:
tl.store(loss_ptr+M, 0.0)
@triton.autotune(
[triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES)
for BLOCK_SIZE in [1024, 2048, 4096, 8192]
for NUM_WARPS in [8, 16, 32]
for NUM_STAGES in [1, 2, 4]
], key=['B', 'N']
)
@triton.jit
def grpo_bwd_kernel(
dloss_ptr,
dlogits_ptr,
logits_ptr,
ref_logp_ptr,
input_ids_ptr,
advantages_ptr,
completion_mask_ptr,
lse_ptr,
beta,
B,
N,
L,
start_idx,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0) # B*L
off_b = row_idx // L
N = tl.cast(N, tl.int64)
dlogits_ptr += N * (row_idx + off_b)
base_cols = tl.arange(0, BLOCK_SIZE)
completion_mask_ptr += row_idx
not_skip = tl.load(completion_mask_ptr).to(tl.int1)
if not_skip == 1:
lse_ptr += row_idx
dloss_ptr += row_idx
advantages_ptr += off_b
ref_logp_ptr += row_idx
logits_ptr += N * (row_idx + off_b)
input_ids_ptr += row_idx + (off_b+1) * start_idx
dloss = tl.load(dloss_ptr).to(tl.float32)
lse = tl.load(lse_ptr).to(tl.float32)
idx = tl.load(input_ids_ptr)
x = tl.load(logits_ptr+idx).to(tl.float32)
advantage = tl.load(advantages_ptr).to(tl.float32)
ref_logp = tl.load(ref_logp_ptr)
logp = x - lse
dlogp = (beta * (-1.0 * exp(ref_logp - logp) + 1)
- advantage) * dloss
for start_n in tl.range(0, N, BLOCK_SIZE):
cols = start_n + base_cols
mask = cols < N
logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32)
probs = exp(logits - lse)
dlogits = tl.where(cols == idx, 1-probs, -probs) * dlogp
tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
else:
dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
for start_n in tl.range(0, N, BLOCK_SIZE):
cols = start_n + base_cols
mask = cols < N
tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask)
class GrpoLoss(torch.autograd.Function):
@input_guard
@staticmethod
def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl):
ctx.input_shape = logits.shape
B, L_ADD_1, N = ctx.input_shape
L = L_ADD_1 - 1
M = B * L
input_ids_start_index = input_ids.size(1) - L
if not save_kl:
loss = torch.empty(B, L, device=logits.device, dtype=torch.float32)
else:
loss = torch.empty(B*2, L, device=logits.device, dtype=torch.float32)
lse = torch.empty(B, L, device=logits.device, dtype=torch.float32)
if completion_mask is None:
completion_mask = torch.ones(B, L, device=logits.device, dtype=torch.int32)
else:
loss[:B].masked_fill_(completion_mask.logical_not(), 0.0)
grpo_fwd_kernel[(M,)](
logits_ptr=logits,
ref_logp_ptr=ref_logp,
input_ids_ptr=input_ids,
advantages_ptr=advantages,
completion_mask_ptr=completion_mask,
loss_ptr=loss,
lse_ptr=lse,
beta=beta,
save_kl=save_kl,
B=B, M=M, N=N, L=L,
start_idx=input_ids_start_index,
)
ctx.beta = beta
ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask)
ctx.ref_logp = ref_logp
return loss
@input_guard
@staticmethod
def backward(ctx, dloss):
# The grad of logits comes from two parts, the reward part and the kl part
lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors
B, L_ADD_1, N = ctx.input_shape
L = L_ADD_1 - 1
M = B * L
input_ids_start_index = input_ids.size(1) - L
dlogits = torch.empty_like(logits) # B, L_ADD_1, N
grpo_bwd_kernel[(M,)](
dloss_ptr=dloss,
dlogits_ptr=dlogits,
logits_ptr=logits,
ref_logp_ptr=ctx.ref_logp,
input_ids_ptr=input_ids,
advantages_ptr=advantages,
completion_mask_ptr=completion_mask,
lse_ptr=lse,
beta=ctx.beta,
B=B, N=N, L=L,
start_idx=input_ids_start_index,
)
# The last token in the completion is not used in the loss computation
# and therefore its gradient should be set to 0
dlogits[:, -1, :].fill_(0.0)
return dlogits.view(*ctx.input_shape), None, None, None, None, None, None
def fused_grpo_loss(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False) -> torch.Tensor:
'''
compute grpo loss, save memory(no addition usage) and fast speed(6X for A800)
Args:
logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1]
ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1]
input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids
advantages: Tensor, [B], the advantages of each prompt
beta: float, the weight of kl loss
completion_mask: Tensor, loss mask
save_kl: bool, if true will save kl
Retutn:
loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part
NOTE: logits(ref_logits) is computed by these steps
logits_to_keep = completion_ids.size(1)
def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits
return logits
logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep)
'''
out = GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl)
if not save_kl:
return out
else:
return out.chunk(2, axis=0)
def grpo_loss_torch(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False):
def get_log_probs(logits, input_ids):
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
logits = logits[:, :-1]
per_token_logps = get_log_probs(logits, input_ids)
ref_per_token_logps = ref_logp
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - beta * per_token_kl)
if completion_mask is not None:
per_token_loss *= completion_mask
if save_kl:
per_token_kl *= completion_mask
return per_token_loss if not save_kl else (per_token_loss, per_token_kl)
@torch.compile(fullgraph=True)
def grpo_loss_with_old_logps(
logps: torch.Tensor,
ref_logps: torch.Tensor,
old_logps: torch.Tensor,
pad_mask: torch.Tensor,
logits_to_keep: int,
rewards: torch.Tensor,
beta: float = 0.2,
epsilon: float = 0.2
):
"""
Compute the GRPO (Group Relative Policy Optimization) loss.
Args:
logps (torch.Tensor): [Batch, Token_length] Log probabilities of the current policy.
ref_logps (torch.Tensor):[Batch, Token_length] Log probabilities of the reference policy.
old_logps (torch.Tensor): [Batch, Token_length] Log probabilities of the old policy.
completion_ids (torch.Tensor): [Batch, Token_length] Completion token IDs (bool).
pad_token_id: Pad token ID.
logits_to_keep (int): Number of logits to keep for masking.
rewards (torch.Tensor): [Batch] Rewards for each generation.
beta (float) = 0.2: A hyperparameter for weighting the KL divergence term.
epsilon (float) = 0.2: An float hyperparameter for clipping the importance weights.
Returns:
torch.Tensor: The computed GRPO loss.
"""
B = logps.shape[0]
assert B > 1, "Batch * Num generations should be greater than 1"
rewards_shaped = rewards.view(-1, B) # B,num_generations
advantages = (rewards_shaped - rewards_shaped.mean(dim=1, keepdim=True)) / \
(rewards_shaped.std(dim=1, keepdim=True) + 1e-8)
advantages = advantages.view(-1) # B*num_generations
# Calculate the per - token KL divergence
per_token_kl = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1
# Calculate the ratio of probabilities (importance weights)
# Importance weights are calculated as exp(log_pi_theta - log_pi_theta_old)
importance_weights = torch.exp(logps - old_logps)
# Clip the importance weights to the range [1 - epsilon, 1 + epsilon]
importance_weights_clipped = torch.clamp(importance_weights, 1 - epsilon, 1 + epsilon)
# Create a completion mask. It checks which positions are valid based on logits_to_keep
completion_mask = torch.arange(logits_to_keep, device=logps.device)[None, :] >= 0
# Combine the completion mask and padding mask
completion_mask = completion_mask & pad_mask # Ensure matching shape
# Add an extra dimension to advantages to match the shape for element - wise multiplication
advantages = advantages.unsqueeze(1)
# Calculate the per - token loss. It takes the minimum of the unclipped and clipped importance weights
# and subtracts the KL divergence term weighted by beta, then multiplies by the completion mask
token_loss = -(torch.min(advantages * importance_weights, advantages *
importance_weights_clipped) - beta * per_token_kl) * completion_mask
# Calculate the final loss by summing the token losses and normalizing by the number of valid tokens
loss = -token_loss.sum() / completion_mask.sum()
return loss