mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 19:42:12 +08:00 
			
		
		
		
	also catch errors during installation in case CUDAExtension cannot be loaded (#175)
				
					
				
			Previously we only catch build errors in `BuildExtension` in https://github.com/facebookresearch/segment-anything-2/pull/155. However, in some cases, the `CUDAExtension` instance might not load. So in this PR, we also catch such errors for `CUDAExtension`.
This commit is contained in:
		
							parent
							
								
									6ecb5ff8d0
								
							
						
					
					
						commit
						6186d1529a
					
				
							
								
								
									
										55
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										55
									
								
								setup.py
									
									
									
									
									
								
							@ -44,55 +44,64 @@ BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
 | 
			
		||||
# You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
 | 
			
		||||
BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
 | 
			
		||||
 | 
			
		||||
# Catch and skip errors during extension building and print a warning message
 | 
			
		||||
# (note that this message only shows up under verbose build mode
 | 
			
		||||
# "pip install -v -e ." or "python setup.py build_ext -v")
 | 
			
		||||
CUDA_ERROR_MSG = (
 | 
			
		||||
    "{}\n\n"
 | 
			
		||||
    "Failed to build the SAM 2 CUDA extension due to the error above. "
 | 
			
		||||
    "You can still use SAM 2, but some post-processing functionality may be limited "
 | 
			
		||||
    "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_extensions():
 | 
			
		||||
    if not BUILD_CUDA:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    srcs = ["sam2/csrc/connected_components.cu"]
 | 
			
		||||
    compile_args = {
 | 
			
		||||
        "cxx": [],
 | 
			
		||||
        "nvcc": [
 | 
			
		||||
            "-DCUDA_HAS_FP16=1",
 | 
			
		||||
            "-D__CUDA_NO_HALF_OPERATORS__",
 | 
			
		||||
            "-D__CUDA_NO_HALF_CONVERSIONS__",
 | 
			
		||||
            "-D__CUDA_NO_HALF2_OPERATORS__",
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
    ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
 | 
			
		||||
    try:
 | 
			
		||||
        srcs = ["sam2/csrc/connected_components.cu"]
 | 
			
		||||
        compile_args = {
 | 
			
		||||
            "cxx": [],
 | 
			
		||||
            "nvcc": [
 | 
			
		||||
                "-DCUDA_HAS_FP16=1",
 | 
			
		||||
                "-D__CUDA_NO_HALF_OPERATORS__",
 | 
			
		||||
                "-D__CUDA_NO_HALF_CONVERSIONS__",
 | 
			
		||||
                "-D__CUDA_NO_HALF2_OPERATORS__",
 | 
			
		||||
            ],
 | 
			
		||||
        }
 | 
			
		||||
        ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        if BUILD_ALLOW_ERRORS:
 | 
			
		||||
            print(CUDA_ERROR_MSG.format(e))
 | 
			
		||||
            ext_modules = []
 | 
			
		||||
        else:
 | 
			
		||||
            raise e
 | 
			
		||||
 | 
			
		||||
    return ext_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BuildExtensionIgnoreErrors(BuildExtension):
 | 
			
		||||
    # Catch and skip errors during extension building and print a warning message
 | 
			
		||||
    # (note that this message only shows up under verbose build mode
 | 
			
		||||
    # "pip install -v -e ." or "python setup.py build_ext -v")
 | 
			
		||||
    ERROR_MSG = (
 | 
			
		||||
        "{}\n\n"
 | 
			
		||||
        "Failed to build the SAM 2 CUDA extension due to the error above. "
 | 
			
		||||
        "You can still use SAM 2, but some post-processing functionality may be limited "
 | 
			
		||||
        "(see https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).\n"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def finalize_options(self):
 | 
			
		||||
        try:
 | 
			
		||||
            super().finalize_options()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(self.ERROR_MSG.format(e))
 | 
			
		||||
            print(CUDA_ERROR_MSG.format(e))
 | 
			
		||||
            self.extensions = []
 | 
			
		||||
 | 
			
		||||
    def build_extensions(self):
 | 
			
		||||
        try:
 | 
			
		||||
            super().build_extensions()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(self.ERROR_MSG.format(e))
 | 
			
		||||
            print(CUDA_ERROR_MSG.format(e))
 | 
			
		||||
            self.extensions = []
 | 
			
		||||
 | 
			
		||||
    def get_ext_filename(self, ext_name):
 | 
			
		||||
        try:
 | 
			
		||||
            return super().get_ext_filename(ext_name)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(self.ERROR_MSG.format(e))
 | 
			
		||||
            print(CUDA_ERROR_MSG.format(e))
 | 
			
		||||
            self.extensions = []
 | 
			
		||||
            return "_C.so"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user