File size: 368 Bytes
c7abadb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os
import shutil


def find_cuda():
    cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
    if cuda_home and os.path.exists(cuda_home):
        return cuda_home

    nvcc_path = shutil.which('nvcc')
    if nvcc_path:
        cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
        return cuda_path

    return None