KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import random
import torch.nn as nn
from mmpretrain.registry import MODELS
from .modules import FlamingoLayer, GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
@MODELS.register_module()
class FlamingoLMAdapter:
"""Mixin to add cross-attention layers to a language model."""
@classmethod
def extend_init(
cls,
base: object,
vis_hidden_size: int,
cross_attn_every_n_layers: int,
use_media_placement_augmentation: bool,
only_attend_previous: bool = False,
):
"""Initialize Flamingo by adding a new gated cross attn to the decoder.
Store the media token id for computing the media locations.
Args:
base (object): Base module could be any object that represent
a instance of language model.
vis_hidden_size: (int): Hidden size of vision embeddings.
cross_attn_every_n_layers: (int): Additional cross attn for
every n layers.
use_media_placement_augmentation: (bool): Whether to use media
placement augmentation.
"""
base.set_decoder_layers_attr_name('model.layers')
gated_cross_attn_layers = nn.ModuleList([
GatedCrossAttentionBlock(
dim=base.config.hidden_size, dim_visual=vis_hidden_size) if
(layer_idx + 1) % cross_attn_every_n_layers == 0 else None
for layer_idx, _ in enumerate(base._get_decoder_layers())
])
base._set_decoder_layers(
nn.ModuleList([
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
for gated_cross_attn_layer, decoder_layer in zip(
gated_cross_attn_layers, base._get_decoder_layers())
]))
base.use_media_placement_augmentation = use_media_placement_augmentation # noqa
base.initialized_flamingo = True
base.only_attend_previous = only_attend_previous
return base
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
"""Set decoder layers attribute name."""
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
"""Get decoder layers according to attribute name."""
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
"""Set decoder layers according to attribute name."""
setattr_recursive(self, self.decoder_layers_attr_name, value)
def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward
function."""
input_ids = kwargs['input_ids'] if 'input_ids' in kwargs else input[0]
media_locations = input_ids == self.media_token_id
if self.only_attend_previous:
attend_previous = True
elif self.use_media_placement_augmentation:
attend_previous = (random.random() < 0.5)
else:
attend_previous = False
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
return super().forward(
*input, **kwargs) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(layer.is_conditioned()
for layer in self._get_decoder_layers())
def clear_conditioned_layers(self):
"""Clear all conditional layers."""
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_attend_previous(None)