|
import torch |
|
from torch import nn |
|
from whisper.model import Linear, Conv1d, LayerNorm, Whisper |
|
|
|
|
|
def replace_modules(model: nn.Module, only_linear: bool = False): |
|
""" |
|
Replace ``Linear``/``Conv1d``/``LayerNorm`` from :class:`whisper.model` with equivalent module in |
|
:class:`torch.nn`. |
|
""" |
|
for m in model.__dict__.get('_modules', []): |
|
module = model.__getattr__(m) |
|
update = True |
|
if isinstance(module, Linear): |
|
model.__setattr__(m, nn.Linear(module.in_features, module.out_features, |
|
bias=module.bias is not None)) |
|
elif not only_linear and isinstance(module, Conv1d): |
|
model.__setattr__(m, nn.Conv1d(module.in_channels, module.out_channels, |
|
kernel_size=module.kernel_size, |
|
stride=module.stride, |
|
padding=module.padding, |
|
bias=module.bias is not None)) |
|
elif not only_linear and isinstance(module, LayerNorm): |
|
model.__setattr__(m, nn.LayerNorm(module.normalized_shape[0])) |
|
else: |
|
update = False |
|
replace_modules(module) |
|
|
|
if update: |
|
model.__getattr__(m).load_state_dict(module.state_dict()) |
|
|
|
|
|
def ptdq_linear(model: "Whisper"): |
|
""" |
|
Apply Dynamic Quantization to instance of :class:`whisper.model.Whisper`. |
|
""" |
|
model.cpu() |
|
replace_modules(model, only_linear=True) |
|
torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8, inplace=True) |
|
setattr(model, 'dq', True) |
|
|