mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 11:32:12 +08:00 
			
		
		
		
	Add interface for box prompt in SAM 2 video predictor (#174)
This PR adds an example to provide box prompt in SAM 2 as inputs to the `add_new_points_or_box` API (renamed from`add_new_points`, which is kept for backward compatibility). If `box` is provided, we add it as the first two points with labels 2 and 3, along with the user-provided points (consistent with how SAM 2 is trained). The video predictor notebook `notebooks/video_predictor_example.ipynb` is updated to include segmenting from box prompt as an example.
This commit is contained in:
		
							parent
							
								
									6ba4c65cb2
								
							
						
					
					
						commit
						6ecb5ff8d0
					
				@ -92,14 +92,14 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
 | 
				
			|||||||
    state = predictor.init_state(<your_video>)
 | 
					    state = predictor.init_state(<your_video>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # add new prompts and instantly get the output on the same frame
 | 
					    # add new prompts and instantly get the output on the same frame
 | 
				
			||||||
    frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
 | 
					    frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # propagate the prompts to get masklets throughout the video
 | 
					    # propagate the prompts to get masklets throughout the video
 | 
				
			||||||
    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
 | 
					    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
 | 
				
			||||||
        ...
 | 
					        ...
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add prompts, make refinements, and track multiple objects in videos.
 | 
					Please refer to the examples in [video_predictor_example.ipynb](./notebooks/video_predictor_example.ipynb) for details on how to add click or box prompts, make refinements, and track multiple objects in videos.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Load from 🤗 Hugging Face
 | 
					## Load from 🤗 Hugging Face
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -130,7 +130,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
 | 
				
			|||||||
    state = predictor.init_state(<your_video>)
 | 
					    state = predictor.init_state(<your_video>)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # add new prompts and instantly get the output on the same frame
 | 
					    # add new prompts and instantly get the output on the same frame
 | 
				
			||||||
    frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>):
 | 
					    frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, <your_prompts>):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # propagate the prompts to get masklets throughout the video
 | 
					    # propagate the prompts to get masklets throughout the video
 | 
				
			||||||
    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
 | 
					    for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
 | 
				
			||||||
 | 
				
			|||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@ -4,6 +4,7 @@
 | 
				
			|||||||
# This source code is licensed under the license found in the
 | 
					# This source code is licensed under the license found in the
 | 
				
			||||||
# LICENSE file in the root directory of this source tree.
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
from collections import OrderedDict
 | 
					from collections import OrderedDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
@ -163,29 +164,66 @@ class SAM2VideoPredictor(SAM2Base):
 | 
				
			|||||||
        return len(inference_state["obj_idx_to_id"])
 | 
					        return len(inference_state["obj_idx_to_id"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.inference_mode()
 | 
					    @torch.inference_mode()
 | 
				
			||||||
    def add_new_points(
 | 
					    def add_new_points_or_box(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        inference_state,
 | 
					        inference_state,
 | 
				
			||||||
        frame_idx,
 | 
					        frame_idx,
 | 
				
			||||||
        obj_id,
 | 
					        obj_id,
 | 
				
			||||||
        points,
 | 
					        points=None,
 | 
				
			||||||
        labels,
 | 
					        labels=None,
 | 
				
			||||||
        clear_old_points=True,
 | 
					        clear_old_points=True,
 | 
				
			||||||
        normalize_coords=True,
 | 
					        normalize_coords=True,
 | 
				
			||||||
 | 
					        box=None,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Add new points to a frame."""
 | 
					        """Add new points to a frame."""
 | 
				
			||||||
        obj_idx = self._obj_id_to_idx(inference_state, obj_id)
 | 
					        obj_idx = self._obj_id_to_idx(inference_state, obj_id)
 | 
				
			||||||
        point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
 | 
					        point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
 | 
				
			||||||
        mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
 | 
					        mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not isinstance(points, torch.Tensor):
 | 
					        if (points is not None) != (labels is not None):
 | 
				
			||||||
 | 
					            raise ValueError("points and labels must be provided together")
 | 
				
			||||||
 | 
					        if points is None and box is None:
 | 
				
			||||||
 | 
					            raise ValueError("at least one of points or box must be provided as input")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if points is None:
 | 
				
			||||||
 | 
					            points = torch.zeros(0, 2, dtype=torch.float32)
 | 
				
			||||||
 | 
					        elif not isinstance(points, torch.Tensor):
 | 
				
			||||||
            points = torch.tensor(points, dtype=torch.float32)
 | 
					            points = torch.tensor(points, dtype=torch.float32)
 | 
				
			||||||
        if not isinstance(labels, torch.Tensor):
 | 
					        if labels is None:
 | 
				
			||||||
 | 
					            labels = torch.zeros(0, dtype=torch.int32)
 | 
				
			||||||
 | 
					        elif not isinstance(labels, torch.Tensor):
 | 
				
			||||||
            labels = torch.tensor(labels, dtype=torch.int32)
 | 
					            labels = torch.tensor(labels, dtype=torch.int32)
 | 
				
			||||||
        if points.dim() == 2:
 | 
					        if points.dim() == 2:
 | 
				
			||||||
            points = points.unsqueeze(0)  # add batch dimension
 | 
					            points = points.unsqueeze(0)  # add batch dimension
 | 
				
			||||||
        if labels.dim() == 1:
 | 
					        if labels.dim() == 1:
 | 
				
			||||||
            labels = labels.unsqueeze(0)  # add batch dimension
 | 
					            labels = labels.unsqueeze(0)  # add batch dimension
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # If `box` is provided, we add it as the first two points with labels 2 and 3
 | 
				
			||||||
 | 
					        # along with the user-provided points (consistent with how SAM 2 is trained).
 | 
				
			||||||
 | 
					        if box is not None:
 | 
				
			||||||
 | 
					            if not clear_old_points:
 | 
				
			||||||
 | 
					                raise ValueError(
 | 
				
			||||||
 | 
					                    "cannot add box without clearing old points, since "
 | 
				
			||||||
 | 
					                    "box prompt must be provided before any point prompt "
 | 
				
			||||||
 | 
					                    "(please use clear_old_points=True instead)"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            if inference_state["tracking_has_started"]:
 | 
				
			||||||
 | 
					                warnings.warn(
 | 
				
			||||||
 | 
					                    "You are adding a box after tracking starts. SAM 2 may not always be "
 | 
				
			||||||
 | 
					                    "able to incorporate a box prompt for *refinement*. If you intend to "
 | 
				
			||||||
 | 
					                    "use box prompt as an *initial* input before tracking, please call "
 | 
				
			||||||
 | 
					                    "'reset_state' on the inference state to restart from scratch.",
 | 
				
			||||||
 | 
					                    category=UserWarning,
 | 
				
			||||||
 | 
					                    stacklevel=2,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            if not isinstance(box, torch.Tensor):
 | 
				
			||||||
 | 
					                box = torch.tensor(box, dtype=torch.float32, device=points.device)
 | 
				
			||||||
 | 
					            box_coords = box.reshape(1, 2, 2)
 | 
				
			||||||
 | 
					            box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
 | 
				
			||||||
 | 
					            box_labels = box_labels.reshape(1, 2)
 | 
				
			||||||
 | 
					            points = torch.cat([box_coords, points], dim=1)
 | 
				
			||||||
 | 
					            labels = torch.cat([box_labels, labels], dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if normalize_coords:
 | 
					        if normalize_coords:
 | 
				
			||||||
            video_H = inference_state["video_height"]
 | 
					            video_H = inference_state["video_height"]
 | 
				
			||||||
            video_W = inference_state["video_width"]
 | 
					            video_W = inference_state["video_width"]
 | 
				
			||||||
@ -268,6 +306,10 @@ class SAM2VideoPredictor(SAM2Base):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
        return frame_idx, obj_ids, video_res_masks
 | 
					        return frame_idx, obj_ids, video_res_masks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_new_points(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        """Deprecated method. Please use `add_new_points_or_box` instead."""
 | 
				
			||||||
 | 
					        return self.add_new_points_or_box(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.inference_mode()
 | 
					    @torch.inference_mode()
 | 
				
			||||||
    def add_new_mask(
 | 
					    def add_new_mask(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -548,7 +590,7 @@ class SAM2VideoPredictor(SAM2Base):
 | 
				
			|||||||
            storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
 | 
					            storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
 | 
				
			||||||
            # Find all the frames that contain temporary outputs for any objects
 | 
					            # Find all the frames that contain temporary outputs for any objects
 | 
				
			||||||
            # (these should be the frames that have just received clicks for mask inputs
 | 
					            # (these should be the frames that have just received clicks for mask inputs
 | 
				
			||||||
            # via `add_new_points` or `add_new_mask`)
 | 
					            # via `add_new_points_or_box` or `add_new_mask`)
 | 
				
			||||||
            temp_frame_inds = set()
 | 
					            temp_frame_inds = set()
 | 
				
			||||||
            for obj_temp_output_dict in temp_output_dict_per_obj.values():
 | 
					            for obj_temp_output_dict in temp_output_dict_per_obj.values():
 | 
				
			||||||
                temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
 | 
					                temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user