switch to a new implementation of the class SAM2VideoPredictor for per-object inference sam2/sam2_video_predictor.py

In this PR, we switch to a new implementation of the class `SAM2VideoPredictor` for in sam2/sam2_video_predictor.py, which allows for independent per-object inference.

Specifically, the new `SAM2VideoPredictor`:
* it handles the inference of each object separately, as if we are opening a separate session for each object
* it relaxes the assumption on prompting
  * previously if a frame receives clicks only for a subset of objects, the rest of (non-prompted) objects are assumed to be non-existing in this frame
  * now if a frame receives clicks only for a subset of objects, we don't make any assumptions for the remaining (non-prompted) objects
* it allows adding new objects after tracking starts
* (The previous implementation is backed up to `SAM2VideoPredictor` in sam2/sam2_video_predictor_legacy.py)

Also, fix a small typo `APP_URL` => `API_URL` in the doc.

Test plan: tested with the predictor notebook `notebooks/video_predictor_example.ipynb` and VOS script `tools/vos_inference.py`. Also tested with the demo.
This commit is contained in:
Ronghang Hu 2024-12-05 07:49:43 +00:00
parent 3297dd0eb0
commit c61e2475e6
3 changed files with 1299 additions and 328 deletions

View File

@ -105,7 +105,7 @@ cd demo/backend/server/
```bash
PYTORCH_ENABLE_MPS_FALLBACK=1 \
APP_ROOT="$(pwd)/../../../" \
APP_URL=http://localhost:7263 \
API_URL=http://localhost:7263 \
MODEL_SIZE=base_plus \
DATA_PATH="$(pwd)/../../data" \
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \

View File

@ -27,8 +27,6 @@ class SAM2VideoPredictor(SAM2Base):
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
clear_non_cond_mem_around_input=False,
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
clear_non_cond_mem_for_multi_obj=False,
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
add_all_frames_to_correct_as_cond=False,
@ -38,7 +36,6 @@ class SAM2VideoPredictor(SAM2Base):
self.fill_hole_area = fill_hole_area
self.non_overlap_masks = non_overlap_masks
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
@torch.inference_mode()
@ -88,11 +85,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["obj_id_to_idx"] = OrderedDict()
inference_state["obj_idx_to_id"] = OrderedDict()
inference_state["obj_ids"] = []
# A storage to hold the model's tracking results and states on each frame
inference_state["output_dict"] = {
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
inference_state["output_dict_per_obj"] = {}
# A temporary storage to hold new outputs when user interact with a frame
@ -100,13 +92,8 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["temp_output_dict_per_obj"] = {}
# Frames that already holds consolidated outputs from click or mask inputs
# (we directly use their consolidated outputs during tracking)
inference_state["consolidated_frame_inds"] = {
"cond_frame_outputs": set(), # set containing frame indices
"non_cond_frame_outputs": set(), # set containing frame indices
}
# metadata for each tracking frame (e.g. which direction it's tracked)
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {}
inference_state["frames_tracked_per_obj"] = {}
# Warm up the visual backbone and cache the image feature on frame 0
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state
@ -134,9 +121,8 @@ class SAM2VideoPredictor(SAM2Base):
if obj_idx is not None:
return obj_idx
# This is a new object id not sent to the server before. We only allow adding
# new objects *before* the tracking starts.
allow_new_object = not inference_state["tracking_has_started"]
# We always allow adding new objects (including after tracking starts).
allow_new_object = True
if allow_new_object:
# get the next object slot
obj_idx = len(inference_state["obj_id_to_idx"])
@ -154,6 +140,7 @@ class SAM2VideoPredictor(SAM2Base):
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
}
inference_state["frames_tracked_per_obj"][obj_idx] = {}
return obj_idx
else:
raise RuntimeError(
@ -214,15 +201,6 @@ class SAM2VideoPredictor(SAM2Base):
"box prompt must be provided before any point prompt "
"(please use clear_old_points=True instead)"
)
if inference_state["tracking_has_started"]:
warnings.warn(
"You are adding a box after tracking starts. SAM 2 may not always be "
"able to incorporate a box prompt for *refinement*. If you intend to "
"use box prompt as an *initial* input before tracking, please call "
"'reset_state' on the inference state to restart from scratch.",
category=UserWarning,
stacklevel=2,
)
if not isinstance(box, torch.Tensor):
box = torch.tensor(box, dtype=torch.float32, device=points.device)
box_coords = box.reshape(1, 2, 2)
@ -252,12 +230,13 @@ class SAM2VideoPredictor(SAM2Base):
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
reverse = obj_frames_tracked[frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
@ -306,7 +285,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
@ -357,12 +335,13 @@ class SAM2VideoPredictor(SAM2Base):
# frame, meaning that the inputs points are to generate segments on this frame without
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
# the input points will be used to correct the already tracked masks.
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
is_init_cond_frame = frame_idx not in obj_frames_tracked
# whether to track in reverse time order
if is_init_cond_frame:
reverse = False
else:
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
reverse = obj_frames_tracked[frame_idx]["reverse"]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
# Add a frame to conditioning output if it's an initial conditioning frame or
@ -394,7 +373,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
@ -429,7 +407,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state,
frame_idx,
is_cond,
run_mem_encoder,
consolidate_at_video_res=False,
):
"""
@ -446,7 +423,6 @@ class SAM2VideoPredictor(SAM2Base):
# Optionally, we allow consolidating the temporary outputs at the original
# video resolution (to provide a better editing experience for mask prompts).
if consolidate_at_video_res:
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
consolidated_H = inference_state["video_height"]
consolidated_W = inference_state["video_width"]
consolidated_mask_key = "pred_masks_video_res"
@ -459,30 +435,13 @@ class SAM2VideoPredictor(SAM2Base):
# constraints to object scores. Its "pred_masks" are prefilled with a large
# negative value (NO_OBJ_SCORE) to represent missing objects.
consolidated_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
consolidated_mask_key: torch.full(
size=(batch_size, 1, consolidated_H, consolidated_W),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["storage_device"],
),
"obj_ptr": torch.full(
size=(batch_size, self.hidden_dim),
fill_value=NO_OBJ_SCORE,
dtype=torch.float32,
device=inference_state["device"],
),
"object_score_logits": torch.full(
size=(batch_size, 1),
# default to 10.0 for object_score_logits, i.e. assuming the object is
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
fill_value=10.0,
dtype=torch.float32,
device=inference_state["device"],
),
}
empty_mask_ptr = None
for obj_idx in range(batch_size):
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
@ -499,16 +458,6 @@ class SAM2VideoPredictor(SAM2Base):
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
# placeholder above) and set its object pointer to be a dummy pointer.
if out is None:
# Fill in dummy object pointers for those objects without any inputs or
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
# i.e. when we need to build the memory for tracking).
if run_mem_encoder:
if empty_mask_ptr is None:
empty_mask_ptr = self._get_empty_mask_ptr(
inference_state, frame_idx
)
# fill object pointer with a dummy pointer (based on an empty mask)
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
continue
# Add the temporary object output mask to consolidated output mask
obj_mask = out["pred_masks"]
@ -524,141 +473,74 @@ class SAM2VideoPredictor(SAM2Base):
align_corners=False,
)
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
"object_score_logits"
]
# Optionally, apply non-overlapping constraints on the consolidated scores
# and rerun the memory encoder
if run_mem_encoder:
device = inference_state["device"]
high_res_masks = torch.nn.functional.interpolate(
consolidated_out["pred_masks"].to(device, non_blocking=True),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
if self.non_overlap_masks_for_mem_enc:
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=batch_size,
high_res_masks=high_res_masks,
object_score_logits=consolidated_out["object_score_logits"],
is_mask_from_pts=True, # these frames are what the user interacted with
)
consolidated_out["maskmem_features"] = maskmem_features
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
return consolidated_out
def _get_empty_mask_ptr(self, inference_state, frame_idx):
"""Get a dummy object pointer based on an empty mask on the current frame."""
# A dummy (empty) mask with a single object
batch_size = 1
mask_inputs = torch.zeros(
(batch_size, 1, self.image_size, self.image_size),
dtype=torch.float32,
device=inference_state["device"],
)
# Retrieve correct image features
(
_,
_,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
) = self._get_image_feature(inference_state, frame_idx, batch_size)
# Feed the empty mask and image feature above to get a dummy object pointer
current_out = self.track_step(
frame_idx=frame_idx,
is_init_cond_frame=True,
current_vision_feats=current_vision_feats,
current_vision_pos_embeds=current_vision_pos_embeds,
feat_sizes=feat_sizes,
point_inputs=None,
mask_inputs=mask_inputs,
output_dict={},
num_frames=inference_state["num_frames"],
track_in_reverse=False,
run_mem_encoder=False,
prev_sam_mask_logits=None,
)
return current_out["obj_ptr"]
@torch.inference_mode()
def propagate_in_video_preflight(self, inference_state):
"""Prepare inference_state and consolidate temporary outputs before tracking."""
# Tracking has started and we don't allow adding new objects until session is reset.
inference_state["tracking_has_started"] = True
# Check and make sure that every object has received input points or masks.
batch_size = self._get_obj_num(inference_state)
if batch_size == 0:
raise RuntimeError(
"No input points or masks are provided for any object; please add inputs first."
)
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
# add them into "output_dict".
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
output_dict = inference_state["output_dict"]
# "consolidated_frame_inds" contains indices of those frames where consolidated
# temporary outputs have been added (either in this call or any previous calls
# to `propagate_in_video_preflight`).
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
for is_cond in [False, True]:
# Separately consolidate conditioning and non-conditioning temp outputs
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
storage_key = (
"cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
)
# Find all the frames that contain temporary outputs for any objects
# (these should be the frames that have just received clicks for mask inputs
# via `add_new_points_or_box` or `add_new_mask`)
temp_frame_inds = set()
for obj_temp_output_dict in temp_output_dict_per_obj.values():
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
consolidated_frame_inds[storage_key].update(temp_frame_inds)
# consolidate the temporary output across all objects on this frame
for frame_idx in temp_frame_inds:
consolidated_out = self._consolidate_temp_output_across_obj(
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
for frame_idx, out in obj_temp_output_dict[storage_key].items():
# Run memory encoder on the temporary outputs (if the memory feature is missing)
if out["maskmem_features"] is None:
high_res_masks = torch.nn.functional.interpolate(
out["pred_masks"].to(inference_state["device"]),
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
# merge them into "output_dict" and also create per-object slices
output_dict[storage_key][frame_idx] = consolidated_out
self._add_output_per_object(
inference_state, frame_idx, consolidated_out, storage_key
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
inference_state=inference_state,
frame_idx=frame_idx,
batch_size=1, # run on the slice of a single object
high_res_masks=high_res_masks,
object_score_logits=out["object_score_logits"],
# these frames are what the user interacted with
is_mask_from_pts=True,
)
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
if clear_non_cond_mem:
out["maskmem_features"] = maskmem_features
out["maskmem_pos_enc"] = maskmem_pos_enc
obj_output_dict[storage_key][frame_idx] = out
if self.clear_non_cond_mem_around_input:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
self._clear_obj_non_cond_mem_around_input(
inference_state, frame_idx, obj_idx
)
# clear temporary outputs in `temp_output_dict_per_obj`
for obj_temp_output_dict in temp_output_dict_per_obj.values():
obj_temp_output_dict[storage_key].clear()
# check and make sure that every object has received input points or masks
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
if len(obj_output_dict["cond_frame_outputs"]) == 0:
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
raise RuntimeError(
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
)
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
# output on the same frame in "non_cond_frame_outputs"
for frame_idx in output_dict["cond_frame_outputs"]:
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
for frame_idx in obj_output_dict["cond_frame_outputs"]:
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
assert frame_idx in output_dict["cond_frame_outputs"]
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
# with either points or mask inputs (which should be true under a correct workflow).
all_consolidated_frame_inds = (
consolidated_frame_inds["cond_frame_outputs"]
| consolidated_frame_inds["non_cond_frame_outputs"]
)
input_frames_inds = set()
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
input_frames_inds.update(point_inputs_per_frame.keys())
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
input_frames_inds.update(mask_inputs_per_frame.keys())
assert all_consolidated_frame_inds == input_frames_inds
@torch.inference_mode()
def propagate_in_video(
@ -671,21 +553,18 @@ class SAM2VideoPredictor(SAM2Base):
"""Propagate the input points across frames to track in the entire video."""
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
obj_ids = inference_state["obj_ids"]
num_frames = inference_state["num_frames"]
batch_size = self._get_obj_num(inference_state)
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError("No points are provided; please add points first")
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
# set start index, end index, and processing order
if start_frame_idx is None:
# default: start from the earliest frame with input points
start_frame_idx = min(output_dict["cond_frame_outputs"])
start_frame_idx = min(
t
for obj_output_dict in inference_state["output_dict_per_obj"].values()
for t in obj_output_dict["cond_frame_outputs"]
)
if max_frame_num_to_track is None:
# default: track all the frames in the video
max_frame_num_to_track = num_frames
@ -702,78 +581,53 @@ class SAM2VideoPredictor(SAM2Base):
processing_order = range(start_frame_idx, end_frame_idx + 1)
for frame_idx in tqdm(processing_order, desc="propagate in video"):
pred_masks_per_obj = [None] * batch_size
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
# We skip those frames already in consolidated outputs (these are frames
# that received input clicks or mask). Note that we cannot directly run
# batched forward on them via `_run_single_frame_inference` because the
# number of clicks on each object might be different.
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
if frame_idx in obj_output_dict["cond_frame_outputs"]:
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
current_out = obj_output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
if self.clear_non_cond_mem_around_input:
# clear non-conditioning memory of the surrounding frames
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
self._clear_obj_non_cond_mem_around_input(
inference_state, frame_idx, obj_idx
)
else:
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
output_dict=obj_output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
batch_size=1, # run on the slice of a single object
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=reverse,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Create slices of per-object outputs for subsequent interaction with each
# individual object after tracking.
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
obj_output_dict[storage_key][frame_idx] = current_out
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
"reverse": reverse
}
pred_masks_per_obj[obj_idx] = pred_masks
# Resize the output mask to the original video resolution (we directly use
# the mask scores on GPU for output to avoid any CPU conversion in between)
if len(pred_masks_per_obj) > 1:
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
else:
all_pred_masks = pred_masks_per_obj[0]
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
inference_state, all_pred_masks
)
yield frame_idx, obj_ids, video_res_masks
def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key
):
"""
Split a multi-object output into per-object output slices and add them into
`output_dict_per_obj`. The resulting slices share the same tensor storage.
"""
maskmem_features = current_out["maskmem_features"]
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
maskmem_pos_enc = current_out["maskmem_pos_enc"]
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
output_dict_per_obj = inference_state["output_dict_per_obj"]
for obj_idx, obj_output_dict in output_dict_per_obj.items():
obj_slice = slice(obj_idx, obj_idx + 1)
obj_out = {
"maskmem_features": None,
"maskmem_pos_enc": None,
"pred_masks": current_out["pred_masks"][obj_slice],
"obj_ptr": current_out["obj_ptr"][obj_slice],
"object_score_logits": current_out["object_score_logits"][obj_slice],
}
if maskmem_features is not None:
obj_out["maskmem_features"] = maskmem_features[obj_slice]
if maskmem_pos_enc is not None:
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
obj_output_dict[storage_key][frame_idx] = obj_out
@torch.inference_mode()
def clear_all_prompts_in_frame(
self, inference_state, frame_idx, obj_id, need_output=True
@ -789,41 +643,14 @@ class SAM2VideoPredictor(SAM2Base):
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
# Check and see if there are still any inputs left on this frame
batch_size = self._get_obj_num(inference_state)
frame_has_input = False
for obj_idx2 in range(batch_size):
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
frame_has_input = True
break
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
frame_has_input = True
break
# If this frame has no remaining inputs for any objects, we further clear its
# conditioning frame status
if not frame_has_input:
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
if out is not None:
# The frame is not a conditioning frame anymore since it's not receiving inputs,
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
output_dict["non_cond_frame_outputs"][frame_idx] = out
inference_state["frames_already_tracked"].pop(frame_idx, None)
# Similarly, do it for the sliced output on each object.
for obj_idx2 in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
if obj_out is not None:
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
# If all the conditioning frames have been removed, we also clear the tracking outputs
if len(output_dict["cond_frame_outputs"]) == 0:
self._reset_tracking_results(inference_state)
obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
if not need_output:
return
@ -837,7 +664,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
@ -857,6 +683,7 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["mask_inputs_per_obj"].clear()
inference_state["output_dict_per_obj"].clear()
inference_state["temp_output_dict_per_obj"].clear()
inference_state["frames_tracked_per_obj"].clear()
def _reset_tracking_results(self, inference_state):
"""Reset all tracking inputs and results across the videos."""
@ -870,12 +697,8 @@ class SAM2VideoPredictor(SAM2Base):
for v in inference_state["temp_output_dict_per_obj"].values():
v["cond_frame_outputs"].clear()
v["non_cond_frame_outputs"].clear()
inference_state["output_dict"]["cond_frame_outputs"].clear()
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"].clear()
for v in inference_state["frames_tracked_per_obj"].values():
v.clear()
def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
@ -1093,8 +916,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["obj_ids"] = new_obj_ids
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
# it's already handled in Step 0)
def _map_keys(container):
new_kvs = []
for k in old_obj_inds:
@ -1107,30 +928,9 @@ class SAM2VideoPredictor(SAM2Base):
_map_keys(inference_state["mask_inputs_per_obj"])
_map_keys(inference_state["output_dict_per_obj"])
_map_keys(inference_state["temp_output_dict_per_obj"])
_map_keys(inference_state["frames_tracked_per_obj"])
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
def _slice_state(output_dict, storage_key):
for frame_idx, out in output_dict[storage_key].items():
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
out["maskmem_pos_enc"] = [
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
]
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
out["object_score_logits"] = out["object_score_logits"][
remain_old_obj_inds
]
# also update the per-object slices
self._add_output_per_object(
inference_state, frame_idx, out, storage_key
)
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
# could show an updated mask for objects previously occluded by the object being removed
if need_output:
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
@ -1143,7 +943,6 @@ class SAM2VideoPredictor(SAM2Base):
inference_state,
frame_idx,
is_cond=is_cond,
run_mem_encoder=False,
consolidate_at_video_res=True,
)
_, video_res_masks = self._get_orig_video_res_output(
@ -1165,12 +964,12 @@ class SAM2VideoPredictor(SAM2Base):
r = self.memory_temporal_stride_for_eval
frame_idx_begin = frame_idx - r * self.num_maskmem
frame_idx_end = frame_idx + r * self.num_maskmem
output_dict = inference_state["output_dict"]
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
batch_size = self._get_obj_num(inference_state)
for obj_idx in range(batch_size):
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
for t in range(frame_idx_begin, frame_idx_end + 1):
non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
class SAM2VideoPredictorVOS(SAM2VideoPredictor):

File diff suppressed because it is too large Load Diff