mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 19:42:12 +08:00 
			
		
		
		
	switch to a new implementation of the class SAM2VideoPredictor for per-object inference sam2/sam2_video_predictor.py
				
					
				
			In this PR, we switch to a new implementation of the class `SAM2VideoPredictor` for in sam2/sam2_video_predictor.py, which allows for independent per-object inference. Specifically, the new `SAM2VideoPredictor`: * it handles the inference of each object separately, as if we are opening a separate session for each object * it relaxes the assumption on prompting * previously if a frame receives clicks only for a subset of objects, the rest of (non-prompted) objects are assumed to be non-existing in this frame * now if a frame receives clicks only for a subset of objects, we don't make any assumptions for the remaining (non-prompted) objects * it allows adding new objects after tracking starts * (The previous implementation is backed up to `SAM2VideoPredictor` in sam2/sam2_video_predictor_legacy.py) Also, fix a small typo `APP_URL` => `API_URL` in the doc. Test plan: tested with the predictor notebook `notebooks/video_predictor_example.ipynb` and VOS script `tools/vos_inference.py`. Also tested with the demo.
This commit is contained in:
		
							parent
							
								
									3297dd0eb0
								
							
						
					
					
						commit
						c61e2475e6
					
				@ -105,7 +105,7 @@ cd demo/backend/server/
 | 
			
		||||
```bash
 | 
			
		||||
PYTORCH_ENABLE_MPS_FALLBACK=1 \
 | 
			
		||||
APP_ROOT="$(pwd)/../../../" \
 | 
			
		||||
APP_URL=http://localhost:7263 \
 | 
			
		||||
API_URL=http://localhost:7263 \
 | 
			
		||||
MODEL_SIZE=base_plus \
 | 
			
		||||
DATA_PATH="$(pwd)/../../data" \
 | 
			
		||||
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
 | 
			
		||||
        # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
 | 
			
		||||
        clear_non_cond_mem_around_input=False,
 | 
			
		||||
        # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
 | 
			
		||||
        clear_non_cond_mem_for_multi_obj=False,
 | 
			
		||||
        # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
 | 
			
		||||
        # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
 | 
			
		||||
        add_all_frames_to_correct_as_cond=False,
 | 
			
		||||
@ -38,7 +36,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        self.fill_hole_area = fill_hole_area
 | 
			
		||||
        self.non_overlap_masks = non_overlap_masks
 | 
			
		||||
        self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
 | 
			
		||||
        self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
 | 
			
		||||
        self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
@ -88,11 +85,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        inference_state["obj_id_to_idx"] = OrderedDict()
 | 
			
		||||
        inference_state["obj_idx_to_id"] = OrderedDict()
 | 
			
		||||
        inference_state["obj_ids"] = []
 | 
			
		||||
        # A storage to hold the model's tracking results and states on each frame
 | 
			
		||||
        inference_state["output_dict"] = {
 | 
			
		||||
            "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
 | 
			
		||||
            "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
 | 
			
		||||
        }
 | 
			
		||||
        # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
 | 
			
		||||
        inference_state["output_dict_per_obj"] = {}
 | 
			
		||||
        # A temporary storage to hold new outputs when user interact with a frame
 | 
			
		||||
@ -100,13 +92,8 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        inference_state["temp_output_dict_per_obj"] = {}
 | 
			
		||||
        # Frames that already holds consolidated outputs from click or mask inputs
 | 
			
		||||
        # (we directly use their consolidated outputs during tracking)
 | 
			
		||||
        inference_state["consolidated_frame_inds"] = {
 | 
			
		||||
            "cond_frame_outputs": set(),  # set containing frame indices
 | 
			
		||||
            "non_cond_frame_outputs": set(),  # set containing frame indices
 | 
			
		||||
        }
 | 
			
		||||
        # metadata for each tracking frame (e.g. which direction it's tracked)
 | 
			
		||||
        inference_state["tracking_has_started"] = False
 | 
			
		||||
        inference_state["frames_already_tracked"] = {}
 | 
			
		||||
        inference_state["frames_tracked_per_obj"] = {}
 | 
			
		||||
        # Warm up the visual backbone and cache the image feature on frame 0
 | 
			
		||||
        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
 | 
			
		||||
        return inference_state
 | 
			
		||||
@ -134,9 +121,8 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        if obj_idx is not None:
 | 
			
		||||
            return obj_idx
 | 
			
		||||
 | 
			
		||||
        # This is a new object id not sent to the server before. We only allow adding
 | 
			
		||||
        # new objects *before* the tracking starts.
 | 
			
		||||
        allow_new_object = not inference_state["tracking_has_started"]
 | 
			
		||||
        # We always allow adding new objects (including after tracking starts).
 | 
			
		||||
        allow_new_object = True
 | 
			
		||||
        if allow_new_object:
 | 
			
		||||
            # get the next object slot
 | 
			
		||||
            obj_idx = len(inference_state["obj_id_to_idx"])
 | 
			
		||||
@ -154,6 +140,7 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
                "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
 | 
			
		||||
                "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
 | 
			
		||||
            }
 | 
			
		||||
            inference_state["frames_tracked_per_obj"][obj_idx] = {}
 | 
			
		||||
            return obj_idx
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
@ -214,15 +201,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
                    "box prompt must be provided before any point prompt "
 | 
			
		||||
                    "(please use clear_old_points=True instead)"
 | 
			
		||||
                )
 | 
			
		||||
            if inference_state["tracking_has_started"]:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    "You are adding a box after tracking starts. SAM 2 may not always be "
 | 
			
		||||
                    "able to incorporate a box prompt for *refinement*. If you intend to "
 | 
			
		||||
                    "use box prompt as an *initial* input before tracking, please call "
 | 
			
		||||
                    "'reset_state' on the inference state to restart from scratch.",
 | 
			
		||||
                    category=UserWarning,
 | 
			
		||||
                    stacklevel=2,
 | 
			
		||||
                )
 | 
			
		||||
            if not isinstance(box, torch.Tensor):
 | 
			
		||||
                box = torch.tensor(box, dtype=torch.float32, device=points.device)
 | 
			
		||||
            box_coords = box.reshape(1, 2, 2)
 | 
			
		||||
@ -252,12 +230,13 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        # frame, meaning that the inputs points are to generate segments on this frame without
 | 
			
		||||
        # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
 | 
			
		||||
        # the input points will be used to correct the already tracked masks.
 | 
			
		||||
        is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
 | 
			
		||||
        obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
 | 
			
		||||
        is_init_cond_frame = frame_idx not in obj_frames_tracked
 | 
			
		||||
        # whether to track in reverse time order
 | 
			
		||||
        if is_init_cond_frame:
 | 
			
		||||
            reverse = False
 | 
			
		||||
        else:
 | 
			
		||||
            reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
 | 
			
		||||
            reverse = obj_frames_tracked[frame_idx]["reverse"]
 | 
			
		||||
        obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
        obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
 | 
			
		||||
        # Add a frame to conditioning output if it's an initial conditioning frame or
 | 
			
		||||
@ -306,7 +285,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
            inference_state,
 | 
			
		||||
            frame_idx,
 | 
			
		||||
            is_cond=is_cond,
 | 
			
		||||
            run_mem_encoder=False,
 | 
			
		||||
            consolidate_at_video_res=True,
 | 
			
		||||
        )
 | 
			
		||||
        _, video_res_masks = self._get_orig_video_res_output(
 | 
			
		||||
@ -357,12 +335,13 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        # frame, meaning that the inputs points are to generate segments on this frame without
 | 
			
		||||
        # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
 | 
			
		||||
        # the input points will be used to correct the already tracked masks.
 | 
			
		||||
        is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
 | 
			
		||||
        obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
 | 
			
		||||
        is_init_cond_frame = frame_idx not in obj_frames_tracked
 | 
			
		||||
        # whether to track in reverse time order
 | 
			
		||||
        if is_init_cond_frame:
 | 
			
		||||
            reverse = False
 | 
			
		||||
        else:
 | 
			
		||||
            reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
 | 
			
		||||
            reverse = obj_frames_tracked[frame_idx]["reverse"]
 | 
			
		||||
        obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
        obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
 | 
			
		||||
        # Add a frame to conditioning output if it's an initial conditioning frame or
 | 
			
		||||
@ -394,7 +373,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
            inference_state,
 | 
			
		||||
            frame_idx,
 | 
			
		||||
            is_cond=is_cond,
 | 
			
		||||
            run_mem_encoder=False,
 | 
			
		||||
            consolidate_at_video_res=True,
 | 
			
		||||
        )
 | 
			
		||||
        _, video_res_masks = self._get_orig_video_res_output(
 | 
			
		||||
@ -429,7 +407,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        inference_state,
 | 
			
		||||
        frame_idx,
 | 
			
		||||
        is_cond,
 | 
			
		||||
        run_mem_encoder,
 | 
			
		||||
        consolidate_at_video_res=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
@ -446,7 +423,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        # Optionally, we allow consolidating the temporary outputs at the original
 | 
			
		||||
        # video resolution (to provide a better editing experience for mask prompts).
 | 
			
		||||
        if consolidate_at_video_res:
 | 
			
		||||
            assert not run_mem_encoder, "memory encoder cannot run at video resolution"
 | 
			
		||||
            consolidated_H = inference_state["video_height"]
 | 
			
		||||
            consolidated_W = inference_state["video_width"]
 | 
			
		||||
            consolidated_mask_key = "pred_masks_video_res"
 | 
			
		||||
@ -459,30 +435,13 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        # constraints to object scores. Its "pred_masks" are prefilled with a large
 | 
			
		||||
        # negative value (NO_OBJ_SCORE) to represent missing objects.
 | 
			
		||||
        consolidated_out = {
 | 
			
		||||
            "maskmem_features": None,
 | 
			
		||||
            "maskmem_pos_enc": None,
 | 
			
		||||
            consolidated_mask_key: torch.full(
 | 
			
		||||
                size=(batch_size, 1, consolidated_H, consolidated_W),
 | 
			
		||||
                fill_value=NO_OBJ_SCORE,
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=inference_state["storage_device"],
 | 
			
		||||
            ),
 | 
			
		||||
            "obj_ptr": torch.full(
 | 
			
		||||
                size=(batch_size, self.hidden_dim),
 | 
			
		||||
                fill_value=NO_OBJ_SCORE,
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=inference_state["device"],
 | 
			
		||||
            ),
 | 
			
		||||
            "object_score_logits": torch.full(
 | 
			
		||||
                size=(batch_size, 1),
 | 
			
		||||
                # default to 10.0 for object_score_logits, i.e. assuming the object is
 | 
			
		||||
                # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
 | 
			
		||||
                fill_value=10.0,
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=inference_state["device"],
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
        empty_mask_ptr = None
 | 
			
		||||
        for obj_idx in range(batch_size):
 | 
			
		||||
            obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
 | 
			
		||||
            obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
@ -499,16 +458,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
            # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
 | 
			
		||||
            # placeholder above) and set its object pointer to be a dummy pointer.
 | 
			
		||||
            if out is None:
 | 
			
		||||
                # Fill in dummy object pointers for those objects without any inputs or
 | 
			
		||||
                # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
 | 
			
		||||
                # i.e. when we need to build the memory for tracking).
 | 
			
		||||
                if run_mem_encoder:
 | 
			
		||||
                    if empty_mask_ptr is None:
 | 
			
		||||
                        empty_mask_ptr = self._get_empty_mask_ptr(
 | 
			
		||||
                            inference_state, frame_idx
 | 
			
		||||
                        )
 | 
			
		||||
                    # fill object pointer with a dummy pointer (based on an empty mask)
 | 
			
		||||
                    consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
 | 
			
		||||
                continue
 | 
			
		||||
            # Add the temporary object output mask to consolidated output mask
 | 
			
		||||
            obj_mask = out["pred_masks"]
 | 
			
		||||
@ -524,141 +473,74 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
                    align_corners=False,
 | 
			
		||||
                )
 | 
			
		||||
                consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
 | 
			
		||||
            consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
 | 
			
		||||
            consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
 | 
			
		||||
                "object_score_logits"
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        # Optionally, apply non-overlapping constraints on the consolidated scores
 | 
			
		||||
        # and rerun the memory encoder
 | 
			
		||||
        if run_mem_encoder:
 | 
			
		||||
            device = inference_state["device"]
 | 
			
		||||
            high_res_masks = torch.nn.functional.interpolate(
 | 
			
		||||
                consolidated_out["pred_masks"].to(device, non_blocking=True),
 | 
			
		||||
                size=(self.image_size, self.image_size),
 | 
			
		||||
                mode="bilinear",
 | 
			
		||||
                align_corners=False,
 | 
			
		||||
            )
 | 
			
		||||
            if self.non_overlap_masks_for_mem_enc:
 | 
			
		||||
                high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
 | 
			
		||||
            maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
 | 
			
		||||
                inference_state=inference_state,
 | 
			
		||||
                frame_idx=frame_idx,
 | 
			
		||||
                batch_size=batch_size,
 | 
			
		||||
                high_res_masks=high_res_masks,
 | 
			
		||||
                object_score_logits=consolidated_out["object_score_logits"],
 | 
			
		||||
                is_mask_from_pts=True,  # these frames are what the user interacted with
 | 
			
		||||
            )
 | 
			
		||||
            consolidated_out["maskmem_features"] = maskmem_features
 | 
			
		||||
            consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
 | 
			
		||||
 | 
			
		||||
        return consolidated_out
 | 
			
		||||
 | 
			
		||||
    def _get_empty_mask_ptr(self, inference_state, frame_idx):
 | 
			
		||||
        """Get a dummy object pointer based on an empty mask on the current frame."""
 | 
			
		||||
        # A dummy (empty) mask with a single object
 | 
			
		||||
        batch_size = 1
 | 
			
		||||
        mask_inputs = torch.zeros(
 | 
			
		||||
            (batch_size, 1, self.image_size, self.image_size),
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=inference_state["device"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Retrieve correct image features
 | 
			
		||||
        (
 | 
			
		||||
            _,
 | 
			
		||||
            _,
 | 
			
		||||
            current_vision_feats,
 | 
			
		||||
            current_vision_pos_embeds,
 | 
			
		||||
            feat_sizes,
 | 
			
		||||
        ) = self._get_image_feature(inference_state, frame_idx, batch_size)
 | 
			
		||||
 | 
			
		||||
        # Feed the empty mask and image feature above to get a dummy object pointer
 | 
			
		||||
        current_out = self.track_step(
 | 
			
		||||
            frame_idx=frame_idx,
 | 
			
		||||
            is_init_cond_frame=True,
 | 
			
		||||
            current_vision_feats=current_vision_feats,
 | 
			
		||||
            current_vision_pos_embeds=current_vision_pos_embeds,
 | 
			
		||||
            feat_sizes=feat_sizes,
 | 
			
		||||
            point_inputs=None,
 | 
			
		||||
            mask_inputs=mask_inputs,
 | 
			
		||||
            output_dict={},
 | 
			
		||||
            num_frames=inference_state["num_frames"],
 | 
			
		||||
            track_in_reverse=False,
 | 
			
		||||
            run_mem_encoder=False,
 | 
			
		||||
            prev_sam_mask_logits=None,
 | 
			
		||||
        )
 | 
			
		||||
        return current_out["obj_ptr"]
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def propagate_in_video_preflight(self, inference_state):
 | 
			
		||||
        """Prepare inference_state and consolidate temporary outputs before tracking."""
 | 
			
		||||
        # Tracking has started and we don't allow adding new objects until session is reset.
 | 
			
		||||
        inference_state["tracking_has_started"] = True
 | 
			
		||||
        # Check and make sure that every object has received input points or masks.
 | 
			
		||||
        batch_size = self._get_obj_num(inference_state)
 | 
			
		||||
        if batch_size == 0:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                "No input points or masks are provided for any object; please add inputs first."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
 | 
			
		||||
        # add them into "output_dict".
 | 
			
		||||
        temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
 | 
			
		||||
        output_dict = inference_state["output_dict"]
 | 
			
		||||
        # "consolidated_frame_inds" contains indices of those frames where consolidated
 | 
			
		||||
        # temporary outputs have been added (either in this call or any previous calls
 | 
			
		||||
        # to `propagate_in_video_preflight`).
 | 
			
		||||
        consolidated_frame_inds = inference_state["consolidated_frame_inds"]
 | 
			
		||||
        for is_cond in [False, True]:
 | 
			
		||||
            # Separately consolidate conditioning and non-conditioning temp outputs
 | 
			
		||||
            storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
 | 
			
		||||
            # Find all the frames that contain temporary outputs for any objects
 | 
			
		||||
            # (these should be the frames that have just received clicks for mask inputs
 | 
			
		||||
            # via `add_new_points_or_box` or `add_new_mask`)
 | 
			
		||||
            temp_frame_inds = set()
 | 
			
		||||
            for obj_temp_output_dict in temp_output_dict_per_obj.values():
 | 
			
		||||
                temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
 | 
			
		||||
            consolidated_frame_inds[storage_key].update(temp_frame_inds)
 | 
			
		||||
            # consolidate the temporary output across all objects on this frame
 | 
			
		||||
            for frame_idx in temp_frame_inds:
 | 
			
		||||
                consolidated_out = self._consolidate_temp_output_across_obj(
 | 
			
		||||
                    inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
 | 
			
		||||
        for obj_idx in range(batch_size):
 | 
			
		||||
            obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
            obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
 | 
			
		||||
            for is_cond in [False, True]:
 | 
			
		||||
                # Separately consolidate conditioning and non-conditioning temp outputs
 | 
			
		||||
                storage_key = (
 | 
			
		||||
                    "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
 | 
			
		||||
                )
 | 
			
		||||
                # merge them into "output_dict" and also create per-object slices
 | 
			
		||||
                output_dict[storage_key][frame_idx] = consolidated_out
 | 
			
		||||
                self._add_output_per_object(
 | 
			
		||||
                    inference_state, frame_idx, consolidated_out, storage_key
 | 
			
		||||
                )
 | 
			
		||||
                clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
 | 
			
		||||
                    self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
 | 
			
		||||
                )
 | 
			
		||||
                if clear_non_cond_mem:
 | 
			
		||||
                    # clear non-conditioning memory of the surrounding frames
 | 
			
		||||
                    self._clear_non_cond_mem_around_input(inference_state, frame_idx)
 | 
			
		||||
                # Find all the frames that contain temporary outputs for any objects
 | 
			
		||||
                # (these should be the frames that have just received clicks for mask inputs
 | 
			
		||||
                # via `add_new_points_or_box` or `add_new_mask`)
 | 
			
		||||
                for frame_idx, out in obj_temp_output_dict[storage_key].items():
 | 
			
		||||
                    # Run memory encoder on the temporary outputs (if the memory feature is missing)
 | 
			
		||||
                    if out["maskmem_features"] is None:
 | 
			
		||||
                        high_res_masks = torch.nn.functional.interpolate(
 | 
			
		||||
                            out["pred_masks"].to(inference_state["device"]),
 | 
			
		||||
                            size=(self.image_size, self.image_size),
 | 
			
		||||
                            mode="bilinear",
 | 
			
		||||
                            align_corners=False,
 | 
			
		||||
                        )
 | 
			
		||||
                        maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
 | 
			
		||||
                            inference_state=inference_state,
 | 
			
		||||
                            frame_idx=frame_idx,
 | 
			
		||||
                            batch_size=1,  # run on the slice of a single object
 | 
			
		||||
                            high_res_masks=high_res_masks,
 | 
			
		||||
                            object_score_logits=out["object_score_logits"],
 | 
			
		||||
                            # these frames are what the user interacted with
 | 
			
		||||
                            is_mask_from_pts=True,
 | 
			
		||||
                        )
 | 
			
		||||
                        out["maskmem_features"] = maskmem_features
 | 
			
		||||
                        out["maskmem_pos_enc"] = maskmem_pos_enc
 | 
			
		||||
 | 
			
		||||
            # clear temporary outputs in `temp_output_dict_per_obj`
 | 
			
		||||
            for obj_temp_output_dict in temp_output_dict_per_obj.values():
 | 
			
		||||
                    obj_output_dict[storage_key][frame_idx] = out
 | 
			
		||||
                    if self.clear_non_cond_mem_around_input:
 | 
			
		||||
                        # clear non-conditioning memory of the surrounding frames
 | 
			
		||||
                        self._clear_obj_non_cond_mem_around_input(
 | 
			
		||||
                            inference_state, frame_idx, obj_idx
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                # clear temporary outputs in `temp_output_dict_per_obj`
 | 
			
		||||
                obj_temp_output_dict[storage_key].clear()
 | 
			
		||||
 | 
			
		||||
        # edge case: if an output is added to "cond_frame_outputs", we remove any prior
 | 
			
		||||
        # output on the same frame in "non_cond_frame_outputs"
 | 
			
		||||
        for frame_idx in output_dict["cond_frame_outputs"]:
 | 
			
		||||
            output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
        for obj_output_dict in inference_state["output_dict_per_obj"].values():
 | 
			
		||||
            # check and make sure that every object has received input points or masks
 | 
			
		||||
            obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
            if len(obj_output_dict["cond_frame_outputs"]) == 0:
 | 
			
		||||
                obj_id = self._obj_idx_to_id(inference_state, obj_idx)
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    f"No input points or masks are provided for object id {obj_id}; please add inputs first."
 | 
			
		||||
                )
 | 
			
		||||
            # edge case: if an output is added to "cond_frame_outputs", we remove any prior
 | 
			
		||||
            # output on the same frame in "non_cond_frame_outputs"
 | 
			
		||||
            for frame_idx in obj_output_dict["cond_frame_outputs"]:
 | 
			
		||||
                obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
        for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
 | 
			
		||||
            assert frame_idx in output_dict["cond_frame_outputs"]
 | 
			
		||||
            consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
 | 
			
		||||
 | 
			
		||||
        # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
 | 
			
		||||
        # with either points or mask inputs (which should be true under a correct workflow).
 | 
			
		||||
        all_consolidated_frame_inds = (
 | 
			
		||||
            consolidated_frame_inds["cond_frame_outputs"]
 | 
			
		||||
            | consolidated_frame_inds["non_cond_frame_outputs"]
 | 
			
		||||
        )
 | 
			
		||||
        input_frames_inds = set()
 | 
			
		||||
        for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
 | 
			
		||||
            input_frames_inds.update(point_inputs_per_frame.keys())
 | 
			
		||||
        for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
 | 
			
		||||
            input_frames_inds.update(mask_inputs_per_frame.keys())
 | 
			
		||||
        assert all_consolidated_frame_inds == input_frames_inds
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def propagate_in_video(
 | 
			
		||||
@ -671,21 +553,18 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        """Propagate the input points across frames to track in the entire video."""
 | 
			
		||||
        self.propagate_in_video_preflight(inference_state)
 | 
			
		||||
 | 
			
		||||
        output_dict = inference_state["output_dict"]
 | 
			
		||||
        consolidated_frame_inds = inference_state["consolidated_frame_inds"]
 | 
			
		||||
        obj_ids = inference_state["obj_ids"]
 | 
			
		||||
        num_frames = inference_state["num_frames"]
 | 
			
		||||
        batch_size = self._get_obj_num(inference_state)
 | 
			
		||||
        if len(output_dict["cond_frame_outputs"]) == 0:
 | 
			
		||||
            raise RuntimeError("No points are provided; please add points first")
 | 
			
		||||
        clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
 | 
			
		||||
            self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # set start index, end index, and processing order
 | 
			
		||||
        if start_frame_idx is None:
 | 
			
		||||
            # default: start from the earliest frame with input points
 | 
			
		||||
            start_frame_idx = min(output_dict["cond_frame_outputs"])
 | 
			
		||||
            start_frame_idx = min(
 | 
			
		||||
                t
 | 
			
		||||
                for obj_output_dict in inference_state["output_dict_per_obj"].values()
 | 
			
		||||
                for t in obj_output_dict["cond_frame_outputs"]
 | 
			
		||||
            )
 | 
			
		||||
        if max_frame_num_to_track is None:
 | 
			
		||||
            # default: track all the frames in the video
 | 
			
		||||
            max_frame_num_to_track = num_frames
 | 
			
		||||
@ -702,78 +581,53 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
            processing_order = range(start_frame_idx, end_frame_idx + 1)
 | 
			
		||||
 | 
			
		||||
        for frame_idx in tqdm(processing_order, desc="propagate in video"):
 | 
			
		||||
            # We skip those frames already in consolidated outputs (these are frames
 | 
			
		||||
            # that received input clicks or mask). Note that we cannot directly run
 | 
			
		||||
            # batched forward on them via `_run_single_frame_inference` because the
 | 
			
		||||
            # number of clicks on each object might be different.
 | 
			
		||||
            if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
 | 
			
		||||
                storage_key = "cond_frame_outputs"
 | 
			
		||||
                current_out = output_dict[storage_key][frame_idx]
 | 
			
		||||
                pred_masks = current_out["pred_masks"]
 | 
			
		||||
                if clear_non_cond_mem:
 | 
			
		||||
                    # clear non-conditioning memory of the surrounding frames
 | 
			
		||||
                    self._clear_non_cond_mem_around_input(inference_state, frame_idx)
 | 
			
		||||
            elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
 | 
			
		||||
                storage_key = "non_cond_frame_outputs"
 | 
			
		||||
                current_out = output_dict[storage_key][frame_idx]
 | 
			
		||||
                pred_masks = current_out["pred_masks"]
 | 
			
		||||
            else:
 | 
			
		||||
                storage_key = "non_cond_frame_outputs"
 | 
			
		||||
                current_out, pred_masks = self._run_single_frame_inference(
 | 
			
		||||
                    inference_state=inference_state,
 | 
			
		||||
                    output_dict=output_dict,
 | 
			
		||||
                    frame_idx=frame_idx,
 | 
			
		||||
                    batch_size=batch_size,
 | 
			
		||||
                    is_init_cond_frame=False,
 | 
			
		||||
                    point_inputs=None,
 | 
			
		||||
                    mask_inputs=None,
 | 
			
		||||
                    reverse=reverse,
 | 
			
		||||
                    run_mem_encoder=True,
 | 
			
		||||
                )
 | 
			
		||||
                output_dict[storage_key][frame_idx] = current_out
 | 
			
		||||
            # Create slices of per-object outputs for subsequent interaction with each
 | 
			
		||||
            # individual object after tracking.
 | 
			
		||||
            self._add_output_per_object(
 | 
			
		||||
                inference_state, frame_idx, current_out, storage_key
 | 
			
		||||
            )
 | 
			
		||||
            inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
 | 
			
		||||
            pred_masks_per_obj = [None] * batch_size
 | 
			
		||||
            for obj_idx in range(batch_size):
 | 
			
		||||
                obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
                # We skip those frames already in consolidated outputs (these are frames
 | 
			
		||||
                # that received input clicks or mask). Note that we cannot directly run
 | 
			
		||||
                # batched forward on them via `_run_single_frame_inference` because the
 | 
			
		||||
                # number of clicks on each object might be different.
 | 
			
		||||
                if frame_idx in obj_output_dict["cond_frame_outputs"]:
 | 
			
		||||
                    storage_key = "cond_frame_outputs"
 | 
			
		||||
                    current_out = obj_output_dict[storage_key][frame_idx]
 | 
			
		||||
                    pred_masks = current_out["pred_masks"]
 | 
			
		||||
                    if self.clear_non_cond_mem_around_input:
 | 
			
		||||
                        # clear non-conditioning memory of the surrounding frames
 | 
			
		||||
                        self._clear_obj_non_cond_mem_around_input(
 | 
			
		||||
                            inference_state, frame_idx, obj_idx
 | 
			
		||||
                        )
 | 
			
		||||
                else:
 | 
			
		||||
                    storage_key = "non_cond_frame_outputs"
 | 
			
		||||
                    current_out, pred_masks = self._run_single_frame_inference(
 | 
			
		||||
                        inference_state=inference_state,
 | 
			
		||||
                        output_dict=obj_output_dict,
 | 
			
		||||
                        frame_idx=frame_idx,
 | 
			
		||||
                        batch_size=1,  # run on the slice of a single object
 | 
			
		||||
                        is_init_cond_frame=False,
 | 
			
		||||
                        point_inputs=None,
 | 
			
		||||
                        mask_inputs=None,
 | 
			
		||||
                        reverse=reverse,
 | 
			
		||||
                        run_mem_encoder=True,
 | 
			
		||||
                    )
 | 
			
		||||
                    obj_output_dict[storage_key][frame_idx] = current_out
 | 
			
		||||
 | 
			
		||||
                inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
 | 
			
		||||
                    "reverse": reverse
 | 
			
		||||
                }
 | 
			
		||||
                pred_masks_per_obj[obj_idx] = pred_masks
 | 
			
		||||
 | 
			
		||||
            # Resize the output mask to the original video resolution (we directly use
 | 
			
		||||
            # the mask scores on GPU for output to avoid any CPU conversion in between)
 | 
			
		||||
            if len(pred_masks_per_obj) > 1:
 | 
			
		||||
                all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                all_pred_masks = pred_masks_per_obj[0]
 | 
			
		||||
            _, video_res_masks = self._get_orig_video_res_output(
 | 
			
		||||
                inference_state, pred_masks
 | 
			
		||||
                inference_state, all_pred_masks
 | 
			
		||||
            )
 | 
			
		||||
            yield frame_idx, obj_ids, video_res_masks
 | 
			
		||||
 | 
			
		||||
    def _add_output_per_object(
 | 
			
		||||
        self, inference_state, frame_idx, current_out, storage_key
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Split a multi-object output into per-object output slices and add them into
 | 
			
		||||
        `output_dict_per_obj`. The resulting slices share the same tensor storage.
 | 
			
		||||
        """
 | 
			
		||||
        maskmem_features = current_out["maskmem_features"]
 | 
			
		||||
        assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
 | 
			
		||||
 | 
			
		||||
        maskmem_pos_enc = current_out["maskmem_pos_enc"]
 | 
			
		||||
        assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
 | 
			
		||||
 | 
			
		||||
        output_dict_per_obj = inference_state["output_dict_per_obj"]
 | 
			
		||||
        for obj_idx, obj_output_dict in output_dict_per_obj.items():
 | 
			
		||||
            obj_slice = slice(obj_idx, obj_idx + 1)
 | 
			
		||||
            obj_out = {
 | 
			
		||||
                "maskmem_features": None,
 | 
			
		||||
                "maskmem_pos_enc": None,
 | 
			
		||||
                "pred_masks": current_out["pred_masks"][obj_slice],
 | 
			
		||||
                "obj_ptr": current_out["obj_ptr"][obj_slice],
 | 
			
		||||
                "object_score_logits": current_out["object_score_logits"][obj_slice],
 | 
			
		||||
            }
 | 
			
		||||
            if maskmem_features is not None:
 | 
			
		||||
                obj_out["maskmem_features"] = maskmem_features[obj_slice]
 | 
			
		||||
            if maskmem_pos_enc is not None:
 | 
			
		||||
                obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
 | 
			
		||||
            obj_output_dict[storage_key][frame_idx] = obj_out
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def clear_all_prompts_in_frame(
 | 
			
		||||
        self, inference_state, frame_idx, obj_id, need_output=True
 | 
			
		||||
@ -789,41 +643,14 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
        temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
 | 
			
		||||
        # Check and see if there are still any inputs left on this frame
 | 
			
		||||
        batch_size = self._get_obj_num(inference_state)
 | 
			
		||||
        frame_has_input = False
 | 
			
		||||
        for obj_idx2 in range(batch_size):
 | 
			
		||||
            if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
 | 
			
		||||
                frame_has_input = True
 | 
			
		||||
                break
 | 
			
		||||
            if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
 | 
			
		||||
                frame_has_input = True
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        # If this frame has no remaining inputs for any objects, we further clear its
 | 
			
		||||
        # conditioning frame status
 | 
			
		||||
        if not frame_has_input:
 | 
			
		||||
            output_dict = inference_state["output_dict"]
 | 
			
		||||
            consolidated_frame_inds = inference_state["consolidated_frame_inds"]
 | 
			
		||||
            consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
 | 
			
		||||
            consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
 | 
			
		||||
            # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
 | 
			
		||||
            out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
            if out is not None:
 | 
			
		||||
                # The frame is not a conditioning frame anymore since it's not receiving inputs,
 | 
			
		||||
                # so we "downgrade" its output (if exists) to a non-conditioning frame output.
 | 
			
		||||
                output_dict["non_cond_frame_outputs"][frame_idx] = out
 | 
			
		||||
                inference_state["frames_already_tracked"].pop(frame_idx, None)
 | 
			
		||||
            # Similarly, do it for the sliced output on each object.
 | 
			
		||||
            for obj_idx2 in range(batch_size):
 | 
			
		||||
                obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
 | 
			
		||||
                obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
                if obj_out is not None:
 | 
			
		||||
                    obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
 | 
			
		||||
 | 
			
		||||
            # If all the conditioning frames have been removed, we also clear the tracking outputs
 | 
			
		||||
            if len(output_dict["cond_frame_outputs"]) == 0:
 | 
			
		||||
                self._reset_tracking_results(inference_state)
 | 
			
		||||
        # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
 | 
			
		||||
        obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
        out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
 | 
			
		||||
        if out is not None:
 | 
			
		||||
            # The frame is not a conditioning frame anymore since it's not receiving inputs,
 | 
			
		||||
            # so we "downgrade" its output (if exists) to a non-conditioning frame output.
 | 
			
		||||
            obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
 | 
			
		||||
            inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
 | 
			
		||||
 | 
			
		||||
        if not need_output:
 | 
			
		||||
            return
 | 
			
		||||
@ -837,7 +664,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
            inference_state,
 | 
			
		||||
            frame_idx,
 | 
			
		||||
            is_cond=is_cond,
 | 
			
		||||
            run_mem_encoder=False,
 | 
			
		||||
            consolidate_at_video_res=True,
 | 
			
		||||
        )
 | 
			
		||||
        _, video_res_masks = self._get_orig_video_res_output(
 | 
			
		||||
@ -857,6 +683,7 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        inference_state["mask_inputs_per_obj"].clear()
 | 
			
		||||
        inference_state["output_dict_per_obj"].clear()
 | 
			
		||||
        inference_state["temp_output_dict_per_obj"].clear()
 | 
			
		||||
        inference_state["frames_tracked_per_obj"].clear()
 | 
			
		||||
 | 
			
		||||
    def _reset_tracking_results(self, inference_state):
 | 
			
		||||
        """Reset all tracking inputs and results across the videos."""
 | 
			
		||||
@ -870,12 +697,8 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        for v in inference_state["temp_output_dict_per_obj"].values():
 | 
			
		||||
            v["cond_frame_outputs"].clear()
 | 
			
		||||
            v["non_cond_frame_outputs"].clear()
 | 
			
		||||
        inference_state["output_dict"]["cond_frame_outputs"].clear()
 | 
			
		||||
        inference_state["output_dict"]["non_cond_frame_outputs"].clear()
 | 
			
		||||
        inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
 | 
			
		||||
        inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
 | 
			
		||||
        inference_state["tracking_has_started"] = False
 | 
			
		||||
        inference_state["frames_already_tracked"].clear()
 | 
			
		||||
        for v in inference_state["frames_tracked_per_obj"].values():
 | 
			
		||||
            v.clear()
 | 
			
		||||
 | 
			
		||||
    def _get_image_feature(self, inference_state, frame_idx, batch_size):
 | 
			
		||||
        """Compute the image features on a given frame."""
 | 
			
		||||
@ -1093,8 +916,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        inference_state["obj_ids"] = new_obj_ids
 | 
			
		||||
 | 
			
		||||
        # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
 | 
			
		||||
        # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
 | 
			
		||||
        # it's already handled in Step 0)
 | 
			
		||||
        def _map_keys(container):
 | 
			
		||||
            new_kvs = []
 | 
			
		||||
            for k in old_obj_inds:
 | 
			
		||||
@ -1107,30 +928,9 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        _map_keys(inference_state["mask_inputs_per_obj"])
 | 
			
		||||
        _map_keys(inference_state["output_dict_per_obj"])
 | 
			
		||||
        _map_keys(inference_state["temp_output_dict_per_obj"])
 | 
			
		||||
        _map_keys(inference_state["frames_tracked_per_obj"])
 | 
			
		||||
 | 
			
		||||
        # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
 | 
			
		||||
        def _slice_state(output_dict, storage_key):
 | 
			
		||||
            for frame_idx, out in output_dict[storage_key].items():
 | 
			
		||||
                out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
 | 
			
		||||
                out["maskmem_pos_enc"] = [
 | 
			
		||||
                    x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
 | 
			
		||||
                ]
 | 
			
		||||
                # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
 | 
			
		||||
                out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
 | 
			
		||||
                out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
 | 
			
		||||
                out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
 | 
			
		||||
                out["object_score_logits"] = out["object_score_logits"][
 | 
			
		||||
                    remain_old_obj_inds
 | 
			
		||||
                ]
 | 
			
		||||
                # also update the per-object slices
 | 
			
		||||
                self._add_output_per_object(
 | 
			
		||||
                    inference_state, frame_idx, out, storage_key
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        _slice_state(inference_state["output_dict"], "cond_frame_outputs")
 | 
			
		||||
        _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
 | 
			
		||||
 | 
			
		||||
        # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
 | 
			
		||||
        # Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
 | 
			
		||||
        # could show an updated mask for objects previously occluded by the object being removed
 | 
			
		||||
        if need_output:
 | 
			
		||||
            temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
 | 
			
		||||
@ -1143,7 +943,6 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
                    inference_state,
 | 
			
		||||
                    frame_idx,
 | 
			
		||||
                    is_cond=is_cond,
 | 
			
		||||
                    run_mem_encoder=False,
 | 
			
		||||
                    consolidate_at_video_res=True,
 | 
			
		||||
                )
 | 
			
		||||
                _, video_res_masks = self._get_orig_video_res_output(
 | 
			
		||||
@ -1165,12 +964,12 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
        r = self.memory_temporal_stride_for_eval
 | 
			
		||||
        frame_idx_begin = frame_idx - r * self.num_maskmem
 | 
			
		||||
        frame_idx_end = frame_idx + r * self.num_maskmem
 | 
			
		||||
        output_dict = inference_state["output_dict"]
 | 
			
		||||
        non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
 | 
			
		||||
        for t in range(frame_idx_begin, frame_idx_end + 1):
 | 
			
		||||
            non_cond_frame_outputs.pop(t, None)
 | 
			
		||||
            for obj_output_dict in inference_state["output_dict_per_obj"].values():
 | 
			
		||||
                obj_output_dict["non_cond_frame_outputs"].pop(t, None)
 | 
			
		||||
        batch_size = self._get_obj_num(inference_state)
 | 
			
		||||
        for obj_idx in range(batch_size):
 | 
			
		||||
            obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
 | 
			
		||||
            non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
 | 
			
		||||
            for t in range(frame_idx_begin, frame_idx_end + 1):
 | 
			
		||||
                non_cond_frame_outputs.pop(t, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1172
									
								
								sam2/sam2_video_predictor_legacy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1172
									
								
								sam2/sam2_video_predictor_legacy.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user