mirror of
https://github.com/facebookresearch/sam2.git
synced 2025-09-18 12:42:48 +08:00
switch to a new implementation of the class SAM2VideoPredictor
for per-object inference sam2/sam2_video_predictor.py
In this PR, we switch to a new implementation of the class `SAM2VideoPredictor` for in sam2/sam2_video_predictor.py, which allows for independent per-object inference. Specifically, the new `SAM2VideoPredictor`: * it handles the inference of each object separately, as if we are opening a separate session for each object * it relaxes the assumption on prompting * previously if a frame receives clicks only for a subset of objects, the rest of (non-prompted) objects are assumed to be non-existing in this frame * now if a frame receives clicks only for a subset of objects, we don't make any assumptions for the remaining (non-prompted) objects * it allows adding new objects after tracking starts * (The previous implementation is backed up to `SAM2VideoPredictor` in sam2/sam2_video_predictor_legacy.py) Also, fix a small typo `APP_URL` => `API_URL` in the doc. Test plan: tested with the predictor notebook `notebooks/video_predictor_example.ipynb` and VOS script `tools/vos_inference.py`. Also tested with the demo.
This commit is contained in:
parent
3297dd0eb0
commit
c61e2475e6
@ -105,7 +105,7 @@ cd demo/backend/server/
|
|||||||
```bash
|
```bash
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK=1 \
|
PYTORCH_ENABLE_MPS_FALLBACK=1 \
|
||||||
APP_ROOT="$(pwd)/../../../" \
|
APP_ROOT="$(pwd)/../../../" \
|
||||||
APP_URL=http://localhost:7263 \
|
API_URL=http://localhost:7263 \
|
||||||
MODEL_SIZE=base_plus \
|
MODEL_SIZE=base_plus \
|
||||||
DATA_PATH="$(pwd)/../../data" \
|
DATA_PATH="$(pwd)/../../data" \
|
||||||
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
|
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
|
||||||
|
@ -27,8 +27,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
||||||
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
||||||
clear_non_cond_mem_around_input=False,
|
clear_non_cond_mem_around_input=False,
|
||||||
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
|
||||||
clear_non_cond_mem_for_multi_obj=False,
|
|
||||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||||
add_all_frames_to_correct_as_cond=False,
|
add_all_frames_to_correct_as_cond=False,
|
||||||
@ -38,7 +36,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
self.fill_hole_area = fill_hole_area
|
self.fill_hole_area = fill_hole_area
|
||||||
self.non_overlap_masks = non_overlap_masks
|
self.non_overlap_masks = non_overlap_masks
|
||||||
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
||||||
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
|
||||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -88,11 +85,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state["obj_id_to_idx"] = OrderedDict()
|
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||||
inference_state["obj_idx_to_id"] = OrderedDict()
|
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||||
inference_state["obj_ids"] = []
|
inference_state["obj_ids"] = []
|
||||||
# A storage to hold the model's tracking results and states on each frame
|
|
||||||
inference_state["output_dict"] = {
|
|
||||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
||||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
|
||||||
}
|
|
||||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||||
inference_state["output_dict_per_obj"] = {}
|
inference_state["output_dict_per_obj"] = {}
|
||||||
# A temporary storage to hold new outputs when user interact with a frame
|
# A temporary storage to hold new outputs when user interact with a frame
|
||||||
@ -100,13 +92,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state["temp_output_dict_per_obj"] = {}
|
inference_state["temp_output_dict_per_obj"] = {}
|
||||||
# Frames that already holds consolidated outputs from click or mask inputs
|
# Frames that already holds consolidated outputs from click or mask inputs
|
||||||
# (we directly use their consolidated outputs during tracking)
|
# (we directly use their consolidated outputs during tracking)
|
||||||
inference_state["consolidated_frame_inds"] = {
|
|
||||||
"cond_frame_outputs": set(), # set containing frame indices
|
|
||||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
|
||||||
}
|
|
||||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||||
inference_state["tracking_has_started"] = False
|
inference_state["frames_tracked_per_obj"] = {}
|
||||||
inference_state["frames_already_tracked"] = {}
|
|
||||||
# Warm up the visual backbone and cache the image feature on frame 0
|
# Warm up the visual backbone and cache the image feature on frame 0
|
||||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||||
return inference_state
|
return inference_state
|
||||||
@ -134,9 +121,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
if obj_idx is not None:
|
if obj_idx is not None:
|
||||||
return obj_idx
|
return obj_idx
|
||||||
|
|
||||||
# This is a new object id not sent to the server before. We only allow adding
|
# We always allow adding new objects (including after tracking starts).
|
||||||
# new objects *before* the tracking starts.
|
allow_new_object = True
|
||||||
allow_new_object = not inference_state["tracking_has_started"]
|
|
||||||
if allow_new_object:
|
if allow_new_object:
|
||||||
# get the next object slot
|
# get the next object slot
|
||||||
obj_idx = len(inference_state["obj_id_to_idx"])
|
obj_idx = len(inference_state["obj_id_to_idx"])
|
||||||
@ -154,6 +140,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
}
|
}
|
||||||
|
inference_state["frames_tracked_per_obj"][obj_idx] = {}
|
||||||
return obj_idx
|
return obj_idx
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -214,15 +201,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
"box prompt must be provided before any point prompt "
|
"box prompt must be provided before any point prompt "
|
||||||
"(please use clear_old_points=True instead)"
|
"(please use clear_old_points=True instead)"
|
||||||
)
|
)
|
||||||
if inference_state["tracking_has_started"]:
|
|
||||||
warnings.warn(
|
|
||||||
"You are adding a box after tracking starts. SAM 2 may not always be "
|
|
||||||
"able to incorporate a box prompt for *refinement*. If you intend to "
|
|
||||||
"use box prompt as an *initial* input before tracking, please call "
|
|
||||||
"'reset_state' on the inference state to restart from scratch.",
|
|
||||||
category=UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if not isinstance(box, torch.Tensor):
|
if not isinstance(box, torch.Tensor):
|
||||||
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||||
box_coords = box.reshape(1, 2, 2)
|
box_coords = box.reshape(1, 2, 2)
|
||||||
@ -252,12 +230,13 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||||
# the input points will be used to correct the already tracked masks.
|
# the input points will be used to correct the already tracked masks.
|
||||||
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
||||||
|
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||||
# whether to track in reverse time order
|
# whether to track in reverse time order
|
||||||
if is_init_cond_frame:
|
if is_init_cond_frame:
|
||||||
reverse = False
|
reverse = False
|
||||||
else:
|
else:
|
||||||
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
reverse = obj_frames_tracked[frame_idx]["reverse"]
|
||||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||||
@ -306,7 +285,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
is_cond=is_cond,
|
is_cond=is_cond,
|
||||||
run_mem_encoder=False,
|
|
||||||
consolidate_at_video_res=True,
|
consolidate_at_video_res=True,
|
||||||
)
|
)
|
||||||
_, video_res_masks = self._get_orig_video_res_output(
|
_, video_res_masks = self._get_orig_video_res_output(
|
||||||
@ -357,12 +335,13 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||||
# the input points will be used to correct the already tracked masks.
|
# the input points will be used to correct the already tracked masks.
|
||||||
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
||||||
|
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||||
# whether to track in reverse time order
|
# whether to track in reverse time order
|
||||||
if is_init_cond_frame:
|
if is_init_cond_frame:
|
||||||
reverse = False
|
reverse = False
|
||||||
else:
|
else:
|
||||||
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
reverse = obj_frames_tracked[frame_idx]["reverse"]
|
||||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||||
@ -394,7 +373,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
is_cond=is_cond,
|
is_cond=is_cond,
|
||||||
run_mem_encoder=False,
|
|
||||||
consolidate_at_video_res=True,
|
consolidate_at_video_res=True,
|
||||||
)
|
)
|
||||||
_, video_res_masks = self._get_orig_video_res_output(
|
_, video_res_masks = self._get_orig_video_res_output(
|
||||||
@ -429,7 +407,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
is_cond,
|
is_cond,
|
||||||
run_mem_encoder,
|
|
||||||
consolidate_at_video_res=False,
|
consolidate_at_video_res=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -446,7 +423,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# Optionally, we allow consolidating the temporary outputs at the original
|
# Optionally, we allow consolidating the temporary outputs at the original
|
||||||
# video resolution (to provide a better editing experience for mask prompts).
|
# video resolution (to provide a better editing experience for mask prompts).
|
||||||
if consolidate_at_video_res:
|
if consolidate_at_video_res:
|
||||||
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
|
||||||
consolidated_H = inference_state["video_height"]
|
consolidated_H = inference_state["video_height"]
|
||||||
consolidated_W = inference_state["video_width"]
|
consolidated_W = inference_state["video_width"]
|
||||||
consolidated_mask_key = "pred_masks_video_res"
|
consolidated_mask_key = "pred_masks_video_res"
|
||||||
@ -459,30 +435,13 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
||||||
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
||||||
consolidated_out = {
|
consolidated_out = {
|
||||||
"maskmem_features": None,
|
|
||||||
"maskmem_pos_enc": None,
|
|
||||||
consolidated_mask_key: torch.full(
|
consolidated_mask_key: torch.full(
|
||||||
size=(batch_size, 1, consolidated_H, consolidated_W),
|
size=(batch_size, 1, consolidated_H, consolidated_W),
|
||||||
fill_value=NO_OBJ_SCORE,
|
fill_value=NO_OBJ_SCORE,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=inference_state["storage_device"],
|
device=inference_state["storage_device"],
|
||||||
),
|
),
|
||||||
"obj_ptr": torch.full(
|
|
||||||
size=(batch_size, self.hidden_dim),
|
|
||||||
fill_value=NO_OBJ_SCORE,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=inference_state["device"],
|
|
||||||
),
|
|
||||||
"object_score_logits": torch.full(
|
|
||||||
size=(batch_size, 1),
|
|
||||||
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
|
||||||
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
|
||||||
fill_value=10.0,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=inference_state["device"],
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
empty_mask_ptr = None
|
|
||||||
for obj_idx in range(batch_size):
|
for obj_idx in range(batch_size):
|
||||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
@ -499,16 +458,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
||||||
# placeholder above) and set its object pointer to be a dummy pointer.
|
# placeholder above) and set its object pointer to be a dummy pointer.
|
||||||
if out is None:
|
if out is None:
|
||||||
# Fill in dummy object pointers for those objects without any inputs or
|
|
||||||
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
|
||||||
# i.e. when we need to build the memory for tracking).
|
|
||||||
if run_mem_encoder:
|
|
||||||
if empty_mask_ptr is None:
|
|
||||||
empty_mask_ptr = self._get_empty_mask_ptr(
|
|
||||||
inference_state, frame_idx
|
|
||||||
)
|
|
||||||
# fill object pointer with a dummy pointer (based on an empty mask)
|
|
||||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
|
||||||
continue
|
continue
|
||||||
# Add the temporary object output mask to consolidated output mask
|
# Add the temporary object output mask to consolidated output mask
|
||||||
obj_mask = out["pred_masks"]
|
obj_mask = out["pred_masks"]
|
||||||
@ -524,141 +473,74 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
align_corners=False,
|
align_corners=False,
|
||||||
)
|
)
|
||||||
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
||||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
|
||||||
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
|
||||||
"object_score_logits"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Optionally, apply non-overlapping constraints on the consolidated scores
|
|
||||||
# and rerun the memory encoder
|
|
||||||
if run_mem_encoder:
|
|
||||||
device = inference_state["device"]
|
|
||||||
high_res_masks = torch.nn.functional.interpolate(
|
|
||||||
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
|
||||||
size=(self.image_size, self.image_size),
|
|
||||||
mode="bilinear",
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
if self.non_overlap_masks_for_mem_enc:
|
|
||||||
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
|
||||||
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
|
||||||
inference_state=inference_state,
|
|
||||||
frame_idx=frame_idx,
|
|
||||||
batch_size=batch_size,
|
|
||||||
high_res_masks=high_res_masks,
|
|
||||||
object_score_logits=consolidated_out["object_score_logits"],
|
|
||||||
is_mask_from_pts=True, # these frames are what the user interacted with
|
|
||||||
)
|
|
||||||
consolidated_out["maskmem_features"] = maskmem_features
|
|
||||||
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
|
||||||
|
|
||||||
return consolidated_out
|
return consolidated_out
|
||||||
|
|
||||||
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
|
||||||
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
|
||||||
# A dummy (empty) mask with a single object
|
|
||||||
batch_size = 1
|
|
||||||
mask_inputs = torch.zeros(
|
|
||||||
(batch_size, 1, self.image_size, self.image_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=inference_state["device"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retrieve correct image features
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
current_vision_feats,
|
|
||||||
current_vision_pos_embeds,
|
|
||||||
feat_sizes,
|
|
||||||
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
|
||||||
|
|
||||||
# Feed the empty mask and image feature above to get a dummy object pointer
|
|
||||||
current_out = self.track_step(
|
|
||||||
frame_idx=frame_idx,
|
|
||||||
is_init_cond_frame=True,
|
|
||||||
current_vision_feats=current_vision_feats,
|
|
||||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
|
||||||
feat_sizes=feat_sizes,
|
|
||||||
point_inputs=None,
|
|
||||||
mask_inputs=mask_inputs,
|
|
||||||
output_dict={},
|
|
||||||
num_frames=inference_state["num_frames"],
|
|
||||||
track_in_reverse=False,
|
|
||||||
run_mem_encoder=False,
|
|
||||||
prev_sam_mask_logits=None,
|
|
||||||
)
|
|
||||||
return current_out["obj_ptr"]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def propagate_in_video_preflight(self, inference_state):
|
def propagate_in_video_preflight(self, inference_state):
|
||||||
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
||||||
# Tracking has started and we don't allow adding new objects until session is reset.
|
# Check and make sure that every object has received input points or masks.
|
||||||
inference_state["tracking_has_started"] = True
|
|
||||||
batch_size = self._get_obj_num(inference_state)
|
batch_size = self._get_obj_num(inference_state)
|
||||||
|
if batch_size == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No input points or masks are provided for any object; please add inputs first."
|
||||||
|
)
|
||||||
|
|
||||||
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||||
# add them into "output_dict".
|
# add them into "output_dict".
|
||||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
for obj_idx in range(batch_size):
|
||||||
output_dict = inference_state["output_dict"]
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||||
# temporary outputs have been added (either in this call or any previous calls
|
|
||||||
# to `propagate_in_video_preflight`).
|
|
||||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
||||||
for is_cond in [False, True]:
|
for is_cond in [False, True]:
|
||||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
storage_key = (
|
||||||
|
"cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||||
|
)
|
||||||
# Find all the frames that contain temporary outputs for any objects
|
# Find all the frames that contain temporary outputs for any objects
|
||||||
# (these should be the frames that have just received clicks for mask inputs
|
# (these should be the frames that have just received clicks for mask inputs
|
||||||
# via `add_new_points_or_box` or `add_new_mask`)
|
# via `add_new_points_or_box` or `add_new_mask`)
|
||||||
temp_frame_inds = set()
|
for frame_idx, out in obj_temp_output_dict[storage_key].items():
|
||||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
# Run memory encoder on the temporary outputs (if the memory feature is missing)
|
||||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
if out["maskmem_features"] is None:
|
||||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
high_res_masks = torch.nn.functional.interpolate(
|
||||||
# consolidate the temporary output across all objects on this frame
|
out["pred_masks"].to(inference_state["device"]),
|
||||||
for frame_idx in temp_frame_inds:
|
size=(self.image_size, self.image_size),
|
||||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
mode="bilinear",
|
||||||
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
align_corners=False,
|
||||||
)
|
)
|
||||||
# merge them into "output_dict" and also create per-object slices
|
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
||||||
output_dict[storage_key][frame_idx] = consolidated_out
|
inference_state=inference_state,
|
||||||
self._add_output_per_object(
|
frame_idx=frame_idx,
|
||||||
inference_state, frame_idx, consolidated_out, storage_key
|
batch_size=1, # run on the slice of a single object
|
||||||
|
high_res_masks=high_res_masks,
|
||||||
|
object_score_logits=out["object_score_logits"],
|
||||||
|
# these frames are what the user interacted with
|
||||||
|
is_mask_from_pts=True,
|
||||||
)
|
)
|
||||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
out["maskmem_features"] = maskmem_features
|
||||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||||
)
|
|
||||||
if clear_non_cond_mem:
|
obj_output_dict[storage_key][frame_idx] = out
|
||||||
|
if self.clear_non_cond_mem_around_input:
|
||||||
# clear non-conditioning memory of the surrounding frames
|
# clear non-conditioning memory of the surrounding frames
|
||||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
self._clear_obj_non_cond_mem_around_input(
|
||||||
|
inference_state, frame_idx, obj_idx
|
||||||
|
)
|
||||||
|
|
||||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
|
||||||
obj_temp_output_dict[storage_key].clear()
|
obj_temp_output_dict[storage_key].clear()
|
||||||
|
|
||||||
|
# check and make sure that every object has received input points or masks
|
||||||
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
|
if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
||||||
|
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
|
||||||
|
)
|
||||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||||
# output on the same frame in "non_cond_frame_outputs"
|
# output on the same frame in "non_cond_frame_outputs"
|
||||||
for frame_idx in output_dict["cond_frame_outputs"]:
|
|
||||||
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
|
||||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
|
||||||
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||||
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||||
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
|
||||||
assert frame_idx in output_dict["cond_frame_outputs"]
|
|
||||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
|
||||||
|
|
||||||
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
|
||||||
# with either points or mask inputs (which should be true under a correct workflow).
|
|
||||||
all_consolidated_frame_inds = (
|
|
||||||
consolidated_frame_inds["cond_frame_outputs"]
|
|
||||||
| consolidated_frame_inds["non_cond_frame_outputs"]
|
|
||||||
)
|
|
||||||
input_frames_inds = set()
|
|
||||||
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
|
||||||
input_frames_inds.update(point_inputs_per_frame.keys())
|
|
||||||
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
|
||||||
input_frames_inds.update(mask_inputs_per_frame.keys())
|
|
||||||
assert all_consolidated_frame_inds == input_frames_inds
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def propagate_in_video(
|
def propagate_in_video(
|
||||||
@ -671,21 +553,18 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
"""Propagate the input points across frames to track in the entire video."""
|
"""Propagate the input points across frames to track in the entire video."""
|
||||||
self.propagate_in_video_preflight(inference_state)
|
self.propagate_in_video_preflight(inference_state)
|
||||||
|
|
||||||
output_dict = inference_state["output_dict"]
|
|
||||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
||||||
obj_ids = inference_state["obj_ids"]
|
obj_ids = inference_state["obj_ids"]
|
||||||
num_frames = inference_state["num_frames"]
|
num_frames = inference_state["num_frames"]
|
||||||
batch_size = self._get_obj_num(inference_state)
|
batch_size = self._get_obj_num(inference_state)
|
||||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
|
||||||
raise RuntimeError("No points are provided; please add points first")
|
|
||||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
|
||||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# set start index, end index, and processing order
|
# set start index, end index, and processing order
|
||||||
if start_frame_idx is None:
|
if start_frame_idx is None:
|
||||||
# default: start from the earliest frame with input points
|
# default: start from the earliest frame with input points
|
||||||
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
start_frame_idx = min(
|
||||||
|
t
|
||||||
|
for obj_output_dict in inference_state["output_dict_per_obj"].values()
|
||||||
|
for t in obj_output_dict["cond_frame_outputs"]
|
||||||
|
)
|
||||||
if max_frame_num_to_track is None:
|
if max_frame_num_to_track is None:
|
||||||
# default: track all the frames in the video
|
# default: track all the frames in the video
|
||||||
max_frame_num_to_track = num_frames
|
max_frame_num_to_track = num_frames
|
||||||
@ -702,78 +581,53 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
||||||
|
|
||||||
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
||||||
|
pred_masks_per_obj = [None] * batch_size
|
||||||
|
for obj_idx in range(batch_size):
|
||||||
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
# We skip those frames already in consolidated outputs (these are frames
|
# We skip those frames already in consolidated outputs (these are frames
|
||||||
# that received input clicks or mask). Note that we cannot directly run
|
# that received input clicks or mask). Note that we cannot directly run
|
||||||
# batched forward on them via `_run_single_frame_inference` because the
|
# batched forward on them via `_run_single_frame_inference` because the
|
||||||
# number of clicks on each object might be different.
|
# number of clicks on each object might be different.
|
||||||
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||||
storage_key = "cond_frame_outputs"
|
storage_key = "cond_frame_outputs"
|
||||||
current_out = output_dict[storage_key][frame_idx]
|
current_out = obj_output_dict[storage_key][frame_idx]
|
||||||
pred_masks = current_out["pred_masks"]
|
pred_masks = current_out["pred_masks"]
|
||||||
if clear_non_cond_mem:
|
if self.clear_non_cond_mem_around_input:
|
||||||
# clear non-conditioning memory of the surrounding frames
|
# clear non-conditioning memory of the surrounding frames
|
||||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
self._clear_obj_non_cond_mem_around_input(
|
||||||
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
inference_state, frame_idx, obj_idx
|
||||||
storage_key = "non_cond_frame_outputs"
|
)
|
||||||
current_out = output_dict[storage_key][frame_idx]
|
|
||||||
pred_masks = current_out["pred_masks"]
|
|
||||||
else:
|
else:
|
||||||
storage_key = "non_cond_frame_outputs"
|
storage_key = "non_cond_frame_outputs"
|
||||||
current_out, pred_masks = self._run_single_frame_inference(
|
current_out, pred_masks = self._run_single_frame_inference(
|
||||||
inference_state=inference_state,
|
inference_state=inference_state,
|
||||||
output_dict=output_dict,
|
output_dict=obj_output_dict,
|
||||||
frame_idx=frame_idx,
|
frame_idx=frame_idx,
|
||||||
batch_size=batch_size,
|
batch_size=1, # run on the slice of a single object
|
||||||
is_init_cond_frame=False,
|
is_init_cond_frame=False,
|
||||||
point_inputs=None,
|
point_inputs=None,
|
||||||
mask_inputs=None,
|
mask_inputs=None,
|
||||||
reverse=reverse,
|
reverse=reverse,
|
||||||
run_mem_encoder=True,
|
run_mem_encoder=True,
|
||||||
)
|
)
|
||||||
output_dict[storage_key][frame_idx] = current_out
|
obj_output_dict[storage_key][frame_idx] = current_out
|
||||||
# Create slices of per-object outputs for subsequent interaction with each
|
|
||||||
# individual object after tracking.
|
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
|
||||||
self._add_output_per_object(
|
"reverse": reverse
|
||||||
inference_state, frame_idx, current_out, storage_key
|
}
|
||||||
)
|
pred_masks_per_obj[obj_idx] = pred_masks
|
||||||
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
|
||||||
|
|
||||||
# Resize the output mask to the original video resolution (we directly use
|
# Resize the output mask to the original video resolution (we directly use
|
||||||
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
||||||
|
if len(pred_masks_per_obj) > 1:
|
||||||
|
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
|
||||||
|
else:
|
||||||
|
all_pred_masks = pred_masks_per_obj[0]
|
||||||
_, video_res_masks = self._get_orig_video_res_output(
|
_, video_res_masks = self._get_orig_video_res_output(
|
||||||
inference_state, pred_masks
|
inference_state, all_pred_masks
|
||||||
)
|
)
|
||||||
yield frame_idx, obj_ids, video_res_masks
|
yield frame_idx, obj_ids, video_res_masks
|
||||||
|
|
||||||
def _add_output_per_object(
|
|
||||||
self, inference_state, frame_idx, current_out, storage_key
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Split a multi-object output into per-object output slices and add them into
|
|
||||||
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
|
||||||
"""
|
|
||||||
maskmem_features = current_out["maskmem_features"]
|
|
||||||
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
|
||||||
|
|
||||||
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
|
||||||
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
|
||||||
|
|
||||||
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
|
||||||
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
|
||||||
obj_slice = slice(obj_idx, obj_idx + 1)
|
|
||||||
obj_out = {
|
|
||||||
"maskmem_features": None,
|
|
||||||
"maskmem_pos_enc": None,
|
|
||||||
"pred_masks": current_out["pred_masks"][obj_slice],
|
|
||||||
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
|
||||||
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
|
||||||
}
|
|
||||||
if maskmem_features is not None:
|
|
||||||
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
|
||||||
if maskmem_pos_enc is not None:
|
|
||||||
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
|
||||||
obj_output_dict[storage_key][frame_idx] = obj_out
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def clear_all_prompts_in_frame(
|
def clear_all_prompts_in_frame(
|
||||||
self, inference_state, frame_idx, obj_id, need_output=True
|
self, inference_state, frame_idx, obj_id, need_output=True
|
||||||
@ -789,41 +643,14 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
||||||
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||||
|
|
||||||
# Check and see if there are still any inputs left on this frame
|
|
||||||
batch_size = self._get_obj_num(inference_state)
|
|
||||||
frame_has_input = False
|
|
||||||
for obj_idx2 in range(batch_size):
|
|
||||||
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
|
||||||
frame_has_input = True
|
|
||||||
break
|
|
||||||
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
|
||||||
frame_has_input = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# If this frame has no remaining inputs for any objects, we further clear its
|
|
||||||
# conditioning frame status
|
|
||||||
if not frame_has_input:
|
|
||||||
output_dict = inference_state["output_dict"]
|
|
||||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
|
||||||
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
|
||||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
|
||||||
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
||||||
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
|
out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
||||||
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
||||||
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
|
||||||
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
|
||||||
# Similarly, do it for the sliced output on each object.
|
|
||||||
for obj_idx2 in range(batch_size):
|
|
||||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
|
||||||
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
|
||||||
if obj_out is not None:
|
|
||||||
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
|
||||||
|
|
||||||
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
|
||||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
|
||||||
self._reset_tracking_results(inference_state)
|
|
||||||
|
|
||||||
if not need_output:
|
if not need_output:
|
||||||
return
|
return
|
||||||
@ -837,7 +664,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
is_cond=is_cond,
|
is_cond=is_cond,
|
||||||
run_mem_encoder=False,
|
|
||||||
consolidate_at_video_res=True,
|
consolidate_at_video_res=True,
|
||||||
)
|
)
|
||||||
_, video_res_masks = self._get_orig_video_res_output(
|
_, video_res_masks = self._get_orig_video_res_output(
|
||||||
@ -857,6 +683,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state["mask_inputs_per_obj"].clear()
|
inference_state["mask_inputs_per_obj"].clear()
|
||||||
inference_state["output_dict_per_obj"].clear()
|
inference_state["output_dict_per_obj"].clear()
|
||||||
inference_state["temp_output_dict_per_obj"].clear()
|
inference_state["temp_output_dict_per_obj"].clear()
|
||||||
|
inference_state["frames_tracked_per_obj"].clear()
|
||||||
|
|
||||||
def _reset_tracking_results(self, inference_state):
|
def _reset_tracking_results(self, inference_state):
|
||||||
"""Reset all tracking inputs and results across the videos."""
|
"""Reset all tracking inputs and results across the videos."""
|
||||||
@ -870,12 +697,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
for v in inference_state["temp_output_dict_per_obj"].values():
|
for v in inference_state["temp_output_dict_per_obj"].values():
|
||||||
v["cond_frame_outputs"].clear()
|
v["cond_frame_outputs"].clear()
|
||||||
v["non_cond_frame_outputs"].clear()
|
v["non_cond_frame_outputs"].clear()
|
||||||
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
for v in inference_state["frames_tracked_per_obj"].values():
|
||||||
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
v.clear()
|
||||||
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
|
||||||
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
|
||||||
inference_state["tracking_has_started"] = False
|
|
||||||
inference_state["frames_already_tracked"].clear()
|
|
||||||
|
|
||||||
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
||||||
"""Compute the image features on a given frame."""
|
"""Compute the image features on a given frame."""
|
||||||
@ -1093,8 +916,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state["obj_ids"] = new_obj_ids
|
inference_state["obj_ids"] = new_obj_ids
|
||||||
|
|
||||||
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
||||||
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
|
||||||
# it's already handled in Step 0)
|
|
||||||
def _map_keys(container):
|
def _map_keys(container):
|
||||||
new_kvs = []
|
new_kvs = []
|
||||||
for k in old_obj_inds:
|
for k in old_obj_inds:
|
||||||
@ -1107,30 +928,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
_map_keys(inference_state["mask_inputs_per_obj"])
|
_map_keys(inference_state["mask_inputs_per_obj"])
|
||||||
_map_keys(inference_state["output_dict_per_obj"])
|
_map_keys(inference_state["output_dict_per_obj"])
|
||||||
_map_keys(inference_state["temp_output_dict_per_obj"])
|
_map_keys(inference_state["temp_output_dict_per_obj"])
|
||||||
|
_map_keys(inference_state["frames_tracked_per_obj"])
|
||||||
|
|
||||||
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
||||||
def _slice_state(output_dict, storage_key):
|
|
||||||
for frame_idx, out in output_dict[storage_key].items():
|
|
||||||
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
|
||||||
out["maskmem_pos_enc"] = [
|
|
||||||
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
|
||||||
]
|
|
||||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
|
||||||
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
|
||||||
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
|
||||||
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
|
||||||
out["object_score_logits"] = out["object_score_logits"][
|
|
||||||
remain_old_obj_inds
|
|
||||||
]
|
|
||||||
# also update the per-object slices
|
|
||||||
self._add_output_per_object(
|
|
||||||
inference_state, frame_idx, out, storage_key
|
|
||||||
)
|
|
||||||
|
|
||||||
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
|
||||||
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
|
||||||
|
|
||||||
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
|
||||||
# could show an updated mask for objects previously occluded by the object being removed
|
# could show an updated mask for objects previously occluded by the object being removed
|
||||||
if need_output:
|
if need_output:
|
||||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||||
@ -1143,7 +943,6 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
is_cond=is_cond,
|
is_cond=is_cond,
|
||||||
run_mem_encoder=False,
|
|
||||||
consolidate_at_video_res=True,
|
consolidate_at_video_res=True,
|
||||||
)
|
)
|
||||||
_, video_res_masks = self._get_orig_video_res_output(
|
_, video_res_masks = self._get_orig_video_res_output(
|
||||||
@ -1165,12 +964,12 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
r = self.memory_temporal_stride_for_eval
|
r = self.memory_temporal_stride_for_eval
|
||||||
frame_idx_begin = frame_idx - r * self.num_maskmem
|
frame_idx_begin = frame_idx - r * self.num_maskmem
|
||||||
frame_idx_end = frame_idx + r * self.num_maskmem
|
frame_idx_end = frame_idx + r * self.num_maskmem
|
||||||
output_dict = inference_state["output_dict"]
|
batch_size = self._get_obj_num(inference_state)
|
||||||
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
for obj_idx in range(batch_size):
|
||||||
|
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||||
|
non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
|
||||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||||
non_cond_frame_outputs.pop(t, None)
|
non_cond_frame_outputs.pop(t, None)
|
||||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
|
||||||
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
|
||||||
|
|
||||||
|
|
||||||
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
||||||
|
1172
sam2/sam2_video_predictor_legacy.py
Normal file
1172
sam2/sam2_video_predictor_legacy.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user