In [1]:
import json
import pytorch_lightning as pl
import torch
import torchmetrics

from datasets import load_dataset, load_metric

from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split

from tqdm.notebook import tqdm

In [4]:
class SemanticSegmentationDataset(Dataset):
    """Image segmentation datasets."""

    def __init__(
        self, 
        dataset: torch.utils.data.dataset.Subset, 
        feature_extractor = SegformerFeatureExtractor(reduce_labels=True),
    ):
        """
        Initialize the dataset with the given feature extractor and split.

        Parameters
        ----------
        hub_dir : Dataset
            The dataset to use.
        feature_extractor : FeatureExtractor, optional
            The feature extractor to use. The default is SegformerFeatureExtractor.
        """
        self.dataset = dataset
        self.feature_extractor = feature_extractor
        self.length = len(self.dataset)
        print(f"Loaded {self.length} samples.")


    def __len__(self):
        """Return the number of samples in the dataset."""
        return self.length


    def __getitem__(self, index: int):
        """Get the sample at the given index."""
        image = self.dataset[index]["pixel_values"]
        label = self.dataset[index]["label"]

        encoded_inputs = self.feature_extractor(image, label, return_tensors="pt")

        for k, v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

In [3]:
BATCH_SIZE = 32
HUB_DIR = "segments/sidewalk-semantic"
EPOCHS = 200

In [5]:
dataset = load_dataset(HUB_DIR, split="train")

train_dataset, val_dataset = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])
train_dataset = SemanticSegmentationDataset(train_dataset)
val_dataset = SemanticSegmentationDataset(val_dataset)

Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9
Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


Loaded 800 samples.
Loaded 200 samples.


In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

In [6]:
batch = next(iter(train_dataloader))

for k, v in batch.items():
    print(k, v.shape)

pixel_values torch.Size([32, 3, 512, 512])
labels torch.Size([32, 512, 512])


In [7]:
# class SidewalkSegmentationModel(pl.LightningModule):
#     def __init__(self, num_classes: int, learning_rate: float = 6e-5):
#         super().__init__()
#         self.model = SegformerForSemanticSegmentation.from_pretrained(
#             "nvidia/mit-b0", num_labels=num_classes, id2label=id2label, label2id=label2id,
#         )
#         self.learning_rate = learning_rate
#         self.metric = load_metric("mean_iou")

    
#     def forward(self, *args, **kwargs):
#         return self.model(*args, **kwargs)

    
#     def training_step(self, batch, batch_idx):
#         pixel_values = batch["pixel_values"]
#         labels = batch["labels"]

#         outputs = self(pixel_values=pixel_values, labels=labels)
#         loss, logits = outputs.loss, outputs.logits

    
#     def configure_optimizers(self) -> torch.optim.AdamW:
#         """
#         Configure the optimizer.
#         Returns
#         -------
#         torch.optim.AdamW
#             Optimizer for the model
#         """
#         return torch.optim.AdamW(model.parameters(), lr=self.learning_rate)

In [11]:
id2label_file = json.load(open("id2label.json", "r"))
id2label = {int(k): v for k, v in id2label_file.items()}
print(id2label)
label2id = {v: k for k, v in id2label_file.items()}
num_labels = len(id2label)

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0", num_labels=num_labels, id2label=id2label, label2id=label2id,
)

{0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}


Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.classifier.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.ba

In [9]:
metric = load_metric("mean_iou")

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(EPOCHS):
    for index, batch in enumerate(tqdm(train_dataloader)):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(
                logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            predicted = upsampled_logits.argmax(dim=1)
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        if index % 100 == 0:
            metrics = metric.compute(num_labels=num_labels, ignore_index=255, reduce_labels=False)
            print(f"Epoch {epoch}/{EPOCHS} Batch {index}/{len(train_dataloader)} Loss {loss.item():.4f} Metrics {metrics}")

  0%|          | 0/25 [00:00<?, ?it/s]

  acc = total_area_intersect / total_area_label


Epoch 0/200 Batch 0/25 Loss 3.5735 Metrics {'mean_iou': 0.005477220659286406, 'mean_accuracy': 0.03572801337234697, 'overall_accuracy': 0.0286087999162653, 'per_category_iou': array([2.25093889e-02, 1.39043249e-03, 7.61652063e-03, 1.08658706e-02,
       1.00475237e-02, 0.00000000e+00, 2.03193907e-04, 1.38439962e-03,
       3.06388194e-05, 2.21495741e-03, 2.81951515e-05, 0.00000000e+00,
       0.00000000e+00, 1.13529929e-04, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 7.05787860e-03, 8.48350742e-03, 2.38476991e-03,
       1.76276224e-03, 2.17775404e-03, 0.00000000e+00, 1.04358386e-03,
       4.58714414e-04, 6.41413308e-05, 1.11431787e-04, 7.99247870e-02,
       8.28409255e-03, 2.07198485e-03, 6.12199013e-04, 6.53992687e-03,
       1.42611461e-02, 5.93919660e-05, 0.00000000e+00]), 'per_category_accuracy': array([2.49789663e-02, 1.40417891e-03, 5.70403118e-02, 1.19771803e-02,
       1.35858556e-02,            nan, 2.43287079e-04, 1.42133234e-02,
       9.06344411e-03, 2.8614168

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 1/200 Batch 0/25 Loss 2.3349 Metrics {'mean_iou': 0.08249196196840773, 'mean_accuracy': 0.12242777101928329, 'overall_accuracy': 0.5340892205419953, 'per_category_iou': array([3.12479410e-01, 6.01364321e-01, 1.11562075e-02, 2.09116315e-02,
       1.95069293e-02, 0.00000000e+00, 1.09903909e-04, 1.13759900e-03,
       1.31593453e-03, 4.16974851e-01, 9.43479560e-04, 0.00000000e+00,
       0.00000000e+00, 2.71726824e-05, 7.02479634e-05, 1.71339744e-06,
       1.01240638e-04, 3.64420274e-01, 5.22553547e-03, 1.01179313e-03,
       5.40034112e-03, 6.14980455e-04, 0.00000000e+00, 1.37384839e-03,
       1.58339239e-03, 2.28333709e-05, 4.11671538e-05, 5.37431493e-01,
       6.31855647e-02, 4.88071958e-01, 5.19480641e-03, 4.67615947e-03,
       2.28171525e-02, 4.67269054e-05, 0.00000000e+00]), 'per_category_accuracy': array([4.06453404e-01, 7.78244778e-01, 2.33772733e-02, 2.33628638e-02,
       2.53283555e-02,            nan, 1.11846810e-04, 3.42729695e-03,
       9.61392116e-03, 6.59758916

  0%|          | 0/25 [00:00<?, ?it/s]

  iou = total_area_intersect / total_area_union


Epoch 2/200 Batch 0/25 Loss 1.9441 Metrics {'mean_iou': 0.1115145961827685, 'mean_accuracy': 0.1571059701593332, 'overall_accuracy': 0.6731429879739427, 'per_category_iou': array([4.45592658e-01, 7.27905738e-01, 3.53301085e-05, 1.04372074e-02,
       7.95093029e-03,            nan, 1.38316668e-06, 0.00000000e+00,
       0.00000000e+00, 4.80050947e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 4.95525330e-01, 0.00000000e+00, 1.11859236e-05,
       8.15231791e-07, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       2.30751037e-06, 0.00000000e+00, 0.00000000e+00, 6.24028521e-01,
       1.05303497e-01, 7.82230424e-01, 0.00000000e+00, 0.00000000e+00,
       9.05398708e-04, 0.00000000e+00,            nan]), 'per_category_accuracy': array([7.33079709e-01, 9.03717795e-01, 3.57307878e-05, 1.05455168e-02,
       8.21302511e-03,            nan, 1.38319984e-06, 0.00000000e+00,
       0.00000000e+00, 8.08957376e-

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 3/200 Batch 0/25 Loss 1.7086 Metrics {'mean_iou': 0.13337084618347653, 'mean_accuracy': 0.17885172041247846, 'overall_accuracy': 0.7166871222915292, 'per_category_iou': array([0.49122674, 0.77341676, 0.        , 0.1407711 , 0.00553198,
              nan, 0.        , 0.        , 0.        , 0.55152752,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.51584332, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.68907817, 0.42107346, 0.81276888,
       0.        , 0.        , 0.        , 0.        ,        nan]), 'per_category_accuracy': array([0.82412545, 0.91420412, 0.        , 0.14242155, 0.00561311,
              nan, 0.        , 0.        , 0.        , 0.84759725,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.85361568, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.    

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 4/200 Batch 0/25 Loss 1.4155 Metrics {'mean_iou': 0.15564168770290954, 'mean_accuracy': 0.20019284726893685, 'overall_accuracy': 0.7574203051314757, 'per_category_iou': array([5.71505637e-01, 8.01399129e-01, 0.00000000e+00, 4.92403441e-01,
       5.41554099e-03,            nan, 2.31577543e-07, 0.00000000e+00,
       0.00000000e+00, 5.86567051e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 5.51951880e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.24133440e-01,
       5.67064833e-01, 8.35733013e-01, 0.00000000e+00, 0.00000000e+00,
       1.49737193e-06, 0.00000000e+00,            nan]), 'per_category_accuracy': array([8.47391615e-01, 9.31621860e-01, 0.00000000e+00, 5.38104498e-01,
       5.46801709e-03,            nan, 2.31577972e-07, 0.00000000e+00,
       0.00000000e+00, 8.70377721

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 5/200 Batch 0/25 Loss 1.2700 Metrics {'mean_iou': 0.16264865782481297, 'mean_accuracy': 0.20757696789528982, 'overall_accuracy': 0.7689239961674764, 'per_category_iou': array([0.58414589, 0.81498541, 0.        , 0.57467831, 0.00689438,
              nan, 0.        , 0.        , 0.        , 0.61111453,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.55374014, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.74374917, 0.6398338 , 0.83826408,
       0.        , 0.        , 0.        , 0.        ,        nan]), 'per_category_accuracy': array([0.86273737, 0.93244115, 0.        , 0.66435561, 0.00696605,
              nan, 0.        , 0.        , 0.        , 0.87784737,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.87298997, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.    

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 6/200 Batch 0/25 Loss 1.1401 Metrics {'mean_iou': 0.16748948766093116, 'mean_accuracy': 0.21181445128118748, 'overall_accuracy': 0.7795023598828317, 'per_category_iou': array([6.09680179e-01, 8.31918538e-01, 0.00000000e+00, 6.34236889e-01,
       1.24257235e-02,            nan, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 6.35402026e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 5.64656796e-01, 0.00000000e+00, 3.65611521e-06,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.53253595e-01,
       6.35493134e-01, 8.50082556e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00,            nan]), 'per_category_accuracy': array([8.81516256e-01, 9.41665624e-01, 0.00000000e+00, 7.39806944e-01,
       1.26589939e-02,            nan, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 8.90136686

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 7/200 Batch 0/25 Loss 1.0940 Metrics {'mean_iou': 0.17137856945919164, 'mean_accuracy': 0.21495584433690398, 'overall_accuracy': 0.7881396456781556, 'per_category_iou': array([6.32120417e-01, 8.39814384e-01, 0.00000000e+00, 6.53276797e-01,
       2.36145329e-02,            nan, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 6.44513271e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 5.70917228e-01, 0.00000000e+00, 5.49670921e-05,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 7.63117039e-01,
       6.70860701e-01, 8.57197972e-01, 0.00000000e+00, 0.00000000e+00,
       5.48363063e-06, 0.00000000e+00,            nan]), 'per_category_accuracy': array([9.08405231e-01, 9.45289137e-01, 0.00000000e+00, 7.43501061e-01,
       2.43273947e-02,            nan, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 8.96271845

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 8/200 Batch 0/25 Loss 0.9463 Metrics {'mean_iou': 0.17478409568512296, 'mean_accuracy': 0.21773480038212398, 'overall_accuracy': 0.7927706422623841, 'per_category_iou': array([6.29223165e-01, 8.42392089e-01, 3.52355169e-04, 6.80396607e-01,
       5.21568283e-02,            nan, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 6.55000862e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 5.75092138e-01, 0.00000000e+00, 7.12203670e-04,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       5.25684195e-05, 0.00000000e+00, 0.00000000e+00, 7.79574322e-01,
       6.87965551e-01, 8.64953847e-01, 0.00000000e+00, 0.00000000e+00,
       2.62043422e-06, 0.00000000e+00,            nan]), 'per_category_accuracy': array([9.02830276e-01, 9.46588697e-01, 3.52355621e-04, 7.80701400e-01,
       5.56946042e-02,            nan, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 8.98895404

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 9/200 Batch 0/25 Loss 1.1287 Metrics {'mean_iou': 0.17933414157617142, 'mean_accuracy': 0.22211677037195932, 'overall_accuracy': 0.7957460527936768, 'per_category_iou': array([6.42468748e-01, 8.51322031e-01, 4.38690416e-02, 6.83716408e-01,
       1.05816211e-01,            nan, 3.89710724e-04, 0.00000000e+00,
       0.00000000e+00, 6.68283305e-01, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 5.86580529e-01, 0.00000000e+00, 1.00559462e-03,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       1.15812990e-04, 0.00000000e+00, 0.00000000e+00, 7.77224737e-01,
       6.87868711e-01, 8.69336423e-01, 0.00000000e+00, 0.00000000e+00,
       2.94092022e-05, 0.00000000e+00,            nan]), 'per_category_accuracy': array([9.05123309e-01, 9.50701912e-01, 4.38913906e-02, 7.96283141e-01,
       1.19598233e-01,            nan, 3.89758329e-04, 0.00000000e+00,
       0.00000000e+00, 8.97971821

  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [5]:
import numpy as np

tokenizer = SegformerFeatureExtractor(reduce_labels=True)

dataset = load_dataset(HUB_DIR, split="train")
length = len(dataset)
print(length)

encoded_dataset = tokenizer(
    images=dataset["pixel_values"], segmentation_maps=dataset["label"], return_tensors="pt"
)

Using custom data configuration segments--sidewalk-semantic-2-f89d0845be9cadc9
Reusing dataset parquet (/home/chainyo/.cache/huggingface/datasets/segments___parquet/segments--sidewalk-semantic-2-f89d0845be9cadc9/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


1000


In [9]:
pixel_values = encoded_dataset["pixel_values"]
labels = encoded_dataset["labels"]

In [10]:
type(pixel_values)

torch.Tensor

In [25]:
from torch.utils.data import DataLoader, Dataset, random_split, Subset

class SegmentationDataset(Dataset):
    def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor):
        self.pixel_values = pixel_values
        self.labels = labels
        assert pixel_values.shape[0] == labels.shape[0]
        self.length = pixel_values.shape[0]
        print(f"Created dataset with {self.length} samples")
    

    def __len__(self):
        return self.length


    def __getitem__(self, index):
        image = self.pixel_values[index]
        label = self.labels[index]

        encoded_inputs = BatchFeature({"pixel_values": image, "labels": label})

        return encoded_inputs

In [26]:
segmentation_dataset = SegmentationDataset(pixel_values, labels)

Created dataset with 1000 samples


In [27]:
test = segmentation_dataset[0]

In [28]:
test

{'pixel_values': tensor([[[ 0.0912, -0.1828, -0.1143,  ..., -0.5253, -0.5424, -0.6623],
         [-0.0116, -0.2342, -0.1486,  ..., -0.5424, -0.6109, -0.7137],
         [ 0.0398, -0.1828, -0.1314,  ..., -0.5082, -0.6281, -0.7650],
         ...,
         [ 1.1529,  1.3927,  1.0331,  ...,  0.4166,  0.3481,  0.3309],
         [ 0.9474,  1.1358,  1.3070,  ...,  0.5022,  0.3652,  0.3994],
         [ 0.6049,  1.3413,  1.1358,  ...,  1.2728,  0.6563,  0.8104]],

        [[ 0.3102, -0.2150, -0.3200,  ..., -0.3901, -0.4426, -0.5651],
         [ 0.2227, -0.2850, -0.3725,  ..., -0.4076, -0.5126, -0.6176],
         [ 0.2577, -0.2325, -0.3550,  ..., -0.3725, -0.5301, -0.6702],
         ...,
         [ 1.1506,  1.3957,  1.0280,  ...,  0.4678,  0.3978,  0.3803],
         [ 0.9405,  1.1331,  1.3081,  ...,  0.5553,  0.4153,  0.4503],
         [ 0.5903,  1.3431,  1.1331,  ...,  1.3431,  0.7129,  0.8704]],

        [[ 0.4788, -0.0267, -0.1312,  ...,  0.0431, -0.0441, -0.2010],
         [ 0.3916, -0.0790, 

In [20]:
segmentation_dataset[0][1].squeeze().unique()

tensor([ 0,  1,  2,  4,  6,  9, 17, 18, 19, 23, 24, 25, 27, 31, 32])

In [22]:
indices = np.arange(length)
train_indices, val_indices = random_split(indices, [int(length * 0.8), int(length * 0.2)])

train_dataset = SegmentationDataset(encoded_dataset, train_indices)
val_dataset = SegmentationDataset(encoded_dataset, val_indices)

In [26]:
train_dataset[0]

TypeError: list indices must be integers or slices, not str

In [23]:
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_dataloader = DataLoader(val_dataset, batch_size=2)

In [24]:
batch = next(iter(train_dataloader))

TypeError: list indices must be integers or slices, not str

In [32]:
import json
from transformers import AutoConfig

id2label_file = json.load(open("id2label.json", "r"))
id2label = {int(k): v for k, v in id2label_file.items()}
num_labels = len(id2label)

config = AutoConfig.from_pretrained(f"nvidia/mit-b0")
config.num_labels = num_labels
config.id2label = id2label
config.label2id = {v: k for k, v in id2label_file.items()}
config.push_to_hub(".", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")

model = SegformerForSemanticSegmentation.from_pretrained(
    "/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt", 
    config=config,
)
model.push_to_hub(".", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")

/home/chainyo/code/segformer-sidewalk/. is already a clone of https://huggingface.co/ChainYo/segformer-sidewalk. Make sure you pull the latest changes with `repo.git_pull()`.
remote: Enforcing permissions...        
remote: Allowed refs: all        
To https://huggingface.co/ChainYo/segformer-sidewalk
   5d5f276..56db83f  main -> main



TypeError: __init__() got an unexpected keyword argument 'num_labels'