| """ | |
| Minimal command: | |
| python training_loop.py --hub_dir "segments/sidewalk-semantic" --push_to_hub | |
| Maximal command: | |
| python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train --push_to_hub | |
| """ | |
| import json | |
| import torch | |
| from pytorch_lightning import Trainer, callbacks, seed_everything | |
| from pytorch_lightning.loggers import WandbLogger | |
| from transformers import AutoConfig, SegformerForSemanticSegmentation, SegformerFeatureExtractor | |
| from dataloader import SidewalkSegmentationDataLoader | |
| from model import SidewalkSegmentationModel | |
| def main( | |
| hub_dir: str, | |
| batch_size: int = 32, | |
| learning_rate: float = 6e-5, | |
| model_flavor: int = 0, | |
| seed: int = 42, | |
| split: str = "train", | |
| push_to_hub: bool = False, | |
| ): | |
| seed_everything(seed) | |
| logger = WandbLogger(project="sidewalk-segmentation") | |
| gpu_value = 1 if torch.cuda.is_available() else 0 | |
| id2label_file = json.load(open("id2label.json", "r")) | |
| id2label = {int(k): v for k, v in id2label_file.items()} | |
| num_labels = len(id2label) | |
| model = SidewalkSegmentationModel( | |
| num_labels=num_labels, | |
| id2label=id2label, | |
| model_flavor=model_flavor, | |
| learning_rate=learning_rate, | |
| ) | |
| data_module = SidewalkSegmentationDataLoader( | |
| hub_dir=hub_dir, | |
| batch_size=batch_size, | |
| split=split, | |
| ) | |
| data_module.setup() | |
| checkpoint_callback = callbacks.ModelCheckpoint( | |
| dirpath="checkpoints", | |
| save_top_k=1, | |
| verbose=True, | |
| monitor="val_mean_iou", | |
| mode="max", | |
| ) | |
| early_stopping_callback = callbacks.EarlyStopping( | |
| monitor="val_mean_iou", | |
| patience=5, | |
| verbose=True, | |
| mode="max", | |
| ) | |
| trainer = Trainer( | |
| max_epochs=200, | |
| progress_bar_refresh_rate=10, | |
| gpus=gpu_value, | |
| logger=logger, | |
| callbacks=[checkpoint_callback, early_stopping_callback], | |
| deterministic=False, | |
| ) | |
| trainer.fit(model, data_module) | |
| if push_to_hub: | |
| config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}") | |
| config.num_labels = num_labels | |
| config.id2label = id2label | |
| config.label2id = {v: k for k, v in id2label_file.items()} | |
| config.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co/ChainYo/segformer-{model_flavor}-sidewalk") | |
| checkpoint_path = checkpoint_callback.best_model_filepath | |
| model = SegformerForSemanticSegmentation.from_pretrained(checkpoint_path, config=config,) | |
| model.push_to_hub(f"flavors/b{model_flavor}", repo_url=f"https://huggingface.co/ChainYo/segformer-{model_flavor}-sidewalk") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--hub_dir", type=str, required=True) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--learning_rate", type=float, default=6e-5) | |
| parser.add_argument("--model_flavor", type=int, default=0) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--split", type=str, default="train") | |
| parser.add_argument("--push_to_hub", action="store_true") | |
| args = parser.parse_args() | |
| main( | |
| hub_dir=args.hub_dir, | |
| batch_size=args.batch_size, | |
| learning_rate=args.learning_rate, | |
| model_flavor=args.model_flavor, | |
| seed=args.seed, | |
| split=args.split, | |
| push_to_hub=args.push_to_hub, | |
| ) | |