mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
allow packaging tools to override CUDA settings
Summary: This makes sure circle ci builds work with cuda even on machines with no gpu. Reviewed By: gkioxari Differential Revision: D19543957 fbshipit-source-id: 9cbfcd4fca22ebe89434ffa71c25d75dd18d2eb6
This commit is contained in:
parent
674ee44ca8
commit
244b7eb80e
16
setup.py
16
setup.py
@ -23,21 +23,31 @@ def get_extensions():
|
|||||||
extra_compile_args = {"cxx": ["-std=c++17"]}
|
extra_compile_args = {"cxx": ["-std=c++17"]}
|
||||||
define_macros = []
|
define_macros = []
|
||||||
|
|
||||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
|
||||||
|
if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda:
|
||||||
extension = CUDAExtension
|
extension = CUDAExtension
|
||||||
sources += source_cuda
|
sources += source_cuda
|
||||||
define_macros += [("WITH_CUDA", None)]
|
define_macros += [("WITH_CUDA", None)]
|
||||||
extra_compile_args["nvcc"] = [
|
nvcc_args = [
|
||||||
"-DCUDA_HAS_FP16=1",
|
"-DCUDA_HAS_FP16=1",
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||||
]
|
]
|
||||||
|
nvcc_flags_env = os.getenv("NVCC_FLAGS", "")
|
||||||
|
if nvcc_flags_env != "":
|
||||||
|
nvcc_args.extend(nvcc_flags_env.split(" "))
|
||||||
|
|
||||||
# It's better if pytorch can do this by default ..
|
# It's better if pytorch can do this by default ..
|
||||||
CC = os.environ.get("CC", None)
|
CC = os.environ.get("CC", None)
|
||||||
if CC is not None:
|
if CC is not None:
|
||||||
extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
|
CC_arg = "-ccbin={}".format(CC)
|
||||||
|
if CC_arg not in nvcc_args:
|
||||||
|
if any(arg.startswith("-ccbin") for arg in nvcc_args):
|
||||||
|
raise ValueError("Inconsistent ccbins")
|
||||||
|
nvcc_args.append(CC_arg)
|
||||||
|
|
||||||
|
extra_compile_args["nvcc"] = nvcc_args
|
||||||
|
|
||||||
sources = [os.path.join(extensions_dir, s) for s in sources]
|
sources = [os.path.join(extensions_dir, s) for s in sources]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user