from transformers import (
    AutoModelForImageClassification,
    AutoImageProcessor,
    TrainingArguments,
    Trainer,
)
from datasets import load_dataset
import os

def train():
    # Load dataset
    dataset = load_dataset("ylecun/mnist")

    # Load processor and apply preprocessing to the dataset
    processor = AutoImageProcessor.from_pretrained("SupremoUGH/image-classification-model")
    
    def process(examples):
        images = [img.convert("RGB") for img in examples["image"]]
        inputs = processor(images=images, return_tensors="pt")
        inputs["labels"] = examples["label"]
        return inputs
    
    dataset.set_transform(process) # Sometimes `map` instead of `set_transform`
    
    # Load model and train it with certain training arguments
    model = AutoModelForImageClassification.from_pretrained("SupremoUGH/image-classification-model")
    training_args = TrainingArguments(
        output_dir="./results",
        remove_unused_columns=False,  # Preserve input data
        per_device_train_batch_size=16,  # Reduce batch size for efficiency
        eval_strategy="steps",
        num_train_epochs=3,
        fp16=False,  # Disable fp16 mixed precision
        save_steps=500,
        eval_steps=500,
        logging_steps=100,
        learning_rate=2e-4,
        push_to_hub=False,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"], # Sometimes called "validation"
    )
    trainer.train()

    # Save fine-tuned model
    save_dir = "./saved_model"
    os.makedirs(save_dir, exist_ok=True)
    model.save_pretrained(save_dir)
    print(f"Model saved to {save_dir}")


if __name__ == "__main__":
    train()