File size: 401 Bytes
bec1e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# -*- coding: utf-8 -*-

from fla.utils import device_platform


def fp32_to_tf32_asm() -> str:
    """
    Get the assembly code for converting FP32 to TF32.
    """
    ASM_DICT = {
        'nvidia': 'cvt.rna.tf32.f32 $0, $1;'
    }
    if device_platform in ASM_DICT:
        return ASM_DICT[device_platform]
    else:
        # return empty string if the device is not supported
        return ""