Gemma-3-270M GRPO (math + CoT)

⚠️ Experimental training run. This is a personal experiment and not production-ready. The current checkpoint is from ~1200 GRPO steps; accuracy is unstable and may be low. Prompt format and weights may change. Please evaluate carefully before any use.

Small reasoning-tuned variant of Google’s Gemma-3-270M.
Two-stage recipe: SFT on math prompts with hidden <think>…</think> reasoning, then GRPO to reinforce correct final answers while discouraging overly long hidden reasoning.
Note: there is no <final> tag in this project. The final answer is emitted as \boxed{...} after the <think> block.

The model can emit <think>…</think> tokens. Examples below strip this by default.


✨ What’s inside

  • Base: google/gemma-3-270m
  • Objective:
    • SFT: learn prompt format + produce a boxed final answer.
    • GRPO: reward = 1.0 if the boxed answer matches ground truth (numeric or sympy-equivalent), else 0.0, minus a small penalty proportional to tokens inside <think>…</think>. KL regularization to the SFT reference.

🧠 Prompt & output format

Training/eval wrapper (no <final> tag):

<prompt>
<YOUR_QUESTION_HERE>
</prompt>
<think>
…(internal scratch work)…
</think>
\boxed{FINAL_ANSWER}
  • SFT builder resembled: format_sft_example(question, reasoning, final_answer)
  • RL prompts use only: <prompt>…</prompt>\n<think>\n and expect the model to write reasoning + \boxed{...}.

πŸš€ Quickstart

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, re

MODEL_ID = "nirav-madhani/gemma3-270m-grpo-math"  # change if different

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True,
    attn_implementation="eager",
).to(device).eval()

BOX_RE = re.compile(r"\\boxed\\{([^{}]+)\\}")

def generate(question, max_new_tokens=160, temperature=0.2, top_p=0.95, return_boxed=True, show_think=False):
    prompt = f"<prompt>\n{question}\n</prompt>\n<think>\n"
    inputs = tok(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=temperature > 0,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tok.eos_token_id,
            pad_token_id=tok.pad_token_id,
            use_cache=True,
        )
    text = tok.decode(out[0], skip_special_tokens=True)
    if not show_think:
        text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.S)
    if return_boxed:
        m = BOX_RE.search(text)
        return m.group(1).strip() if m else text.strip()
    return text

print(generate("If 3x + 5 = 17, what is x?"))

πŸ—οΈ Training recipe (summary)

Stage 1 β€” SFT

  • transformers==4.55.x, trl==0.21.0
  • Tokenizer: Gemma-3 fast; pad_token = eos_token
  • Trainer + DataCollatorForLanguageModeling
  • max_seq_length = prompt_len + completion_len; small per-device batch with grad accumulation; BF16/FP16 on Ampere, else FP32.

Stage 2 β€” RL (GRPO)

  • TRL GRPO (trl==0.21.0)
  • Policy & reference initialized from SFT
  • Reward:
    • +1.0 if \boxed{...} equals ground truth (float tolerance or sympy equivalence)
    • βˆ’Ξ» * (#tokens inside <think>…</think>)
  • Knobs (fit ~15 GB VRAM; tune per GPU):
    • num_generations (K): 2–8
    • per_device_train_batch_size: 1–2
    • gradient_accumulation_steps: 4–8
      (ensure batch * accum * world_size is divisible by K)
    • max_prompt_length: ~160–256
    • max_completion_length: ~128–192
    • beta (KL): ~0.02
    • attn_implementation="eager"; enable use_cache=True if you have headroom

Checkpointing

  • Checkpoints saved every N steps; keep last 3; persisted to Google Drive or /kaggle/working.
  • Inference loader grabs newest checkpoint-XXXX/, else RL root β†’ SFT β†’ base.

πŸ“Š Evaluation

Early-stage. Evaluate on your math split or GSM-style test by extracting \boxed{…} and checking numeric or sympy-equivalence. Track:

  • reward mean/std, exact-match of final answers,
  • KL vs. reference,
  • output length and β€œthink” token counts.

βš–οΈ License & usage

  • Base license: Gemma models are under the Gemma license. This derivative remains subject to those terms.
  • Repo metadata sets license: gemma. Review Gemma’s terms before commercial use/redistribution.

πŸ”’ Limitations & risks

  • 270M params is very small; expect brittleness outside narrow math tasks.
  • Hidden reasoning can be wrong; we hide it by default.
  • No built-in safety filtering.

🧩 Repro notes

  • Colab/Kaggle friendly; use attn_implementation="eager" to avoid FA mismatches.
  • GRPO progress bar’s β€œTraining Loss” can be 0.0 β€” monitor reward/KL/length.
  • Env tips:
    • set TRANSFORMERS_NO_TORCHVISION=1,
    • ensure compatible numpy/scikit-learn on Kaggle if transformers.generation pulls sklearn.

πŸ™Œ Acknowledgements

  • Base weights: Google google/gemma-3-270m
  • RL training: TRL (trl==0.21.0)

πŸ“£ Citation

@software{nirav_gemma3_270m_grpo_math_2025,
  title   = {Gemma-3-270M GRPO (math + CoT)},
  author  = {Nirav Madhani},
  year    = {2025},
  url     = {https://huggingface.co/nirav-madhani/gemma3-270m-grpo-math}
}
Downloads last month
7
Safetensors
Model size
268M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Nirav-Madhani/gemma3-270m-grpo-math

Finetuned
(34)
this model