Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn.functional as F | |
| from .gates import DiffMaskGateInput | |
| from argparse import ArgumentParser | |
| from math import sqrt | |
| from pytorch_lightning.core.optimizer import LightningOptimizer | |
| from torch import Tensor | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from transformers import ( | |
| get_constant_schedule_with_warmup, | |
| get_constant_schedule, | |
| ViTForImageClassification, | |
| ) | |
| from transformers.models.vit.configuration_vit import ViTConfig | |
| from typing import Optional, Union | |
| from utils.getters_setters import vit_getter, vit_setter | |
| from utils.metrics import accuracy_precision_recall_f1 | |
| from utils.optimizer import LookaheadAdam | |
| class ImageInterpretationNet(pl.LightningModule): | |
| def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: | |
| parser = parent_parser.add_argument_group("Vision DiffMask") | |
| parser.add_argument( | |
| "--alpha", | |
| type=float, | |
| default=20.0, | |
| help="Initial value for the Lagrangian", | |
| ) | |
| parser.add_argument( | |
| "--lr", | |
| type=float, | |
| default=2e-5, | |
| help="Learning rate for DiffMask.", | |
| ) | |
| parser.add_argument( | |
| "--eps", | |
| type=float, | |
| default=0.1, | |
| help="KL divergence tolerance.", | |
| ) | |
| parser.add_argument( | |
| "--no_placeholder", | |
| action="store_true", | |
| help="Whether to not use placeholder", | |
| ) | |
| parser.add_argument( | |
| "--lr_placeholder", | |
| type=float, | |
| default=1e-3, | |
| help="Learning for mask vectors.", | |
| ) | |
| parser.add_argument( | |
| "--lr_alpha", | |
| type=float, | |
| default=0.3, | |
| help="Learning rate for lagrangian optimizer.", | |
| ) | |
| parser.add_argument( | |
| "--mul_activation", | |
| type=float, | |
| default=15.0, | |
| help="Value to multiply gate activations.", | |
| ) | |
| parser.add_argument( | |
| "--add_activation", | |
| type=float, | |
| default=8.0, | |
| help="Value to add to gate activations.", | |
| ) | |
| parser.add_argument( | |
| "--weighted_layer_distribution", | |
| action="store_true", | |
| help="Whether to use a weighted distribution when picking a layer in DiffMask forward.", | |
| ) | |
| return parent_parser | |
| # Declare variables that will be initialized later | |
| model: ViTForImageClassification | |
| def __init__( | |
| self, | |
| model_cfg: ViTConfig, | |
| alpha: float = 1, | |
| lr: float = 3e-4, | |
| eps: float = 0.1, | |
| eps_valid: float = 0.8, | |
| acc_valid: float = 0.75, | |
| lr_placeholder: float = 1e-3, | |
| lr_alpha: float = 0.3, | |
| mul_activation: float = 10.0, | |
| add_activation: float = 5.0, | |
| placeholder: bool = True, | |
| weighted_layer_pred: bool = False, | |
| ): | |
| """A PyTorch Lightning Module for the VisionDiffMask model on the Vision Transformer. | |
| Args: | |
| model_cfg (ViTConfig): the configuration of the Vision Transformer model | |
| alpha (float): the initial value for the Lagrangian | |
| lr (float): the learning rate for the DiffMask gates | |
| eps (float): the tolerance for the KL divergence | |
| eps_valid (float): the tolerance for the KL divergence in the validation step | |
| acc_valid (float): the accuracy threshold for the validation step | |
| lr_placeholder (float): the learning rate for the learnable masking embeddings | |
| lr_alpha (float): the learning rate for the Lagrangian | |
| mul_activation (float): the value to multiply the gate activations by | |
| add_activation (float): the value to add to the gate activations | |
| placeholder (bool): whether to use placeholder embeddings or a zero vector | |
| weighted_layer_pred (bool): whether to use a weighted distribution when picking a layer | |
| """ | |
| super().__init__() | |
| # Save the hyperparameters | |
| self.save_hyperparameters() | |
| # Create DiffMask instance | |
| self.gate = DiffMaskGateInput( | |
| hidden_size=model_cfg.hidden_size, | |
| hidden_attention=model_cfg.hidden_size // 4, | |
| num_hidden_layers=model_cfg.num_hidden_layers + 2, | |
| max_position_embeddings=1, | |
| mul_activation=mul_activation, | |
| add_activation=add_activation, | |
| placeholder=placeholder, | |
| ) | |
| # Create the Lagrangian values for the dual optimization | |
| self.alpha = torch.nn.ParameterList( | |
| [ | |
| torch.nn.Parameter(torch.ones(()) * alpha) | |
| for _ in range(model_cfg.num_hidden_layers + 2) | |
| ] | |
| ) | |
| # Register buffers for running metrics | |
| self.register_buffer( | |
| "running_acc", torch.ones((model_cfg.num_hidden_layers + 2,)) | |
| ) | |
| self.register_buffer( | |
| "running_l0", torch.ones((model_cfg.num_hidden_layers + 2,)) | |
| ) | |
| self.register_buffer( | |
| "running_steps", torch.zeros((model_cfg.num_hidden_layers + 2,)) | |
| ) | |
| def set_vision_transformer(self, model: ViTForImageClassification): | |
| """Set the Vision Transformer model to be used with this module.""" | |
| # Save the model instance as a class attribute | |
| self.model = model | |
| # Freeze the model's parameters | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| def forward_explainer( | |
| self, x: Tensor, attribution: bool = False | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, int, int]: | |
| """Performs a forward pass through the explainer (VisionDiffMask) model.""" | |
| # Get the original logits and hidden states from the model | |
| logits_orig, hidden_states = vit_getter(self.model, x) | |
| # Add [CLS] token to deal with shape mismatch in self.gate() call | |
| patch_embeddings = hidden_states[0] | |
| batch_size = len(patch_embeddings) | |
| cls_tokens = self.model.vit.embeddings.cls_token.expand(batch_size, -1, -1) | |
| hidden_states[0] = torch.cat((cls_tokens, patch_embeddings), dim=1) | |
| # Select the layer to generate the mask from in this pass | |
| n_hidden = len(hidden_states) | |
| if self.hparams.weighted_layer_pred: | |
| # If weighted layer prediction is enabled, use a weighted distribution | |
| # instead of uniformly picking a layer after a number of steps | |
| low_weight = ( | |
| lambda i: self.running_acc[i] > 0.75 | |
| and self.running_l0[i] < 0.1 | |
| and self.running_steps[i] > 100 | |
| ) | |
| layers = torch.tensor(list(range(n_hidden))) | |
| p = torch.tensor([0.1 if low_weight(i) else 1 for i in range(n_hidden)]) | |
| p = p / p.sum() | |
| idx = p.multinomial(num_samples=1) | |
| layer_pred = layers[idx].item() | |
| else: | |
| layer_pred = torch.randint(n_hidden, ()).item() | |
| # Set the layer to drop to 0, since we are only interested in masking the input | |
| layer_drop = 0 | |
| ( | |
| new_hidden_state, | |
| gates, | |
| expected_L0, | |
| gates_full, | |
| expected_L0_full, | |
| ) = self.gate( | |
| hidden_states=hidden_states, | |
| layer_pred=None | |
| if attribution | |
| else layer_pred, # if attribution, we get all the hidden states | |
| ) | |
| # Create the list of the new hidden states for the new forward pass | |
| new_hidden_states = ( | |
| [None] * layer_drop | |
| + [new_hidden_state] | |
| + [None] * (n_hidden - layer_drop - 1) | |
| ) | |
| # Get the new logits from the masked input | |
| logits, _ = vit_setter(self.model, x, new_hidden_states) | |
| return ( | |
| logits, | |
| logits_orig, | |
| gates, | |
| expected_L0, | |
| gates_full, | |
| expected_L0_full, | |
| layer_drop, | |
| layer_pred, | |
| ) | |
| def get_mask(self, x: Tensor, | |
| idx: int = -1, | |
| aggregated_mask: bool = True, | |
| ) -> dict[str, Tensor]: | |
| """ | |
| Generates a mask for the given input. | |
| Args: | |
| x: the input to generate the mask for | |
| idx: the index of the layer to generate the mask from | |
| aggregated_mask: whether to use an aggregative mask from each layer | |
| Returns: | |
| a dictionary containing the mask, kl divergence and the predicted class | |
| """ | |
| # Pass from forward explainer with attribution=True | |
| ( | |
| logits, | |
| logits_orig, | |
| gates, | |
| expected_L0, | |
| gates_full, | |
| expected_L0_full, | |
| layer_drop, | |
| layer_pred, | |
| ) = self.forward_explainer(x, attribution=True) | |
| # Calculate KL-divergence | |
| kl_div = torch.distributions.kl_divergence( | |
| torch.distributions.Categorical(logits=logits_orig), | |
| torch.distributions.Categorical(logits=logits), | |
| ) | |
| # Get predicted class | |
| pred_class = logits.argmax(-1) | |
| # Calculate mask | |
| if aggregated_mask: | |
| mask = expected_L0_full[:, :, idx].exp() | |
| else: | |
| mask = gates_full[:, :, idx] | |
| mask = mask[:, 1:] | |
| C, H, W = x.shape[1:] # channels, height, width | |
| B, P = mask.shape # batch, patches | |
| N = int(sqrt(P)) # patches per side | |
| S = int(H / N) # patch size | |
| # Reshape mask to match input shape | |
| mask = mask.reshape(B, 1, N, N) | |
| mask = F.interpolate(mask, scale_factor=S) | |
| mask = mask.reshape(B, H, W) | |
| return {"mask": mask, "kl_div": kl_div, "pred_class": pred_class, | |
| "logits": logits, "logits_orig": logits_orig} | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.model(x).logits | |
| def training_step(self, batch: tuple[Tensor, Tensor], *args, **kwargs) -> dict: | |
| # Unpack the batch | |
| x, y = batch | |
| # Pass the batch through the explainer (VisionDiffMask) model | |
| ( | |
| logits, | |
| logits_orig, | |
| gates, | |
| expected_L0, | |
| gates_full, | |
| expected_L0_full, | |
| layer_drop, | |
| layer_pred, | |
| ) = self.forward_explainer(x) | |
| # Calculate the KL-divergence loss term | |
| loss_c = ( | |
| torch.distributions.kl_divergence( | |
| torch.distributions.Categorical(logits=logits_orig), | |
| torch.distributions.Categorical(logits=logits), | |
| ) | |
| - self.hparams.eps | |
| ) | |
| # Calculate the L0 loss term | |
| loss_g = expected_L0.mean(-1) | |
| # Calculate the full loss term | |
| loss = self.alpha[layer_pred] * loss_c + loss_g | |
| # Calculate the accuracy | |
| acc, _, _, _ = accuracy_precision_recall_f1( | |
| logits.argmax(-1), logits_orig.argmax(-1), average=True | |
| ) | |
| # Calculate the average L0 loss | |
| l0 = expected_L0.exp().mean(-1) | |
| outputs_dict = { | |
| "loss_c": loss_c.mean(-1), | |
| "loss_g": loss_g.mean(-1), | |
| "alpha": self.alpha[layer_pred].mean(-1), | |
| "acc": acc, | |
| "l0": l0.mean(-1), | |
| "layer_pred": layer_pred, | |
| "r_acc": self.running_acc[layer_pred], | |
| "r_l0": self.running_l0[layer_pred], | |
| "r_steps": self.running_steps[layer_pred], | |
| "debug_loss": loss.mean(-1), | |
| } | |
| outputs_dict = { | |
| "loss": loss.mean(-1), | |
| **outputs_dict, | |
| "log": outputs_dict, | |
| "progress_bar": outputs_dict, | |
| } | |
| self.log( | |
| "loss", outputs_dict["loss"], on_step=True, on_epoch=True, prog_bar=True | |
| ) | |
| self.log( | |
| "loss_c", outputs_dict["loss_c"], on_step=True, on_epoch=True, prog_bar=True | |
| ) | |
| self.log( | |
| "loss_g", outputs_dict["loss_g"], on_step=True, on_epoch=True, prog_bar=True | |
| ) | |
| self.log("acc", outputs_dict["acc"], on_step=True, on_epoch=True, prog_bar=True) | |
| self.log("l0", outputs_dict["l0"], on_step=True, on_epoch=True, prog_bar=True) | |
| self.log( | |
| "alpha", outputs_dict["alpha"], on_step=True, on_epoch=True, prog_bar=True | |
| ) | |
| outputs_dict = { | |
| "{}{}".format("" if self.training else "val_", k): v | |
| for k, v in outputs_dict.items() | |
| } | |
| if self.training: | |
| self.running_acc[layer_pred] = ( | |
| self.running_acc[layer_pred] * 0.9 + acc * 0.1 | |
| ) | |
| self.running_l0[layer_pred] = ( | |
| self.running_l0[layer_pred] * 0.9 + l0.mean(-1) * 0.1 | |
| ) | |
| self.running_steps[layer_pred] += 1 | |
| return outputs_dict | |
| def validation_epoch_end(self, outputs: list[dict]): | |
| outputs_dict = { | |
| k: [e[k] for e in outputs if k in e] | |
| for k in ("val_loss_c", "val_loss_g", "val_acc", "val_l0") | |
| } | |
| outputs_dict = {k: sum(v) / len(v) for k, v in outputs_dict.items()} | |
| outputs_dict["val_loss_c"] += self.hparams.eps | |
| outputs_dict = { | |
| "val_loss": outputs_dict["val_l0"] | |
| if outputs_dict["val_loss_c"] <= self.hparams.eps_valid | |
| and outputs_dict["val_acc"] >= self.hparams.acc_valid | |
| else torch.full_like(outputs_dict["val_l0"], float("inf")), | |
| **outputs_dict, | |
| "log": outputs_dict, | |
| } | |
| return outputs_dict | |
| def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]: | |
| optimizers = [ | |
| LookaheadAdam( | |
| params=[ | |
| { | |
| "params": self.gate.g_hat.parameters(), | |
| "lr": self.hparams.lr, | |
| }, | |
| { | |
| "params": self.gate.placeholder.parameters() | |
| if isinstance(self.gate.placeholder, torch.nn.ParameterList) | |
| else [self.gate.placeholder], | |
| "lr": self.hparams.lr_placeholder, | |
| }, | |
| ], | |
| # centered=True, # this is for LookaheadRMSprop | |
| ), | |
| LookaheadAdam( | |
| params=[self.alpha] | |
| if isinstance(self.alpha, torch.Tensor) | |
| else self.alpha.parameters(), | |
| lr=self.hparams.lr_alpha, | |
| ), | |
| ] | |
| schedulers = [ | |
| { | |
| "scheduler": get_constant_schedule_with_warmup(optimizers[0], 12 * 100), | |
| "interval": "step", | |
| }, | |
| get_constant_schedule(optimizers[1]), | |
| ] | |
| return optimizers, schedulers | |
| def optimizer_step( | |
| self, | |
| epoch: int, | |
| batch_idx: int, | |
| optimizer: Union[Optimizer, LightningOptimizer], | |
| optimizer_idx: int = 0, | |
| optimizer_closure: Optional[callable] = None, | |
| on_tpu: bool = False, | |
| using_native_amp: bool = False, | |
| using_lbfgs: bool = False, | |
| ): | |
| # Optimizer 0: Minimize loss w.r.t. DiffMask's parameters | |
| if optimizer_idx == 0: | |
| # Gradient ascent on the model's parameters | |
| optimizer.step(closure=optimizer_closure) | |
| optimizer.zero_grad() | |
| for g in optimizer.param_groups: | |
| for p in g["params"]: | |
| p.grad = None | |
| # Optimizer 1: Maximize loss w.r.t. the Langrangian | |
| elif optimizer_idx == 1: | |
| # Reverse the sign of the Langrangian's gradients | |
| for i in range(len(self.alpha)): | |
| if self.alpha[i].grad: | |
| self.alpha[i].grad *= -1 | |
| # Gradient ascent on the Langrangian | |
| optimizer.step(closure=optimizer_closure) | |
| optimizer.zero_grad() | |
| for g in optimizer.param_groups: | |
| for p in g["params"]: | |
| p.grad = None | |
| # Clip the Lagrangian's values | |
| for i in range(len(self.alpha)): | |
| self.alpha[i].data = torch.where( | |
| self.alpha[i].data < 0, | |
| torch.full_like(self.alpha[i].data, 0), | |
| self.alpha[i].data, | |
| ) | |
| self.alpha[i].data = torch.where( | |
| self.alpha[i].data > 200, | |
| torch.full_like(self.alpha[i].data, 200), | |
| self.alpha[i].data, | |
| ) | |
| def on_save_checkpoint(self, ckpt: dict): | |
| # Remove VIT from checkpoint as we can load it dynamically | |
| keys = list(ckpt["state_dict"].keys()) | |
| for key in keys: | |
| if key.startswith("model."): | |
| del ckpt["state_dict"][key] | |