bert-base-uncased-stsb-TTM / modeling_bert.py
memyprokotow's picture
Upload TTCompressedBertForSequenceClassification
0ecf33a
"""This module uses parts of rut5compressed. It shares the same module
structure as model used in neural network compression experiments with
rut5compressed.
"""
from functools import partial
from typing import Optional, Tuple
import numpy as np
import torch as T
from transformers import BertForSequenceClassification
from .configuration_bert import TTCompressedBertConfig
from .linalg import ttd # noqa: F401 We need this import for HF.
from .modules import TTCompressedLinear
from .util import compress_linear_tt, map_module
class TTCompressedBertForSequenceClassification(BertForSequenceClassification):
"""Class TTCompressedBertForSequenceClassification defines a BERT-based model
with compressed linear layers with TT.
"""
#LAYERS = r'/(de|en)coder/layers/\d+/fc[12]'
LAYERS = r'/encoder/layer/\d+/(intermediate|output)'
config_class = TTCompressedBertConfig
def __init__(self, config: TTCompressedBertConfig,
shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None,
rank: Optional[int] = None,
compress: bool = False):
super().__init__(config)
self.rank = rank or config.rank
self.shape = shape
if self.shape is None:
self.shape = (tuple(self.config.shape_in),
tuple(self.config.shape_out))
compress_fn = partial(compress_linear_tt, rank=self.rank, shape=self.shape)
if not compress:
compress_fn = self.convert
self.bert = map_module(self.bert, compress_fn, self.LAYERS)
def convert(self, module: T.nn.Module, path: str) -> T.nn.Module:
if isinstance(module, T.nn.Linear):
# If in_features < out_features of original linear module then this
# is extension mapping; otherwise, it is embedding mapping and we
# need to swap input and output shape.
in_shape, out_shape = self.shape
if module.in_features > module.out_features:
out_shape, in_shape = self.shape
shape = (in_shape, out_shape)
bias = module.bias is not None
return TTCompressedLinear.from_random(shape, self.rank, bias)
return module
TTCompressedBertForSequenceClassification \
.register_for_auto_class('AutoModelForSequenceClassification')