code-reviewer / train_sft.py
Erpg12's picture
feat: upload train sft file
47f1e3a
raw
history blame
1.24 kB
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTTrainingArguments
MODEL_ID = "Salesforce/codegen-350M-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# 1) load your local JSONL
ds = load_dataset("json", data_files="data/train_dataset.jsonl", split="train")
# 2) tokenize & format
def tokenize(example):
prompt = f"DIFF:\n{example['diff']}\n\nOUTPUT FORMAT:\n"
output = example['comments']
text = prompt + tokenizer.decode(tokenizer.encode(str(output), add_special_tokens=False))
tokens = tokenizer(text, truncation=True, max_length=512)
tokens["labels"] = tokens["input_ids"].copy()
return tokens
ds = ds.map(tokenize, remove_columns=ds.column_names, batched=False)
# 3) SFT arguments
training_args = SFTTrainingArguments(
output_dir="sft-model",
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-5,
max_train_steps=500,
logging_steps=50,
save_steps=200
)
# 4) kick off the trainer
trainer = SFTTrainer(model, tokenizer, args=training_args, train_dataset=ds)
trainer.train()
trainer.save_model("sft-model")