code-reviewer / train_sft.py
Erpg12's picture
fix: fix SFT training file
894ef1a
raw
history blame
2.64 kB
# train_sft.py
import sys
import json
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
DRY_RUN = "--dry-run" in sys.argv
MODEL_ID = "Salesforce/codegen-350M-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# 1) load your JSONL
ds = load_dataset("json", data_files="data/train_dataset.jsonl", split="train")
# 2) tokenize & format
def tokenize(example):
prompt = f"DIFF:\n{example['diff']}\n\nOUTPUT FORMAT:\n"
output = example.get("comments", example.get("comment", []))
text = prompt + tokenizer.decode(tokenizer.encode(json.dumps(output, ensure_ascii=False), add_special_tokens=False))
tokens = tokenizer(text, truncation=True, max_length=512)
tokens["labels"] = tokens["input_ids"].copy()
return tokens
# In dry‐run, only map a couple examples
if DRY_RUN:
sample = ds.select(range(2))
print("Sample examples before tokenization:")
for ex in sample:
print(ex)
tokenized = sample.map(tokenize, remove_columns=sample.column_names)
print("\nAfter tokenization, examples look like:")
for ex in tokenized:
print({k: ex[k] for k in ["input_ids","labels"]})
else:
ds = ds.map(tokenize, remove_columns=ds.column_names)
# 3) configure args
training_args = TrainingArguments(
output_dir = "sft-model", # where to write checkpoints
overwrite_output_dir = True,
do_train = True, # we’re doing a train run
num_train_epochs = 3, # full passes over the data
per_device_train_batch_size = 2,
gradient_accumulation_steps = 8,
learning_rate = 2e-5,
max_steps = 500, # total optimization steps (overrides epochs)
logging_strategy = "steps",
logging_steps = 50,
save_strategy = "steps",
save_steps = 200,
fp16 = False, # no half‐precision on CPU
report_to = None, # disable WandB/others
)
# 4) instantiate trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=(tokenized if DRY_RUN else ds),
)
print(f"\n✅ DRY-RUN: Trainer instantiated:\n – model: {type(model)}\n – tokenizer: {type(tokenizer)}\n – train_dataset size: {len(tokenized if DRY_RUN else ds)}")
print(f" – SFTTrainingArguments: {training_args}")
if not DRY_RUN:
# only run the real training if you didn’t pass --dry-run
trainer.train()
trainer.save_model("sft-model")