File size: 1,662 Bytes
8718761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torch import nn
from whisper.model import Linear, Conv1d, LayerNorm, Whisper


def replace_modules(model: nn.Module, only_linear: bool = False):
    """
    Replace ``Linear``/``Conv1d``/``LayerNorm`` from :class:`whisper.model` with equivalent module in
        :class:`torch.nn`.
    """
    for m in model.__dict__.get('_modules', []):
        module = model.__getattr__(m)
        update = True
        if isinstance(module, Linear):
            model.__setattr__(m, nn.Linear(module.in_features, module.out_features,
                                           bias=module.bias is not None))
        elif not only_linear and isinstance(module, Conv1d):
            model.__setattr__(m, nn.Conv1d(module.in_channels, module.out_channels,
                                           kernel_size=module.kernel_size,
                                           stride=module.stride,
                                           padding=module.padding,
                                           bias=module.bias is not None))
        elif not only_linear and isinstance(module, LayerNorm):
            model.__setattr__(m, nn.LayerNorm(module.normalized_shape[0]))
        else:
            update = False
            replace_modules(module)

        if update:
            model.__getattr__(m).load_state_dict(module.state_dict())


def ptdq_linear(model: "Whisper"):
    """
    Apply Dynamic Quantization to instance of :class:`whisper.model.Whisper`.
    """
    model.cpu()
    replace_modules(model, only_linear=True)
    torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8, inplace=True)
    setattr(model, 'dq', True)