Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
class Sampler(nn.Module): | |
""" | |
Optimized sampler implementation using vectorized operations instead of loops, significantly improving performance | |
Performance optimizations: | |
1. Using batch processing instead of sequence loops, reducing Python loop overhead | |
2. Using PyTorch's vectorized operations (like torch.sort, torch.gather) for parallel computation | |
3. Using mask operations to apply top-k filtering at once, avoiding per-sequence processing | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, top_k: int = None): | |
""" | |
Perform sampling operation using vectorized method for top-k filtering | |
Args: | |
logits: Logits tensor with shape [batch_size, vocab_size] | |
temperatures: Temperature parameters with shape [batch_size] | |
top_k: Top-k value for filtering (uniform across all sequences) | |
Returns: | |
Sampled token IDs | |
""" | |
logits = logits.to(torch.float) | |
greedy_tokens = logits.argmax(dim=-1) # Greedy decoding result, used when temperature=0 | |
logits.div_(temperatures.unsqueeze(dim=1)) # Apply temperature scaling | |
# Apply uniform top-k filtering if top_k is provided | |
if top_k is not None and top_k > 0: | |
vocab_size = logits.size(-1) | |
# Create a mask to store which positions should be kept | |
mask = torch.zeros_like(logits, dtype=torch.bool) | |
# Batch sorting for all sequences at once | |
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) | |
# Get threshold for each sequence (the k-th largest value) | |
k_value = min(top_k, vocab_size) # Ensure k doesn't exceed vocab size | |
thresholds = sorted_logits[:, k_value-1:k_value] # Shape [batch_size, 1] | |
thresholds = thresholds.expand(-1, vocab_size) # Expand to match logits shape | |
# Create mask: only keep logits greater than or equal to threshold | |
mask = logits >= thresholds | |
# Apply mask: set logits not in top-k to negative infinity | |
logits = torch.where(mask, logits, torch.tensor(float('-inf'), device=logits.device)) | |
probs = torch.softmax(logits, dim=-1, dtype=torch.float) | |
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) | |
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) | |
return torch.where(temperatures == 0, greedy_tokens, sample_tokens) | |
class RasSampler(nn.Module): | |
""" | |
Optimized Repetition Aware Sampling implementation | |
Performance optimizations: | |
1. Using vectorized nucleus sampling instead of loop implementation, improving sampling efficiency | |
2. Using tensor operations to calculate repetition rate, reducing Python loop overhead | |
3. Optimizing EOS handling logic, reducing unnecessary resampling | |
4. Using PyTorch's vectorized operations for parallel computation | |
5. Batch processing for all sequences, dramatically improving throughput | |
6. Robust handling for sequences of any length, including empty sequences | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, logits: torch.Tensor, decoded_tokens_list: list, | |
win_size: int = 10, tau_r: float = 0.1, | |
top_p: float = 0.8, top_k: int = 25, | |
eos_token: int = 6561, min_tokens: list[int] = None): | |
""" | |
Execute repetition-aware sampling using optimized vectorized operations with batch processing | |
Args: | |
logits: Input logits with shape [batch_size, vocab_size] | |
decoded_tokens_list: List of decoded tokens, each element is a token list for a batch | |
win_size: Window size for repetition detection (uniform across all batch items) | |
tau_r: Repetition threshold (uniform across all batch items) | |
top_p: Nucleus sampling probability threshold (uniform across all batch items) | |
top_k: Nucleus sampling top-k threshold (uniform across all batch items) | |
eos_token: End of sequence token ID (uniform across all batch items) | |
min_tokens: List of minimum tokens to generate before allowing EOS, one per batch item | |
Returns: | |
Selected token IDs | |
""" | |
batch_size = logits.size(0) | |
device = logits.device | |
result = torch.zeros(batch_size, dtype=torch.long, device=device) | |
# Set default values if not provided | |
if min_tokens is None: | |
min_tokens = [2] * batch_size | |
# Ensure min_tokens list has the correct length | |
assert len(min_tokens) == batch_size, f"min_tokens length {len(min_tokens)} != batch_size {batch_size}" | |
# Force continue decode first token | |
for i in range(batch_size): | |
if i < len(decoded_tokens_list) and len(decoded_tokens_list[i]) == 0: | |
logits[i, eos_token] = -float('inf') | |
# 1. First, perform nucleus sampling for all sequences | |
probs = torch.softmax(logits, dim=-1) | |
# Use vectorized nucleus sampling for all sequences | |
# This can be done in batch since top_p and top_k are uniform | |
sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
# Create masks for top-p and top-k filtering | |
top_p_mask = cumulative_probs <= top_p | |
# Create top-k mask (first top_k positions are True) | |
top_k_mask = torch.zeros_like(top_p_mask) | |
top_k_mask[:, :top_k] = True | |
# Combine masks | |
mask = top_p_mask & top_k_mask | |
# Ensure at least one token is selected per sequence | |
first_token_mask = torch.zeros_like(mask) | |
first_token_mask[:, 0] = True | |
mask = mask | first_token_mask | |
# Sample from the filtered distribution | |
sample_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs)) | |
sample_probs = sample_probs / sample_probs.sum(dim=-1, keepdim=True) | |
# Sample indices from the filtered distribution | |
sampled_indices = torch.multinomial(sample_probs, 1).squeeze(-1) | |
top_ids = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1) | |
# 2. Check for repetitions and apply random sampling if needed | |
# Extract recent tokens for each sequence, handling empty or short sequences | |
recent_tokens_list = [] | |
for i in range(batch_size): | |
# Handle index out of range or empty tokens | |
if i < len(decoded_tokens_list): | |
tokens = decoded_tokens_list[i] | |
if len(tokens) > 0: | |
start_idx = max(0, len(tokens) - win_size) | |
recent_tokens_list.append(tokens[start_idx:]) | |
else: | |
recent_tokens_list.append([]) # Empty list for empty tokens | |
else: | |
recent_tokens_list.append([]) # Empty list for missing batch items | |
# Check if we have any tokens to process for repetition detection | |
if any(len(tokens) > 0 for tokens in recent_tokens_list): | |
# Convert to padded tensor for batch processing | |
max_recent_len = max(len(tokens) for tokens in recent_tokens_list) | |
if max_recent_len > 0: # Only proceed if we have tokens | |
recent_tokens_tensor = torch.zeros((batch_size, max_recent_len), dtype=torch.long, device=device) - 1 | |
for i, tokens in enumerate(recent_tokens_list): | |
if len(tokens) > 0: | |
recent_tokens_tensor[i, -len(tokens):] = torch.tensor(tokens, device=device) | |
# Create a mask for valid positions and to avoid division by zero | |
valid_positions_mask = torch.zeros_like(recent_tokens_tensor, dtype=torch.bool) | |
for i, tokens in enumerate(recent_tokens_list): | |
if len(tokens) > 0: | |
valid_positions_mask[i, -len(tokens):] = True | |
# Check repetition rates | |
repetition_counts = torch.zeros(batch_size, device=device) | |
for i in range(batch_size): | |
if len(recent_tokens_list[i]) > 0: | |
repetition_counts[i] = (recent_tokens_tensor[i] == top_ids[i]).sum() | |
# Calculate repetition rates, avoiding division by zero | |
recent_lengths = torch.tensor([max(1, len(tokens)) for tokens in recent_tokens_list], device=device) | |
repetition_rates = repetition_counts / recent_lengths | |
# Identify sequences needing random sampling | |
need_random = repetition_rates >= tau_r | |
# Apply random sampling where needed | |
if need_random.any(): | |
random_indices = torch.multinomial(probs[need_random], 1).squeeze(-1) | |
top_ids[need_random] = random_indices | |
# 3. Handle EOS tokens | |
# Create mask for sequences that should ignore EOS tokens | |
ignore_eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device) | |
for i in range(batch_size): | |
if i < len(decoded_tokens_list): | |
ignore_eos_mask[i] = len(decoded_tokens_list[i]) < min_tokens[i] | |
else: | |
ignore_eos_mask[i] = True # Default to ignoring EOS for missing sequences | |
is_eos_mask = top_ids == eos_token | |
need_resample = ignore_eos_mask & is_eos_mask | |
# Resample for sequences that need it | |
if need_resample.any(): | |
max_trials = 100 | |
for attempt in range(max_trials): | |
# Break if no more resampling needed | |
if not need_resample.any(): | |
break | |
# Sample new tokens for sequences that need resampling | |
new_samples = torch.multinomial(probs[need_resample], 1).squeeze(-1) | |
# Update top_ids with new samples | |
top_ids[need_resample] = new_samples | |
# Update which sequences still need resampling | |
is_eos_mask = top_ids == eos_token | |
need_resample = ignore_eos_mask & is_eos_mask | |
# If still have EOS tokens that should be ignored, force them to be non-EOS | |
if need_resample.any(): | |
# Force to a non-EOS token (e.g., the second most likely token) | |
for i in range(batch_size): | |
if need_resample[i]: | |
# Get second most likely token (or first if only one token) | |
second_best_idx = 1 if sorted_indices.size(1) > 1 else 0 | |
top_ids[i] = sorted_indices[i, second_best_idx] | |
result = top_ids | |
return result | |