diff --git a/setup.py b/setup.py index f3add4b4..84d2665b 100755 --- a/setup.py +++ b/setup.py @@ -5,12 +5,31 @@ import glob import os import runpy import warnings +from typing import List, Optional import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension +def get_existing_ccbin(nvcc_args: List[str]) -> Optional[str]: + """ + Given a list of nvcc arguments, return the compiler if specified. + + Note from CUDA doc: Single value options and list options must have + arguments, which must follow the name of the option itself by either + one of more spaces or an equals character. + """ + last_arg = None + for arg in reversed(nvcc_args): + if arg == "-ccbin": + return last_arg + if arg.startswith("-ccbin="): + return arg[7:] + last_arg = arg + return None + + def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc") @@ -61,13 +80,18 @@ def get_extensions(): # This is needed for pytorch 1.6 and earlier. See e.g. # https://github.com/facebookresearch/pytorch3d/issues/436 - CC = os.environ.get("CC", None) - if CC is not None: - CC_arg = "-ccbin={}".format(CC) - if CC_arg not in nvcc_args: - if any(arg.startswith("-ccbin") for arg in nvcc_args): - raise ValueError("Inconsistent ccbins") - nvcc_args.append(CC_arg) + # It is harmless after https://github.com/pytorch/pytorch/pull/47404 . + # But it can be problematic in torch 1.7.0 and 1.7.1 + if torch.__version__[:4] != "1.7.": + CC = os.environ.get("CC", None) + if CC is not None: + existing_CC = get_existing_ccbin(nvcc_args) + if existing_CC is None: + CC_arg = "-ccbin={}".format(CC) + nvcc_args.append(CC_arg) + elif existing_CC != CC: + msg = f"Inconsistent ccbins: {CC} and {existing_CC}" + raise ValueError(msg) extra_compile_args["nvcc"] = nvcc_args