mirror of
https://github.com/facebookresearch/sam2.git
synced 2025-09-18 04:32:48 +08:00
switch to a new implementation of the class SAM2VideoPredictor
for per-object inference sam2/sam2_video_predictor.py
In this PR, we switch to a new implementation of the class `SAM2VideoPredictor` for in sam2/sam2_video_predictor.py, which allows for independent per-object inference. Specifically, the new `SAM2VideoPredictor`: * it handles the inference of each object separately, as if we are opening a separate session for each object * it relaxes the assumption on prompting * previously if a frame receives clicks only for a subset of objects, the rest of (non-prompted) objects are assumed to be non-existing in this frame * now if a frame receives clicks only for a subset of objects, we don't make any assumptions for the remaining (non-prompted) objects * it allows adding new objects after tracking starts * (The previous implementation is backed up to `SAM2VideoPredictor` in sam2/sam2_video_predictor_legacy.py) Also, fix a small typo `APP_URL` => `API_URL` in the doc. Test plan: tested with the predictor notebook `notebooks/video_predictor_example.ipynb` and VOS script `tools/vos_inference.py`. Also tested with the demo.
This commit is contained in:
parent
3297dd0eb0
commit
c61e2475e6
@ -105,7 +105,7 @@ cd demo/backend/server/
|
||||
```bash
|
||||
PYTORCH_ENABLE_MPS_FALLBACK=1 \
|
||||
APP_ROOT="$(pwd)/../../../" \
|
||||
APP_URL=http://localhost:7263 \
|
||||
API_URL=http://localhost:7263 \
|
||||
MODEL_SIZE=base_plus \
|
||||
DATA_PATH="$(pwd)/../../data" \
|
||||
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \
|
||||
|
@ -27,8 +27,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
||||
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
||||
clear_non_cond_mem_around_input=False,
|
||||
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
||||
clear_non_cond_mem_for_multi_obj=False,
|
||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
@ -38,7 +36,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
self.fill_hole_area = fill_hole_area
|
||||
self.non_overlap_masks = non_overlap_masks
|
||||
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
||||
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -88,11 +85,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||
inference_state["obj_ids"] = []
|
||||
# A storage to hold the model's tracking results and states on each frame
|
||||
inference_state["output_dict"] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||
inference_state["output_dict_per_obj"] = {}
|
||||
# A temporary storage to hold new outputs when user interact with a frame
|
||||
@ -100,13 +92,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["temp_output_dict_per_obj"] = {}
|
||||
# Frames that already holds consolidated outputs from click or mask inputs
|
||||
# (we directly use their consolidated outputs during tracking)
|
||||
inference_state["consolidated_frame_inds"] = {
|
||||
"cond_frame_outputs": set(), # set containing frame indices
|
||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||
}
|
||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"] = {}
|
||||
inference_state["frames_tracked_per_obj"] = {}
|
||||
# Warm up the visual backbone and cache the image feature on frame 0
|
||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||
return inference_state
|
||||
@ -134,9 +121,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
if obj_idx is not None:
|
||||
return obj_idx
|
||||
|
||||
# This is a new object id not sent to the server before. We only allow adding
|
||||
# new objects *before* the tracking starts.
|
||||
allow_new_object = not inference_state["tracking_has_started"]
|
||||
# We always allow adding new objects (including after tracking starts).
|
||||
allow_new_object = True
|
||||
if allow_new_object:
|
||||
# get the next object slot
|
||||
obj_idx = len(inference_state["obj_id_to_idx"])
|
||||
@ -154,6 +140,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
inference_state["frames_tracked_per_obj"][obj_idx] = {}
|
||||
return obj_idx
|
||||
else:
|
||||
raise RuntimeError(
|
||||
@ -214,15 +201,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"box prompt must be provided before any point prompt "
|
||||
"(please use clear_old_points=True instead)"
|
||||
)
|
||||
if inference_state["tracking_has_started"]:
|
||||
warnings.warn(
|
||||
"You are adding a box after tracking starts. SAM 2 may not always be "
|
||||
"able to incorporate a box prompt for *refinement*. If you intend to "
|
||||
"use box prompt as an *initial* input before tracking, please call "
|
||||
"'reset_state' on the inference state to restart from scratch.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not isinstance(box, torch.Tensor):
|
||||
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
||||
box_coords = box.reshape(1, 2, 2)
|
||||
@ -252,12 +230,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||
# the input points will be used to correct the already tracked masks.
|
||||
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
||||
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
||||
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||
# whether to track in reverse time order
|
||||
if is_init_cond_frame:
|
||||
reverse = False
|
||||
else:
|
||||
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
||||
reverse = obj_frames_tracked[frame_idx]["reverse"]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||
@ -306,7 +285,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@ -357,12 +335,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||
# the input points will be used to correct the already tracked masks.
|
||||
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
||||
obj_frames_tracked = inference_state["frames_tracked_per_obj"][obj_idx]
|
||||
is_init_cond_frame = frame_idx not in obj_frames_tracked
|
||||
# whether to track in reverse time order
|
||||
if is_init_cond_frame:
|
||||
reverse = False
|
||||
else:
|
||||
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
||||
reverse = obj_frames_tracked[frame_idx]["reverse"]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||
@ -394,7 +373,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@ -429,7 +407,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond,
|
||||
run_mem_encoder,
|
||||
consolidate_at_video_res=False,
|
||||
):
|
||||
"""
|
||||
@ -446,7 +423,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# Optionally, we allow consolidating the temporary outputs at the original
|
||||
# video resolution (to provide a better editing experience for mask prompts).
|
||||
if consolidate_at_video_res:
|
||||
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
||||
consolidated_H = inference_state["video_height"]
|
||||
consolidated_W = inference_state["video_width"]
|
||||
consolidated_mask_key = "pred_masks_video_res"
|
||||
@ -459,30 +435,13 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
||||
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
||||
consolidated_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
consolidated_mask_key: torch.full(
|
||||
size=(batch_size, 1, consolidated_H, consolidated_W),
|
||||
fill_value=NO_OBJ_SCORE,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["storage_device"],
|
||||
),
|
||||
"obj_ptr": torch.full(
|
||||
size=(batch_size, self.hidden_dim),
|
||||
fill_value=NO_OBJ_SCORE,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
),
|
||||
"object_score_logits": torch.full(
|
||||
size=(batch_size, 1),
|
||||
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
||||
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
||||
fill_value=10.0,
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
),
|
||||
}
|
||||
empty_mask_ptr = None
|
||||
for obj_idx in range(batch_size):
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
@ -499,16 +458,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
||||
# placeholder above) and set its object pointer to be a dummy pointer.
|
||||
if out is None:
|
||||
# Fill in dummy object pointers for those objects without any inputs or
|
||||
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
||||
# i.e. when we need to build the memory for tracking).
|
||||
if run_mem_encoder:
|
||||
if empty_mask_ptr is None:
|
||||
empty_mask_ptr = self._get_empty_mask_ptr(
|
||||
inference_state, frame_idx
|
||||
)
|
||||
# fill object pointer with a dummy pointer (based on an empty mask)
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
||||
continue
|
||||
# Add the temporary object output mask to consolidated output mask
|
||||
obj_mask = out["pred_masks"]
|
||||
@ -524,141 +473,74 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
align_corners=False,
|
||||
)
|
||||
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
||||
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
||||
"object_score_logits"
|
||||
]
|
||||
|
||||
# Optionally, apply non-overlapping constraints on the consolidated scores
|
||||
# and rerun the memory encoder
|
||||
if run_mem_encoder:
|
||||
device = inference_state["device"]
|
||||
high_res_masks = torch.nn.functional.interpolate(
|
||||
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
if self.non_overlap_masks_for_mem_enc:
|
||||
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
||||
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=batch_size,
|
||||
high_res_masks=high_res_masks,
|
||||
object_score_logits=consolidated_out["object_score_logits"],
|
||||
is_mask_from_pts=True, # these frames are what the user interacted with
|
||||
)
|
||||
consolidated_out["maskmem_features"] = maskmem_features
|
||||
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
|
||||
return consolidated_out
|
||||
|
||||
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
||||
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
||||
# A dummy (empty) mask with a single object
|
||||
batch_size = 1
|
||||
mask_inputs = torch.zeros(
|
||||
(batch_size, 1, self.image_size, self.image_size),
|
||||
dtype=torch.float32,
|
||||
device=inference_state["device"],
|
||||
)
|
||||
|
||||
# Retrieve correct image features
|
||||
(
|
||||
_,
|
||||
_,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
||||
|
||||
# Feed the empty mask and image feature above to get a dummy object pointer
|
||||
current_out = self.track_step(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=True,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=None,
|
||||
mask_inputs=mask_inputs,
|
||||
output_dict={},
|
||||
num_frames=inference_state["num_frames"],
|
||||
track_in_reverse=False,
|
||||
run_mem_encoder=False,
|
||||
prev_sam_mask_logits=None,
|
||||
)
|
||||
return current_out["obj_ptr"]
|
||||
|
||||
@torch.inference_mode()
|
||||
def propagate_in_video_preflight(self, inference_state):
|
||||
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
||||
# Tracking has started and we don't allow adding new objects until session is reset.
|
||||
inference_state["tracking_has_started"] = True
|
||||
# Check and make sure that every object has received input points or masks.
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
if batch_size == 0:
|
||||
raise RuntimeError(
|
||||
"No input points or masks are provided for any object; please add inputs first."
|
||||
)
|
||||
|
||||
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||
# add them into "output_dict".
|
||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||
output_dict = inference_state["output_dict"]
|
||||
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
||||
# temporary outputs have been added (either in this call or any previous calls
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temporary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
for is_cond in [False, True]:
|
||||
# Separately consolidate conditioning and non-conditioning temp outputs
|
||||
storage_key = (
|
||||
"cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
)
|
||||
# merge them into "output_dict" and also create per-object slices
|
||||
output_dict[storage_key][frame_idx] = consolidated_out
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, consolidated_out, storage_key
|
||||
)
|
||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
||||
)
|
||||
if clear_non_cond_mem:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points_or_box` or `add_new_mask`)
|
||||
for frame_idx, out in obj_temp_output_dict[storage_key].items():
|
||||
# Run memory encoder on the temporary outputs (if the memory feature is missing)
|
||||
if out["maskmem_features"] is None:
|
||||
high_res_masks = torch.nn.functional.interpolate(
|
||||
out["pred_masks"].to(inference_state["device"]),
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
||||
inference_state=inference_state,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1, # run on the slice of a single object
|
||||
high_res_masks=high_res_masks,
|
||||
object_score_logits=out["object_score_logits"],
|
||||
# these frames are what the user interacted with
|
||||
is_mask_from_pts=True,
|
||||
)
|
||||
out["maskmem_features"] = maskmem_features
|
||||
out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
|
||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
obj_output_dict[storage_key][frame_idx] = out
|
||||
if self.clear_non_cond_mem_around_input:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_obj_non_cond_mem_around_input(
|
||||
inference_state, frame_idx, obj_idx
|
||||
)
|
||||
|
||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||
obj_temp_output_dict[storage_key].clear()
|
||||
|
||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||
# output on the same frame in "non_cond_frame_outputs"
|
||||
for frame_idx in output_dict["cond_frame_outputs"]:
|
||||
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
||||
# check and make sure that every object has received input points or masks
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
if len(obj_output_dict["cond_frame_outputs"]) == 0:
|
||||
obj_id = self._obj_idx_to_id(inference_state, obj_idx)
|
||||
raise RuntimeError(
|
||||
f"No input points or masks are provided for object id {obj_id}; please add inputs first."
|
||||
)
|
||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||
# output on the same frame in "non_cond_frame_outputs"
|
||||
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
assert frame_idx in output_dict["cond_frame_outputs"]
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
|
||||
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
||||
# with either points or mask inputs (which should be true under a correct workflow).
|
||||
all_consolidated_frame_inds = (
|
||||
consolidated_frame_inds["cond_frame_outputs"]
|
||||
| consolidated_frame_inds["non_cond_frame_outputs"]
|
||||
)
|
||||
input_frames_inds = set()
|
||||
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
||||
input_frames_inds.update(point_inputs_per_frame.keys())
|
||||
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
||||
input_frames_inds.update(mask_inputs_per_frame.keys())
|
||||
assert all_consolidated_frame_inds == input_frames_inds
|
||||
|
||||
@torch.inference_mode()
|
||||
def propagate_in_video(
|
||||
@ -671,21 +553,18 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
"""Propagate the input points across frames to track in the entire video."""
|
||||
self.propagate_in_video_preflight(inference_state)
|
||||
|
||||
output_dict = inference_state["output_dict"]
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
obj_ids = inference_state["obj_ids"]
|
||||
num_frames = inference_state["num_frames"]
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
raise RuntimeError("No points are provided; please add points first")
|
||||
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
||||
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
||||
)
|
||||
|
||||
# set start index, end index, and processing order
|
||||
if start_frame_idx is None:
|
||||
# default: start from the earliest frame with input points
|
||||
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
||||
start_frame_idx = min(
|
||||
t
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values()
|
||||
for t in obj_output_dict["cond_frame_outputs"]
|
||||
)
|
||||
if max_frame_num_to_track is None:
|
||||
# default: track all the frames in the video
|
||||
max_frame_num_to_track = num_frames
|
||||
@ -702,78 +581,53 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
||||
|
||||
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
||||
# We skip those frames already in consolidated outputs (these are frames
|
||||
# that received input clicks or mask). Note that we cannot directly run
|
||||
# batched forward on them via `_run_single_frame_inference` because the
|
||||
# number of clicks on each object might be different.
|
||||
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
if clear_non_cond_mem:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
||||
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out, pred_masks = self._run_single_frame_inference(
|
||||
inference_state=inference_state,
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=batch_size,
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=reverse,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
output_dict[storage_key][frame_idx] = current_out
|
||||
# Create slices of per-object outputs for subsequent interaction with each
|
||||
# individual object after tracking.
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, current_out, storage_key
|
||||
)
|
||||
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
||||
pred_masks_per_obj = [None] * batch_size
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
# We skip those frames already in consolidated outputs (these are frames
|
||||
# that received input clicks or mask). Note that we cannot directly run
|
||||
# batched forward on them via `_run_single_frame_inference` because the
|
||||
# number of clicks on each object might be different.
|
||||
if frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = obj_output_dict[storage_key][frame_idx]
|
||||
pred_masks = current_out["pred_masks"]
|
||||
if self.clear_non_cond_mem_around_input:
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_obj_non_cond_mem_around_input(
|
||||
inference_state, frame_idx, obj_idx
|
||||
)
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out, pred_masks = self._run_single_frame_inference(
|
||||
inference_state=inference_state,
|
||||
output_dict=obj_output_dict,
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1, # run on the slice of a single object
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=reverse,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
obj_output_dict[storage_key][frame_idx] = current_out
|
||||
|
||||
inference_state["frames_tracked_per_obj"][obj_idx][frame_idx] = {
|
||||
"reverse": reverse
|
||||
}
|
||||
pred_masks_per_obj[obj_idx] = pred_masks
|
||||
|
||||
# Resize the output mask to the original video resolution (we directly use
|
||||
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
||||
if len(pred_masks_per_obj) > 1:
|
||||
all_pred_masks = torch.cat(pred_masks_per_obj, dim=0)
|
||||
else:
|
||||
all_pred_masks = pred_masks_per_obj[0]
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
inference_state, pred_masks
|
||||
inference_state, all_pred_masks
|
||||
)
|
||||
yield frame_idx, obj_ids, video_res_masks
|
||||
|
||||
def _add_output_per_object(
|
||||
self, inference_state, frame_idx, current_out, storage_key
|
||||
):
|
||||
"""
|
||||
Split a multi-object output into per-object output slices and add them into
|
||||
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
||||
"""
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
||||
|
||||
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
||||
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
||||
|
||||
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
||||
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
||||
obj_slice = slice(obj_idx, obj_idx + 1)
|
||||
obj_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": current_out["pred_masks"][obj_slice],
|
||||
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
||||
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
||||
}
|
||||
if maskmem_features is not None:
|
||||
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
||||
if maskmem_pos_enc is not None:
|
||||
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
||||
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||
|
||||
@torch.inference_mode()
|
||||
def clear_all_prompts_in_frame(
|
||||
self, inference_state, frame_idx, obj_id, need_output=True
|
||||
@ -789,41 +643,14 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
||||
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
|
||||
# Check and see if there are still any inputs left on this frame
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
frame_has_input = False
|
||||
for obj_idx2 in range(batch_size):
|
||||
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
||||
frame_has_input = True
|
||||
break
|
||||
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
||||
frame_has_input = True
|
||||
break
|
||||
|
||||
# If this frame has no remaining inputs for any objects, we further clear its
|
||||
# conditioning frame status
|
||||
if not frame_has_input:
|
||||
output_dict = inference_state["output_dict"]
|
||||
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
||||
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if out is not None:
|
||||
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
||||
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
||||
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
||||
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
||||
# Similarly, do it for the sliced output on each object.
|
||||
for obj_idx2 in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
||||
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if obj_out is not None:
|
||||
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
||||
|
||||
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
self._reset_tracking_results(inference_state)
|
||||
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
||||
if out is not None:
|
||||
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
||||
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
||||
obj_output_dict["non_cond_frame_outputs"][frame_idx] = out
|
||||
inference_state["frames_tracked_per_obj"][obj_idx].pop(frame_idx, None)
|
||||
|
||||
if not need_output:
|
||||
return
|
||||
@ -837,7 +664,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@ -857,6 +683,7 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["mask_inputs_per_obj"].clear()
|
||||
inference_state["output_dict_per_obj"].clear()
|
||||
inference_state["temp_output_dict_per_obj"].clear()
|
||||
inference_state["frames_tracked_per_obj"].clear()
|
||||
|
||||
def _reset_tracking_results(self, inference_state):
|
||||
"""Reset all tracking inputs and results across the videos."""
|
||||
@ -870,12 +697,8 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
for v in inference_state["temp_output_dict_per_obj"].values():
|
||||
v["cond_frame_outputs"].clear()
|
||||
v["non_cond_frame_outputs"].clear()
|
||||
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
||||
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
||||
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
||||
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"].clear()
|
||||
for v in inference_state["frames_tracked_per_obj"].values():
|
||||
v.clear()
|
||||
|
||||
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
||||
"""Compute the image features on a given frame."""
|
||||
@ -1093,8 +916,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state["obj_ids"] = new_obj_ids
|
||||
|
||||
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
||||
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
||||
# it's already handled in Step 0)
|
||||
def _map_keys(container):
|
||||
new_kvs = []
|
||||
for k in old_obj_inds:
|
||||
@ -1107,30 +928,9 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
_map_keys(inference_state["mask_inputs_per_obj"])
|
||||
_map_keys(inference_state["output_dict_per_obj"])
|
||||
_map_keys(inference_state["temp_output_dict_per_obj"])
|
||||
_map_keys(inference_state["frames_tracked_per_obj"])
|
||||
|
||||
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
||||
def _slice_state(output_dict, storage_key):
|
||||
for frame_idx, out in output_dict[storage_key].items():
|
||||
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
||||
out["maskmem_pos_enc"] = [
|
||||
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
||||
]
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
||||
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
||||
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
||||
out["object_score_logits"] = out["object_score_logits"][
|
||||
remain_old_obj_inds
|
||||
]
|
||||
# also update the per-object slices
|
||||
self._add_output_per_object(
|
||||
inference_state, frame_idx, out, storage_key
|
||||
)
|
||||
|
||||
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
||||
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
||||
|
||||
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
||||
# Step 3: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
||||
# could show an updated mask for objects previously occluded by the object being removed
|
||||
if need_output:
|
||||
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
||||
@ -1143,7 +943,6 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
inference_state,
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
consolidate_at_video_res=True,
|
||||
)
|
||||
_, video_res_masks = self._get_orig_video_res_output(
|
||||
@ -1165,12 +964,12 @@ class SAM2VideoPredictor(SAM2Base):
|
||||
r = self.memory_temporal_stride_for_eval
|
||||
frame_idx_begin = frame_idx - r * self.num_maskmem
|
||||
frame_idx_end = frame_idx + r * self.num_maskmem
|
||||
output_dict = inference_state["output_dict"]
|
||||
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||
non_cond_frame_outputs.pop(t, None)
|
||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
||||
batch_size = self._get_obj_num(inference_state)
|
||||
for obj_idx in range(batch_size):
|
||||
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
||||
non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
|
||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||
non_cond_frame_outputs.pop(t, None)
|
||||
|
||||
|
||||
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
||||
|
1172
sam2/sam2_video_predictor_legacy.py
Normal file
1172
sam2/sam2_video_predictor_legacy.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user