File size: 2,242 Bytes
5cd6bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import lightning.pytorch as pl
import torchmetrics
from torch.optim import AdamW
from transformers import ViTForImageClassification
from torch import nn
from transformers.optimization import get_scheduler

class LightningViTRegressor(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = ViTForImageClassification.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            num_labels=1,
        )
        self.mse = torchmetrics.MeanSquaredError()
        self.mae = torchmetrics.MeanAbsoluteError()
        self.r2_score = torchmetrics.R2Score()

    def common_step(self, step_type, batch, batch_idx):
        x,y = batch
        x = self.model(x)
        x = x.logits
        loss = nn.functional.mse_loss(x,y)
        mean_squared_error = self.mse(x,y)
        mean_absolute_error = self.mae(x,y)        
        r2_score = self.r2_score(x,y)
        to_log = {step_type + "_loss": loss, 
                  step_type + "_mse": mean_squared_error, 
                  step_type + "_mae": mean_absolute_error,
                  step_type + '_r2_score': r2_score}  # add more items if needed
        self.log_dict(to_log)
        return loss
      
    def training_step(self, batch, batch_idx):
        loss = self.common_step("train", batch, batch_idx)     
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self.common_step("val", batch, batch_idx)     
        return loss

    def test_step(self, batch, batch_idx):
        loss = self.common_step("test", batch, batch_idx)     
        return loss

    # def configure_optimizers(self):
    #     optimizer = optim.Adam(self.parameters(), lr = 1e-5)
    #     return optimizer

    def configure_optimizers(self):
        # optimizer = AdamW(optimizer_grouped_params, lr=self.hparams.lr, betas=(0.9, 0.999), eps=1e-7)
        optimizer = AdamW(self.parameters(), lr = 1e-5)
        # Configure learning rate scheduler.
        scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=self.trainer.estimated_stepping_batches)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]