Gemma-3-270M GRPO (math + CoT)
β οΈ Experimental training run. This is a personal experiment and not production-ready. The current checkpoint is from ~1200 GRPO steps; accuracy is unstable and may be low. Prompt format and weights may change. Please evaluate carefully before any use.
Small reasoning-tuned variant of Googleβs Gemma-3-270M.
Two-stage recipe: SFT on math prompts with hidden <think>β¦</think>
reasoning, then GRPO to reinforce correct final answers while discouraging overly long hidden reasoning.
Note: there is no <final>
tag in this project. The final answer is emitted as \boxed{...}
after the <think>
block.
The model can emit
<think>β¦</think>
tokens. Examples below strip this by default.
β¨ Whatβs inside
- Base:
google/gemma-3-270m
- Objective:
- SFT: learn prompt format + produce a boxed final answer.
- GRPO: reward
= 1.0
if the boxed answer matches ground truth (numeric orsympy
-equivalent), else0.0
, minus a small penalty proportional to tokens inside<think>β¦</think>
. KL regularization to the SFT reference.
π§ Prompt & output format
Training/eval wrapper (no <final>
tag):
<prompt>
<YOUR_QUESTION_HERE>
</prompt>
<think>
β¦(internal scratch work)β¦
</think>
\boxed{FINAL_ANSWER}
- SFT builder resembled:
format_sft_example(question, reasoning, final_answer)
- RL prompts use only:
<prompt>β¦</prompt>\n<think>\n
and expect the model to write reasoning +\boxed{...}
.
π Quickstart
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, re
MODEL_ID = "nirav-madhani/gemma3-270m-grpo-math" # change if different
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
attn_implementation="eager",
).to(device).eval()
BOX_RE = re.compile(r"\\boxed\\{([^{}]+)\\}")
def generate(question, max_new_tokens=160, temperature=0.2, top_p=0.95, return_boxed=True, show_think=False):
prompt = f"<prompt>\n{question}\n</prompt>\n<think>\n"
inputs = tok(prompt, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature,
top_p=top_p,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
use_cache=True,
)
text = tok.decode(out[0], skip_special_tokens=True)
if not show_think:
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.S)
if return_boxed:
m = BOX_RE.search(text)
return m.group(1).strip() if m else text.strip()
return text
print(generate("If 3x + 5 = 17, what is x?"))
ποΈ Training recipe (summary)
Stage 1 β SFT
transformers==4.55.x
,trl==0.21.0
- Tokenizer: Gemma-3 fast;
pad_token = eos_token
Trainer
+DataCollatorForLanguageModeling
max_seq_length = prompt_len + completion_len
; small per-device batch with grad accumulation; BF16/FP16 on Ampere, else FP32.
Stage 2 β RL (GRPO)
- TRL GRPO (
trl==0.21.0
) - Policy & reference initialized from SFT
- Reward:
+1.0
if\boxed{...}
equals ground truth (float tolerance orsympy
equivalence)βΞ» * (#tokens inside <think>β¦</think>)
- Knobs (fit ~15 GB VRAM; tune per GPU):
num_generations (K)
: 2β8per_device_train_batch_size
: 1β2gradient_accumulation_steps
: 4β8
(ensurebatch * accum * world_size
is divisible by K)max_prompt_length
: ~160β256max_completion_length
: ~128β192beta
(KL): ~0.02attn_implementation="eager"
; enableuse_cache=True
if you have headroom
Checkpointing
- Checkpoints saved every N steps; keep last 3; persisted to Google Drive or
/kaggle/working
. - Inference loader grabs newest
checkpoint-XXXX/
, else RL root β SFT β base.
π Evaluation
Early-stage. Evaluate on your math split or GSM-style test by extracting \boxed{β¦}
and checking numeric or sympy
-equivalence. Track:
- reward mean/std, exact-match of final answers,
- KL vs. reference,
- output length and βthinkβ token counts.
βοΈ License & usage
- Base license: Gemma models are under the Gemma license. This derivative remains subject to those terms.
- Repo metadata sets
license: gemma
. Review Gemmaβs terms before commercial use/redistribution.
π Limitations & risks
- 270M params is very small; expect brittleness outside narrow math tasks.
- Hidden reasoning can be wrong; we hide it by default.
- No built-in safety filtering.
π§© Repro notes
- Colab/Kaggle friendly; use
attn_implementation="eager"
to avoid FA mismatches. - GRPO progress barβs βTraining Lossβ can be
0.0
β monitor reward/KL/length. - Env tips:
- set
TRANSFORMERS_NO_TORCHVISION=1
, - ensure compatible
numpy
/scikit-learn
on Kaggle iftransformers.generation
pulls sklearn.
- set
π Acknowledgements
- Base weights: Google
google/gemma-3-270m
- RL training: TRL (
trl==0.21.0
)
π£ Citation
@software{nirav_gemma3_270m_grpo_math_2025,
title = {Gemma-3-270M GRPO (math + CoT)},
author = {Nirav Madhani},
year = {2025},
url = {https://huggingface.co/nirav-madhani/gemma3-270m-grpo-math}
}
- Downloads last month
- 7
Model tree for Nirav-Madhani/gemma3-270m-grpo-math
Base model
google/gemma-3-270m