|
|
""" |
|
|
Full definition of a LunarisCodex Language Model, all of it in this single file. |
|
|
This version is a refactored and simplified Llama-style model, created by adapting |
|
|
the robust, industry-standard components from the `Instella` (OLMo) architecture |
|
|
into a clean, minimal, and self-contained structure. |
|
|
|
|
|
This version has been refactored to include KV Caching for efficient inference. |
|
|
|
|
|
Architectural Choices: |
|
|
- Pre-normalization using RMSNorm: Normalizes inputs to each layer rather than outputs, |
|
|
providing better gradient flow and training stability |
|
|
- Rotary Positional Embeddings (RoPE): Encodes position information directly into |
|
|
query/key vectors using rotation matrices in complex space |
|
|
- SwiGLU as the feed-forward network's activation function: Combines Swish activation |
|
|
with a gating mechanism for better performance than ReLU |
|
|
- Grouped-Query Attention (GQA): Reduces memory usage by sharing key/value heads |
|
|
across multiple query heads while maintaining performance |
|
|
- Tied input and output embedding weights: Reduces parameters by sharing the token |
|
|
embedding matrix with the final projection layer |
|
|
- KV Caching: Stores computed key/value pairs to avoid recomputation during generation |
|
|
""" |
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
import inspect |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LunarisCodexConfig: |
|
|
""" |
|
|
Configuration class for the LunarisCodex model. |
|
|
|
|
|
Args: |
|
|
d_model: Hidden dimension size (embedding dimension) |
|
|
n_layers: Number of transformer blocks |
|
|
n_heads: Number of attention heads for queries |
|
|
n_kv_heads: Number of key/value heads (for GQA). If equal to n_heads, it's MHA |
|
|
vocab_size: Size of the vocabulary |
|
|
multiple_of: Ensures FFN hidden dimension is a multiple of this (for efficiency) |
|
|
ffn_hidden_multiplier: Multiplier for FFN hidden dimension size |
|
|
max_seq_len: Maximum sequence length the model can handle |
|
|
rope_theta: Base frequency for RoPE (10000 is standard) |
|
|
dropout: Dropout probability for regularization |
|
|
""" |
|
|
d_model: int = 768 |
|
|
n_layers: int = 12 |
|
|
n_heads: int = 12 |
|
|
n_kv_heads: int = 12 |
|
|
vocab_size: int = 50257 |
|
|
multiple_of: int = 256 |
|
|
ffn_hidden_multiplier: float = 4.0 |
|
|
max_seq_len: int = 1024 |
|
|
rope_theta: float = 10000.0 |
|
|
dropout: float = 0.0 |
|
|
|
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: |
|
|
""" |
|
|
Precomputes the rotary frequencies in complex number format for RoPE. |
|
|
|
|
|
RoPE works by rotating query and key vectors in pairs of dimensions using |
|
|
rotation matrices. In complex space, rotation by angle θ is multiplication |
|
|
by e^(iθ). We precompute these rotation factors for all positions and |
|
|
dimension pairs. |
|
|
|
|
|
Math behind RoPE: |
|
|
- For each dimension pair (d_i, d_i+1), we define a rotation frequency: 1/theta^(2i/dim) |
|
|
- At position t, the rotation angle is: t * frequency |
|
|
- The complex rotation factor is: e^(i * t * frequency) = cos(t*freq) + i*sin(t*freq) |
|
|
|
|
|
Args: |
|
|
dim: The head dimension (d_model // n_heads) |
|
|
end: Maximum sequence length to precompute for |
|
|
theta: Base frequency (typically 10000) |
|
|
|
|
|
Returns: |
|
|
Complex tensor of shape (end, dim//2) containing rotation factors |
|
|
""" |
|
|
|
|
|
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
|
|
|
|
|
|
t = torch.arange(end, device=freqs.device, dtype=torch.float32) |
|
|
|
|
|
|
|
|
freqs = torch.outer(t, freqs) |
|
|
|
|
|
|
|
|
|
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
|
return freqs_cis |
|
|
|
|
|
|
|
|
def apply_rotary_emb( |
|
|
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Applies rotary positional embeddings to query and key tensors. |
|
|
|
|
|
RoPE encodes position by rotating the query and key vectors in pairs of |
|
|
dimensions. This is done by treating consecutive pairs as complex numbers |
|
|
and multiplying by the precomputed rotation factors. |
|
|
|
|
|
Args: |
|
|
xq: Query tensor of shape (batch, heads, seq_len, head_dim) |
|
|
xk: Key tensor of shape (batch, heads, seq_len, head_dim) |
|
|
freqs_cis: Complex rotation factors of shape (seq_len, head_dim//2) |
|
|
|
|
|
Returns: |
|
|
Tuple of (rotated_queries, rotated_keys) with same shapes as input |
|
|
""" |
|
|
|
|
|
|
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
|
|
|
|
|
|
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
|
|
|
|
|
|
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" |
|
|
Root Mean Square Layer Normalization. |
|
|
|
|
|
RMSNorm normalizes by the RMS (root mean square) of the input rather than |
|
|
mean and variance like LayerNorm. This is more stable and efficient. |
|
|
|
|
|
Formula: RMSNorm(x) = x / sqrt(mean(x²) + eps) * weight |
|
|
|
|
|
Why upcast to float32: Mixed precision training uses float16 for speed, |
|
|
but normalization operations need higher precision to avoid numerical |
|
|
instability. We compute in float32 then cast back. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
|
""" |
|
|
Args: |
|
|
dim: Input dimension to normalize |
|
|
eps: Small constant for numerical stability |
|
|
""" |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def _norm(self, x: torch.Tensor): |
|
|
""" |
|
|
Compute RMS normalization. |
|
|
|
|
|
RMS = sqrt(mean(x²)) provides a measure of the magnitude of x. |
|
|
We multiply by the reciprocal (rsqrt) for efficiency. |
|
|
""" |
|
|
|
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
Apply RMSNorm with mixed precision support. |
|
|
|
|
|
The forward pass is stable with mixed-precision training by computing |
|
|
the normalization in float32 and then casting back to the input dtype. |
|
|
""" |
|
|
output_dtype = x.dtype |
|
|
x = self._norm(x.float()).to(output_dtype) |
|
|
return x * self.weight |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
""" |
|
|
Grouped-Query Attention module with KV Caching. |
|
|
|
|
|
GQA reduces memory usage by having fewer key/value heads than query heads. |
|
|
Multiple query heads share the same key/value heads, reducing the KV cache size |
|
|
while maintaining most of the performance of full multi-head attention. |
|
|
|
|
|
KV Caching stores computed key/value pairs from previous tokens to avoid |
|
|
recomputation during autoregressive generation. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LunarisCodexConfig): |
|
|
""" |
|
|
Initialize the attention module. |
|
|
|
|
|
Args: |
|
|
config: Model configuration containing attention parameters |
|
|
""" |
|
|
super().__init__() |
|
|
assert config.d_model % config.n_heads == 0 |
|
|
self.n_heads = config.n_heads |
|
|
self.n_kv_heads = config.n_kv_heads |
|
|
self.head_dim = config.d_model // config.n_heads |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(config.d_model, config.n_heads * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(config.n_heads * self.head_dim, config.d_model, bias=False) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
freqs_cis: torch.Tensor, |
|
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
""" |
|
|
Forward pass of the attention mechanism. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch, seq_len, d_model) |
|
|
freqs_cis: RoPE rotation factors |
|
|
past_kv: Cached key/value pairs from previous tokens (for generation) |
|
|
|
|
|
Returns: |
|
|
Tuple of (attention_output, new_kv_cache) |
|
|
""" |
|
|
B, T, C = x.shape |
|
|
|
|
|
|
|
|
q = self.q_proj(x) |
|
|
k = self.k_proj(x) |
|
|
v = self.v_proj(x) |
|
|
|
|
|
|
|
|
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
q, k = apply_rotary_emb(q, k, freqs_cis) |
|
|
|
|
|
|
|
|
if past_kv is not None: |
|
|
past_k, past_v = past_kv |
|
|
k = torch.cat((past_k, k), dim=-2) |
|
|
v = torch.cat((past_v, v), dim=-2) |
|
|
present_kv = (k, v) |
|
|
|
|
|
|
|
|
|
|
|
if self.n_kv_heads < self.n_heads: |
|
|
n_repeats = self.n_heads // self.n_kv_heads |
|
|
k = k.repeat_interleave(n_repeats, dim=1) |
|
|
v = v.repeat_interleave(n_repeats, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_causal = past_kv is None |
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) |
|
|
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
y = self.dropout(self.o_proj(y)) |
|
|
|
|
|
return y, present_kv |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
""" |
|
|
SwiGLU Feed-Forward Network. |
|
|
|
|
|
SwiGLU combines the Swish activation function with a gating mechanism: |
|
|
SwiGLU(x) = Swish(W1 * x) ⊙ (W3 * x) * W2 |
|
|
where ⊙ is element-wise multiplication. |
|
|
|
|
|
This provides better performance than ReLU-based FFNs by: |
|
|
1. Swish activation: smoother than ReLU, better gradient flow |
|
|
2. Gating mechanism: allows the network to control information flow |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LunarisCodexConfig): |
|
|
""" |
|
|
Initialize the feed-forward network. |
|
|
|
|
|
The hidden dimension is calculated as: |
|
|
1. Start with d_model * ffn_hidden_multiplier |
|
|
2. Adjust for SwiGLU (multiply by 2/3) |
|
|
3. Round up to nearest multiple of 'multiple_of' for efficiency |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = int(config.ffn_hidden_multiplier * config.d_model) |
|
|
hidden_dim = int(2 * hidden_dim / 3) |
|
|
|
|
|
hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) |
|
|
|
|
|
|
|
|
self.w1 = nn.Linear(config.d_model, hidden_dim, bias=False) |
|
|
self.w3 = nn.Linear(config.d_model, hidden_dim, bias=False) |
|
|
self.w2 = nn.Linear(hidden_dim, config.d_model, bias=False) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
""" |
|
|
Apply SwiGLU activation. |
|
|
|
|
|
Formula: SwiGLU(x) = Swish(W1(x)) ⊙ W3(x) → W2 |
|
|
where Swish(x) = x * sigmoid(x) = x * σ(x) |
|
|
""" |
|
|
|
|
|
|
|
|
swiglu = F.silu(self.w1(x)) * self.w3(x) |
|
|
return self.dropout(self.w2(swiglu)) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
""" |
|
|
A single Transformer block using pre-normalization architecture. |
|
|
|
|
|
Pre-normalization (used here) vs Post-normalization: |
|
|
- Pre-norm: LayerNorm → Attention → Add, LayerNorm → FFN → Add |
|
|
- Post-norm: Attention → Add → LayerNorm, FFN → Add → LayerNorm |
|
|
|
|
|
Pre-normalization provides better gradient flow and training stability |
|
|
because the residual connections carry the original gradient directly. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LunarisCodexConfig): |
|
|
super().__init__() |
|
|
self.attention = Attention(config) |
|
|
self.feed_forward = FeedForward(config) |
|
|
self.attention_norm = RMSNorm(config.d_model) |
|
|
self.ffn_norm = RMSNorm(config.d_model) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
freqs_cis: torch.Tensor, |
|
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
""" |
|
|
Forward pass of the transformer block. |
|
|
|
|
|
Architecture: Pre-norm with residual connections |
|
|
1. x + Attention(RMSNorm(x)) |
|
|
2. x + FFN(RMSNorm(x)) |
|
|
|
|
|
Args: |
|
|
x: Input tensor |
|
|
freqs_cis: RoPE rotation factors |
|
|
past_kv: KV cache from previous tokens |
|
|
|
|
|
Returns: |
|
|
Tuple of (block_output, updated_kv_cache) |
|
|
""" |
|
|
|
|
|
|
|
|
attn_output, new_kv = self.attention(self.attention_norm(x), freqs_cis, past_kv) |
|
|
h = x + attn_output |
|
|
|
|
|
|
|
|
out = h + self.feed_forward(self.ffn_norm(h)) |
|
|
|
|
|
return out, new_kv |
|
|
|
|
|
|
|
|
class LunarisCodex(nn.Module): |
|
|
""" |
|
|
Complete LunarisCodex Language Model. |
|
|
|
|
|
This is a Llama-style decoder-only transformer with: |
|
|
- Pre-normalization architecture for better training stability |
|
|
- RoPE for positional encoding |
|
|
- SwiGLU activation in FFN |
|
|
- Grouped-Query Attention for efficiency |
|
|
- KV caching for fast inference |
|
|
- Weight tying between input embeddings and output projection |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LunarisCodexConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.transformer = nn.ModuleDict(dict( |
|
|
wte = nn.Embedding(config.vocab_size, config.d_model), |
|
|
h = nn.ModuleList([Block(config) for _ in range(config.n_layers)]), |
|
|
ln_f = RMSNorm(config.d_model), |
|
|
drop = nn.Dropout(config.dropout), |
|
|
)) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.transformer.wte.weight = self.lm_head.weight |
|
|
|
|
|
|
|
|
|
|
|
freqs_cis = precompute_freqs_cis( |
|
|
self.config.d_model // self.config.n_heads, |
|
|
self.config.max_seq_len, |
|
|
self.config.rope_theta, |
|
|
) |
|
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False) |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
|
|
|
print(f"Number of parameters: {self.get_num_params()/1e6:.2f}M") |
|
|
|
|
|
def get_num_params(self) -> int: |
|
|
"""Count the number of trainable parameters.""" |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
""" |
|
|
Initialize model weights using scaled initialization. |
|
|
|
|
|
Standard initialization for most weights, with special scaled initialization |
|
|
for residual projections to prevent activation variance from growing with depth. |
|
|
""" |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
|
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(module, (Attention, FeedForward)): |
|
|
for name, p in module.named_parameters(): |
|
|
if name.endswith("o_proj.weight") or name.endswith("w2.weight"): |
|
|
|
|
|
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
idx: torch.Tensor, |
|
|
targets: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
""" |
|
|
Forward pass of the model. |
|
|
|
|
|
Args: |
|
|
idx: Input token indices of shape (batch, seq_len) |
|
|
targets: Target token indices for training (optional) |
|
|
past_key_values: KV cache from previous forward passes (for generation) |
|
|
|
|
|
Returns: |
|
|
Tuple of (logits, loss, new_kv_cache) |
|
|
- logits: Output probabilities over vocabulary |
|
|
- loss: Cross-entropy loss (only if targets provided) |
|
|
- new_kv_cache: Updated KV cache for next iteration |
|
|
""" |
|
|
B, T = idx.shape |
|
|
|
|
|
|
|
|
|
|
|
start_pos = past_key_values[0][0].shape[-2] if past_key_values is not None else 0 |
|
|
|
|
|
|
|
|
assert start_pos + T <= self.config.max_seq_len, \ |
|
|
f"Cannot forward, sequence length {start_pos + T} exceeds model's max_seq_len {self.config.max_seq_len}" |
|
|
|
|
|
|
|
|
x = self.transformer.wte(idx) |
|
|
x = self.transformer.drop(x) |
|
|
|
|
|
|
|
|
freqs_cis = self.freqs_cis[start_pos : start_pos + T] |
|
|
|
|
|
|
|
|
new_past_key_values = [] |
|
|
for i, block in enumerate(self.transformer.h): |
|
|
|
|
|
past_kv_for_block = past_key_values[i] if past_key_values is not None else None |
|
|
x, new_kv = block(x, freqs_cis, past_kv_for_block) |
|
|
new_past_key_values.append(new_kv) |
|
|
|
|
|
|
|
|
x = self.transformer.ln_f(x) |
|
|
|
|
|
|
|
|
if targets is not None: |
|
|
|
|
|
logits = self.lm_head(x) |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
|
|
else: |
|
|
|
|
|
|
|
|
logits = self.lm_head(x[:, [-1], :]) |
|
|
loss = None |
|
|
|
|
|
return logits, loss, new_past_key_values |
|
|
|
|
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): |
|
|
""" |
|
|
Configure the optimizer with weight decay applied only to 2D parameters. |
|
|
|
|
|
Weight decay is applied to matrices (2D tensors) but not to biases and |
|
|
layer norm parameters (1D tensors) for better training dynamics. |
|
|
""" |
|
|
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} |
|
|
|
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
|
|
|
|
|
|
|
optim_groups = [ |
|
|
{'params': decay_params, 'weight_decay': weight_decay}, |
|
|
{'params': nodecay_params, 'weight_decay': 0.0} |
|
|
] |
|
|
|
|
|
|
|
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") |
|
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") |
|
|
|
|
|
|
|
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters |
|
|
use_fused = fused_available and device_type == 'cuda' |
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) |
|
|
print(f"using fused AdamW: {use_fused}") |
|
|
return optimizer |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
|
""" |
|
|
Generate text using the model with efficient KV caching. |
|
|
|
|
|
The generation process has two phases: |
|
|
1. Prefill: Process the entire input prompt to build the initial KV cache |
|
|
2. Decode: Generate tokens one by one, reusing the KV cache |
|
|
|
|
|
Args: |
|
|
idx: Input token indices (prompt) |
|
|
max_new_tokens: Maximum number of tokens to generate |
|
|
temperature: Sampling temperature (higher = more random) |
|
|
top_k: Keep only top k tokens for sampling (None = no filtering) |
|
|
|
|
|
Returns: |
|
|
Generated token sequence including the original prompt |
|
|
""" |
|
|
self.eval() |
|
|
past_key_values = None |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
current_len = past_key_values[0][0].shape[-2] if past_key_values else idx.shape[1] |
|
|
if current_len >= self.config.max_seq_len: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
idx_cond = idx if past_key_values is None else idx[:, -1:] |
|
|
|
|
|
|
|
|
logits, _, past_key_values = self(idx_cond, past_key_values=past_key_values) |
|
|
|
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
self.train() |
|
|
return idx |
|
|
|