mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 19:42:12 +08:00 
			
		
		
		
	Merge pull request #205 from facebookresearch/haitham/fix_hf_image_predictor
Fix HF image predictor
This commit is contained in:
		
						commit
						0db838b117
					
				@ -53,6 +53,7 @@ class SAM2AutomaticMaskGenerator:
 | 
				
			|||||||
        output_mode: str = "binary_mask",
 | 
					        output_mode: str = "binary_mask",
 | 
				
			||||||
        use_m2m: bool = False,
 | 
					        use_m2m: bool = False,
 | 
				
			||||||
        multimask_output: bool = True,
 | 
					        multimask_output: bool = True,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Using a SAM 2 model, generates masks for the entire image.
 | 
					        Using a SAM 2 model, generates masks for the entire image.
 | 
				
			||||||
@ -148,6 +149,23 @@ class SAM2AutomaticMaskGenerator:
 | 
				
			|||||||
        self.use_m2m = use_m2m
 | 
					        self.use_m2m = use_m2m
 | 
				
			||||||
        self.multimask_output = multimask_output
 | 
					        self.multimask_output = multimask_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Load a pretrained model from the Hugging Face hub.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					          model_id (str): The Hugging Face repository ID.
 | 
				
			||||||
 | 
					          **kwargs: Additional arguments to pass to the model constructor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					          (SAM2AutomaticMaskGenerator): The loaded model.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        from sam2.build_sam import build_sam2_hf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sam_model = build_sam2_hf(model_id, **kwargs)
 | 
				
			||||||
 | 
					        return cls(sam_model, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
 | 
					    def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
				
			|||||||
@ -19,6 +19,7 @@ def build_sam2(
 | 
				
			|||||||
    mode="eval",
 | 
					    mode="eval",
 | 
				
			||||||
    hydra_overrides_extra=[],
 | 
					    hydra_overrides_extra=[],
 | 
				
			||||||
    apply_postprocessing=True,
 | 
					    apply_postprocessing=True,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if apply_postprocessing:
 | 
					    if apply_postprocessing:
 | 
				
			||||||
@ -47,6 +48,7 @@ def build_sam2_video_predictor(
 | 
				
			|||||||
    mode="eval",
 | 
					    mode="eval",
 | 
				
			||||||
    hydra_overrides_extra=[],
 | 
					    hydra_overrides_extra=[],
 | 
				
			||||||
    apply_postprocessing=True,
 | 
					    apply_postprocessing=True,
 | 
				
			||||||
 | 
					    **kwargs,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    hydra_overrides = [
 | 
					    hydra_overrides = [
 | 
				
			||||||
        "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
 | 
					        "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
 | 
				
			||||||
 | 
				
			|||||||
@ -24,6 +24,7 @@ class SAM2ImagePredictor:
 | 
				
			|||||||
        mask_threshold=0.0,
 | 
					        mask_threshold=0.0,
 | 
				
			||||||
        max_hole_area=0.0,
 | 
					        max_hole_area=0.0,
 | 
				
			||||||
        max_sprinkle_area=0.0,
 | 
					        max_sprinkle_area=0.0,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Uses SAM-2 to calculate the image embedding for an image, and then
 | 
					        Uses SAM-2 to calculate the image embedding for an image, and then
 | 
				
			||||||
@ -33,8 +34,10 @@ class SAM2ImagePredictor:
 | 
				
			|||||||
          sam_model (Sam-2): The model to use for mask prediction.
 | 
					          sam_model (Sam-2): The model to use for mask prediction.
 | 
				
			||||||
          mask_threshold (float): The threshold to use when converting mask logits
 | 
					          mask_threshold (float): The threshold to use when converting mask logits
 | 
				
			||||||
            to binary masks. Masks are thresholded at 0 by default.
 | 
					            to binary masks. Masks are thresholded at 0 by default.
 | 
				
			||||||
          fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
 | 
					          max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
 | 
				
			||||||
            the maximum area of fill_hole_area in low_res_masks.
 | 
					            the maximum area of max_hole_area in low_res_masks.
 | 
				
			||||||
 | 
					          max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
 | 
				
			||||||
 | 
					            the maximum area of max_sprinkle_area in low_res_masks.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.model = sam_model
 | 
					        self.model = sam_model
 | 
				
			||||||
@ -77,7 +80,7 @@ class SAM2ImagePredictor:
 | 
				
			|||||||
        from sam2.build_sam import build_sam2_hf
 | 
					        from sam2.build_sam import build_sam2_hf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sam_model = build_sam2_hf(model_id, **kwargs)
 | 
					        sam_model = build_sam2_hf(model_id, **kwargs)
 | 
				
			||||||
        return cls(sam_model)
 | 
					        return cls(sam_model, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def set_image(
 | 
					    def set_image(
 | 
				
			||||||
 | 
				
			|||||||
@ -121,7 +121,7 @@ class SAM2VideoPredictor(SAM2Base):
 | 
				
			|||||||
        from sam2.build_sam import build_sam2_video_predictor_hf
 | 
					        from sam2.build_sam import build_sam2_video_predictor_hf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
 | 
					        sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
 | 
				
			||||||
        return cls(sam_model)
 | 
					        return sam_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _obj_id_to_idx(self, inference_state, obj_id):
 | 
					    def _obj_id_to_idx(self, inference_state, obj_id):
 | 
				
			||||||
        """Map client-side object id to model-side object index."""
 | 
					        """Map client-side object id to model-side object index."""
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user