File size: 3,377 Bytes
867532a
df2ac53
867532a
d41c4d4
 
 
 
 
 
 
 
 
8e899a8
d41c4d4
 
 
c10f559
d41c4d4
 
 
c10f559
 
d41c4d4
 
 
 
 
 
 
c10f559
 
 
 
 
 
 
 
 
 
 
 
d41c4d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e899a8
 
d41c4d4
 
8e899a8
d41c4d4
 
8e899a8
 
 
d41c4d4
 
 
 
8e899a8
 
 
 
 
 
 
 
d41c4d4
867532a
 
 
 
 
 
 
 
 
 
df2ac53
 
867532a
df2ac53
 
 
867532a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python
import zipfile
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoImageProcessor, AutoModel
import tqdm


class ImageDataset(Dataset):
    def __init__(self, metadata_path, images_root_path, model_name="./dinov2"):
        self.metadata_path = metadata_path
        self.metadata = pd.read_csv(metadata_path)
        self.images_root_path = images_root_path
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        image_path = Path(self.images_root_path) / row.filename

        model_inputs = self.processor(
            images=Image.open(image_path), return_tensors="pt"
        )
        with torch.no_grad():
            outputs = self.model(**model_inputs)
            last_hidden_states = outputs.last_hidden_state
        # extract the cls token
        return {
            "features": last_hidden_states[0, 0],
            "observation_id": row.observation_id,
        }


class LinearClassifier(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.model = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return torch.log_softmax(self.model(x), dim=1)


def make_submission(
    test_metadata,
    model_path,
    output_csv_path="./submission.csv",
    images_root_path="/tmp/data/private_testset",
):
    checkpoint = torch.load(model_path)
    hparams = checkpoint["hyper_parameters"]
    model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
    model.load_state_dict(checkpoint["state_dict"])
    # to gpu
    model = model.cuda()

    dataloader = DataLoader(
        ImageDataset(test_metadata, images_root_path), batch_size=250
    )
    rows = []
    for batch in tqdm.tqdm(dataloader):
        observation_ids = batch["observation_id"].cuda()
        logits = model(batch["features"].cuda())
        class_ids = torch.argmax(logits, dim=1)
        for observation_id, class_id in zip(observation_ids, class_ids):
            row = {"observation_id": int(observation_id), "class_id": int(class_id)}
            rows.append(row)
    # group by observation take the mode of the class_id
    # make sure to keep the observation id
    submission_df = (
        pd.DataFrame(rows)
        .groupby("observation_id")
        .agg(lambda x: pd.Series.mode(x)[0])
        .reset_index()
    )
    submission_df.to_csv(output_csv_path, index=False)


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--model-path", type=str, default="./last.ckpt")
    parser.add_argument(
        "--metadata-file-path", type=str, default="./SnakeCLEF2024-TestMetadata.csv"
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    with zipfile.ZipFile("/tmp/data/private_testset.zip", "r") as zip_ref:
        zip_ref.extractall("/tmp/data")

    make_submission(test_metadata=args.metadata_file_path, model_path=args.model_path)