Spaces:
Paused
Paused
"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper | |
""" | |
import torch | |
import torch.nn as nn | |
def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): | |
weight_dtype = cls.weight.dtype | |
cls.to(origin_dtype) | |
# Convert all inputs to the original dtype | |
inputs = [input.to(origin_dtype) for input in inputs] | |
out = cls.original_forward(*inputs, **kwargs) | |
cls.to(weight_dtype) | |
return out | |
def convert_weight_dtype_wrapper(module, origin_dtype): | |
for name, module in module.named_modules(): | |
if name == "": | |
continue | |
original_forward = module.forward | |
if hasattr(module, "weight"): | |
setattr(module, "original_forward", original_forward) | |
setattr( | |
module, | |
"forward", | |
lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) | |
) | |