Spaces:
Running
on
A100
Running
on
A100
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| import torch | |
| import torch.nn as nn | |
| from torch.autograd.function import Function, InplaceFunction | |
| try: | |
| from .QAct import QAct_FPin, QAct_FPout | |
| from .Qconfig import qconfig | |
| from .QFunction import * | |
| from .utils import * | |
| except: | |
| from Qconfig import qconfig | |
| from utils import * | |
| from QFunction import * | |
| from .QAct import QAct_FPin, QAct_FPout | |
| import os | |
| from copy import deepcopy | |
| import matplotlib.pyplot as plt | |
| class QMul(nn.Module): | |
| def __init__(self, args=None, layer_type=""): | |
| super().__init__() | |
| self.args = deepcopy(args) | |
| self.layer_type = layer_type | |
| assert layer_type != "", "layer_type is not defined" | |
| assert layer_type in qconfig.qmul_config, f"{layer_type} not in qgelu_config" | |
| self.apply_quantize = list_has_common_element(args.qchoice, qconfig.qmul_config[layer_type]) | |
| self.fbit = self.args.fabit if self.args.fabit else self.Ubit | |
| self.bbit = self.args.babit if self.args.babit else self.Ubit | |
| quantize_flag = format_string_with_condition( | |
| layer_type, | |
| {"apply": self.apply_quantize}, | |
| self.args.symm, | |
| self.fbit, | |
| self.bbit, | |
| {"row": self.args.row_blocksize, "col": self.args.col_blocksize}, | |
| ) | |
| print(quantize_flag) | |
| self.Mul_in1 = QAct_FPout(args, layer_type=layer_type + "_in1") | |
| self.Mul_in2 = QAct_FPout(args, layer_type=layer_type + "_in2") | |
| self.Mul_out = QAct_FPin(args, layer_type=layer_type + "_out") | |
| def forward(self, Qinput1, Qinput2, Iscale1, Iscale2): | |
| # input shape is (Batch Size, Sequence Length, Hidden Size) | |
| input1 = self.Mul_in1(Qinput1, Iscale1) | |
| input2 = self.Mul_in2(Qinput2, Iscale2) | |
| output_fp = input1 * input2 | |
| Qoutput, Oscale = self.Mul_out(output_fp) | |
| return Qoutput, Oscale | |
| if __name__ == "__main__": | |
| Sum = torch.load("tensor/QAct_nan_epoch16.pt") | |