File size: 919 Bytes
f62c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper
"""
import torch
import torch.nn as nn

def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
    weight_dtype = cls.weight.dtype
    cls.to(origin_dtype)

    # Convert all inputs to the original dtype
    inputs = [input.to(origin_dtype) for input in inputs]
    out = cls.original_forward(*inputs, **kwargs)

    cls.to(weight_dtype)
    return out

def convert_weight_dtype_wrapper(module, origin_dtype):
    for name, module in module.named_modules():
        if name == "":
            continue
        original_forward = module.forward
        if hasattr(module, "weight"):
            setattr(module, "original_forward", original_forward)
            setattr(
                module,
                "forward",
                lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
            )