# -*- coding: utf-8 -*- | |
from fla.utils import device_platform | |
def fp32_to_tf32_asm() -> str: | |
""" | |
Get the assembly code for converting FP32 to TF32. | |
""" | |
ASM_DICT = { | |
'nvidia': 'cvt.rna.tf32.f32 $0, $1;' | |
} | |
if device_platform in ASM_DICT: | |
return ASM_DICT[device_platform] | |
else: | |
# return empty string if the device is not supported | |
return "" | |