allow packaging tools to override CUDA settings

Summary: This makes sure circle ci builds work with cuda even on machines with no gpu.

Reviewed By: gkioxari

Differential Revision: D19543957

fbshipit-source-id: 9cbfcd4fca22ebe89434ffa71c25d75dd18d2eb6
This commit is contained in:
Jeremy Reizenstein 2020-01-27 06:18:29 -08:00 committed by Facebook Github Bot
parent 674ee44ca8
commit 244b7eb80e

View File

@ -23,21 +23,31 @@ def get_extensions():
extra_compile_args = {"cxx": ["-std=c++17"]} extra_compile_args = {"cxx": ["-std=c++17"]}
define_macros = [] define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None: force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda:
extension = CUDAExtension extension = CUDAExtension
sources += source_cuda sources += source_cuda
define_macros += [("WITH_CUDA", None)] define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [ nvcc_args = [
"-DCUDA_HAS_FP16=1", "-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__", "-D__CUDA_NO_HALF2_OPERATORS__",
] ]
nvcc_flags_env = os.getenv("NVCC_FLAGS", "")
if nvcc_flags_env != "":
nvcc_args.extend(nvcc_flags_env.split(" "))
# It's better if pytorch can do this by default .. # It's better if pytorch can do this by default ..
CC = os.environ.get("CC", None) CC = os.environ.get("CC", None)
if CC is not None: if CC is not None:
extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) CC_arg = "-ccbin={}".format(CC)
if CC_arg not in nvcc_args:
if any(arg.startswith("-ccbin") for arg in nvcc_args):
raise ValueError("Inconsistent ccbins")
nvcc_args.append(CC_arg)
extra_compile_args["nvcc"] = nvcc_args
sources = [os.path.join(extensions_dir, s) for s in sources] sources = [os.path.join(extensions_dir, s) for s in sources]