mirror of
https://github.com/facebookresearch/sam2.git
synced 2025-09-18 12:42:48 +08:00
Merge pull request #153 from fairinternal/chay/improve_speed_v1
speed optimizations cleanup
This commit is contained in:
commit
3297dd0eb0
19
README.md
19
README.md
@ -158,10 +158,10 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
|||||||
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
|
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
|
||||||
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
||||||
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
||||||
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 |
|
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
|
||||||
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 |
|
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
|
||||||
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 |
|
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
|
||||||
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 |
|
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
|
||||||
|
|
||||||
### SAM 2 checkpoints
|
### SAM 2 checkpoints
|
||||||
|
|
||||||
@ -169,13 +169,12 @@ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows
|
|||||||
|
|
||||||
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
|
||||||
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
|
||||||
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
|
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
|
||||||
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
|
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
|
||||||
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
|
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
|
||||||
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
|
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
|
||||||
|
|
||||||
\* Compile the model by setting `compile_image_encoder: True` in the config.
|
|
||||||
|
|
||||||
|
Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
|
||||||
## Segment Anything Video Dataset
|
## Segment Anything Video Dataset
|
||||||
|
|
||||||
See [sav_dataset/README.md](sav_dataset/README.md) for details.
|
See [sav_dataset/README.md](sav_dataset/README.md) for details.
|
||||||
|
92
sam2/benchmark.py
Normal file
92
sam2/benchmark.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sam2.build_sam import build_sam2_video_predictor
|
||||||
|
|
||||||
|
# Only cuda supported
|
||||||
|
assert torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Config and checkpoint
|
||||||
|
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
|
||||||
|
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
||||||
|
|
||||||
|
# Build video predictor with vos_optimized=True setting
|
||||||
|
predictor = build_sam2_video_predictor(
|
||||||
|
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize with video
|
||||||
|
video_dir = "notebooks/videos/bedroom"
|
||||||
|
# scan all the JPEG frame names in this directory
|
||||||
|
frame_names = [
|
||||||
|
p
|
||||||
|
for p in os.listdir(video_dir)
|
||||||
|
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||||
|
]
|
||||||
|
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||||
|
inference_state = predictor.init_state(video_path=video_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# Number of runs, warmup etc
|
||||||
|
warm_up, runs = 5, 25
|
||||||
|
verbose = True
|
||||||
|
num_frames = len(frame_names)
|
||||||
|
total, count = 0, 0
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# We will select an object with a click.
|
||||||
|
# See video_predictor_example.ipynb for more detailed explanation
|
||||||
|
ann_frame_idx, ann_obj_id = 0, 1
|
||||||
|
# Add a positive click at (x, y) = (210, 350)
|
||||||
|
# For labels, `1` means positive click
|
||||||
|
points = np.array([[210, 350]], dtype=np.float32)
|
||||||
|
labels = np.array([1], np.int32)
|
||||||
|
|
||||||
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
|
||||||
|
inference_state=inference_state,
|
||||||
|
frame_idx=ann_frame_idx,
|
||||||
|
obj_id=ann_obj_id,
|
||||||
|
points=points,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup and then average FPS over several runs
|
||||||
|
with torch.autocast("cuda", torch.bfloat16):
|
||||||
|
with torch.inference_mode():
|
||||||
|
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
|
||||||
|
start = time.time()
|
||||||
|
# Start tracking
|
||||||
|
for (
|
||||||
|
out_frame_idx,
|
||||||
|
out_obj_ids,
|
||||||
|
out_mask_logits,
|
||||||
|
) in predictor.propagate_in_video(inference_state):
|
||||||
|
pass
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
total += end - start
|
||||||
|
count += 1
|
||||||
|
if i == warm_up - 1:
|
||||||
|
print("Warmup FPS: ", count * num_frames / total)
|
||||||
|
total = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
print("FPS: ", count * num_frames / total)
|
@ -104,11 +104,18 @@ def build_sam2_video_predictor(
|
|||||||
mode="eval",
|
mode="eval",
|
||||||
hydra_overrides_extra=[],
|
hydra_overrides_extra=[],
|
||||||
apply_postprocessing=True,
|
apply_postprocessing=True,
|
||||||
|
vos_optimized=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
hydra_overrides = [
|
hydra_overrides = [
|
||||||
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
||||||
]
|
]
|
||||||
|
if vos_optimized:
|
||||||
|
hydra_overrides = [
|
||||||
|
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
|
||||||
|
"++model.compile_image_encoder=True", # Let sam2_base handle this
|
||||||
|
]
|
||||||
|
|
||||||
if apply_postprocessing:
|
if apply_postprocessing:
|
||||||
hydra_overrides_extra = hydra_overrides_extra.copy()
|
hydra_overrides_extra = hydra_overrides_extra.copy()
|
||||||
hydra_overrides_extra += [
|
hydra_overrides_extra += [
|
||||||
|
@ -36,7 +36,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -47,7 +47,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -40,7 +40,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -51,7 +51,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -97,7 +97,7 @@ trainer:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -108,7 +108,7 @@ trainer:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -36,7 +36,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -47,7 +47,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -40,7 +40,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -51,7 +51,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -39,7 +39,7 @@ model:
|
|||||||
self_attention:
|
self_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
downsample_rate: 1
|
downsample_rate: 1
|
||||||
@ -50,7 +50,7 @@ model:
|
|||||||
cross_attention:
|
cross_attention:
|
||||||
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
||||||
rope_theta: 10000.0
|
rope_theta: 10000.0
|
||||||
feat_sizes: [32, 32]
|
feat_sizes: [64, 64]
|
||||||
rope_k_repeat: True
|
rope_k_repeat: True
|
||||||
embedding_dim: 256
|
embedding_dim: 256
|
||||||
num_heads: 1
|
num_heads: 1
|
||||||
|
@ -32,9 +32,7 @@ def window_partition(x, window_size):
|
|||||||
Hp, Wp = H + pad_h, W + pad_w
|
Hp, Wp = H + pad_h, W + pad_w
|
||||||
|
|
||||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||||
windows = (
|
windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
||||||
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
||||||
)
|
|
||||||
return windows, (Hp, Wp)
|
return windows, (Hp, Wp)
|
||||||
|
|
||||||
|
|
||||||
@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
|||||||
Hp, Wp = pad_hw
|
Hp, Wp = pad_hw
|
||||||
H, W = hw
|
H, W = hw
|
||||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||||
x = windows.view(
|
x = windows.reshape(
|
||||||
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
||||||
)
|
)
|
||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
||||||
|
|
||||||
if Hp > H or Wp > W:
|
if Hp > H or Wp > W:
|
||||||
x = x[:, :H, :W, :].contiguous()
|
x = x[:, :H, :W, :]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
temperature: int = 10000,
|
temperature: int = 10000,
|
||||||
normalize: bool = True,
|
normalize: bool = True,
|
||||||
scale: Optional[float] = None,
|
scale: Optional[float] = None,
|
||||||
|
# Following settings only relevant
|
||||||
|
# for warmping up cache for compilation
|
||||||
|
warmup_cache: bool = True,
|
||||||
|
image_size: int = 1024,
|
||||||
|
strides: Tuple[int] = (4, 8, 16, 32),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
||||||
@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
if warmup_cache and torch.cuda.is_available():
|
||||||
|
# Warmup cache for cuda, to help with compilation
|
||||||
|
device = torch.device("cuda")
|
||||||
|
for stride in strides:
|
||||||
|
cache_key = (image_size // stride, image_size // stride)
|
||||||
|
self._pe(1, device, *cache_key)
|
||||||
|
|
||||||
def _encode_xy(self, x, y):
|
def _encode_xy(self, x, y):
|
||||||
# The positions are expected to be normalized
|
# The positions are expected to be normalized
|
||||||
@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
return pos
|
return pos
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x: torch.Tensor):
|
def _pe(self, B, device, *cache_key):
|
||||||
cache_key = (x.shape[-2], x.shape[-1])
|
H, W = cache_key
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
||||||
|
|
||||||
y_embed = (
|
y_embed = (
|
||||||
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
||||||
.view(1, -1, 1)
|
.view(1, -1, 1)
|
||||||
.repeat(x.shape[0], 1, x.shape[-1])
|
.repeat(B, 1, W)
|
||||||
)
|
)
|
||||||
x_embed = (
|
x_embed = (
|
||||||
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
torch.arange(1, W + 1, dtype=torch.float32, device=device)
|
||||||
.view(1, 1, -1)
|
.view(1, 1, -1)
|
||||||
.repeat(x.shape[0], x.shape[-2], 1)
|
.repeat(B, H, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
self.cache[cache_key] = pos[0]
|
self.cache[cache_key] = pos[0]
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
B = x.shape[0]
|
||||||
|
cache_key = (x.shape[-2], x.shape[-1])
|
||||||
|
return self._pe(B, x.device, *cache_key)
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingRandom(nn.Module):
|
class PositionEmbeddingRandom(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
|
|||||||
point_embedding = self.pe_layer.forward_with_coords(
|
point_embedding = self.pe_layer.forward_with_coords(
|
||||||
points, self.input_image_size
|
points, self.input_image_size
|
||||||
)
|
)
|
||||||
point_embedding[labels == -1] = 0.0
|
|
||||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
point_embedding = torch.where(
|
||||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
(labels == -1).unsqueeze(-1),
|
||||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
|
||||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
point_embedding,
|
||||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 0).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[0].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 1).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[1].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 2).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[2].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
|
point_embedding = torch.where(
|
||||||
|
(labels == 3).unsqueeze(-1),
|
||||||
|
point_embedding + self.point_embeddings[3].weight,
|
||||||
|
point_embedding,
|
||||||
|
)
|
||||||
return point_embedding
|
return point_embedding
|
||||||
|
|
||||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -4,9 +4,7 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
@ -16,29 +14,6 @@ from torch import nn, Tensor
|
|||||||
|
|
||||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
||||||
from sam2.modeling.sam2_utils import MLP
|
from sam2.modeling.sam2_utils import MLP
|
||||||
from sam2.utils.misc import get_sdpa_settings
|
|
||||||
|
|
||||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
||||||
# Check whether Flash Attention is available (and use it by default)
|
|
||||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
|
||||||
# A fallback setting to allow all available kernels if Flash Attention fails
|
|
||||||
ALLOW_ALL_KERNELS = False
|
|
||||||
|
|
||||||
|
|
||||||
def sdp_kernel_context(dropout_p):
|
|
||||||
"""
|
|
||||||
Get the context for the attention scaled dot-product kernel. We use Flash Attention
|
|
||||||
by default, but fall back to all available kernels if Flash Attention fails.
|
|
||||||
"""
|
|
||||||
if ALLOW_ALL_KERNELS:
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
return torch.backends.cuda.sdp_kernel(
|
|
||||||
enable_flash=USE_FLASH_ATTN,
|
|
||||||
# if Flash attention kernel is off, then math kernel needs to be enabled
|
|
||||||
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
|
|
||||||
enable_mem_efficient=OLD_GPU,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TwoWayTransformer(nn.Module):
|
class TwoWayTransformer(nn.Module):
|
||||||
@ -265,19 +240,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
dropout_p = self.dropout_p if self.training else 0.0
|
dropout_p = self.dropout_p if self.training else 0.0
|
||||||
# Attention
|
# Attention
|
||||||
try:
|
|
||||||
with sdp_kernel_context(dropout_p):
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
||||||
except Exception as e:
|
|
||||||
# Fall back to all kernels if the Flash attention kernel fails
|
|
||||||
warnings.warn(
|
|
||||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
|
||||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
|
||||||
category=UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
global ALLOW_ALL_KERNELS
|
|
||||||
ALLOW_ALL_KERNELS = True
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||||
|
|
||||||
out = self._recombine_heads(out)
|
out = self._recombine_heads(out)
|
||||||
@ -296,7 +258,7 @@ class RoPEAttention(Attention):
|
|||||||
# whether to repeat q rope to match k length
|
# whether to repeat q rope to match k length
|
||||||
# this is needed for cross-attention to memories
|
# this is needed for cross-attention to memories
|
||||||
rope_k_repeat=False,
|
rope_k_repeat=False,
|
||||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -305,7 +267,9 @@ class RoPEAttention(Attention):
|
|||||||
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
|
||||||
)
|
)
|
||||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
||||||
self.freqs_cis = freqs_cis
|
self.freqs_cis = (
|
||||||
|
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
|
||||||
|
)
|
||||||
self.rope_k_repeat = rope_k_repeat
|
self.rope_k_repeat = rope_k_repeat
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -339,19 +303,6 @@ class RoPEAttention(Attention):
|
|||||||
|
|
||||||
dropout_p = self.dropout_p if self.training else 0.0
|
dropout_p = self.dropout_p if self.training else 0.0
|
||||||
# Attention
|
# Attention
|
||||||
try:
|
|
||||||
with sdp_kernel_context(dropout_p):
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
|
||||||
except Exception as e:
|
|
||||||
# Fall back to all kernels if the Flash attention kernel fails
|
|
||||||
warnings.warn(
|
|
||||||
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
|
|
||||||
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
|
|
||||||
category=UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
global ALLOW_ALL_KERNELS
|
|
||||||
ALLOW_ALL_KERNELS = True
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
|
||||||
|
|
||||||
out = self._recombine_heads(out)
|
out = self._recombine_heads(out)
|
||||||
|
@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module):
|
|||||||
if self.add_tpos_enc_to_obj_ptrs:
|
if self.add_tpos_enc_to_obj_ptrs:
|
||||||
t_diff_max = max_obj_ptrs_in_encoder - 1
|
t_diff_max = max_obj_ptrs_in_encoder - 1
|
||||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
||||||
obj_pos = torch.tensor(pos_list, device=device)
|
obj_pos = (
|
||||||
|
torch.tensor(pos_list)
|
||||||
|
.pin_memory()
|
||||||
|
.to(device=device, non_blocking=True)
|
||||||
|
)
|
||||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
||||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
||||||
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
||||||
|
@ -8,6 +8,7 @@ import warnings
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -1170,3 +1171,255 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
non_cond_frame_outputs.pop(t, None)
|
non_cond_frame_outputs.pop(t, None)
|
||||||
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
||||||
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
||||||
|
|
||||||
|
|
||||||
|
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
||||||
|
"""Optimized for the VOS setting"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._compile_all_components()
|
||||||
|
|
||||||
|
def _compile_all_components(self):
|
||||||
|
print(
|
||||||
|
"Compiling all components for for vos setting. First time may be very slow."
|
||||||
|
)
|
||||||
|
self.memory_encoder.forward = torch.compile(
|
||||||
|
self.memory_encoder.forward,
|
||||||
|
mode="max-autotune",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.memory_attention.forward = torch.compile(
|
||||||
|
self.memory_attention.forward,
|
||||||
|
mode="max-autotune",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=True, # Num. of memories varies
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sam_prompt_encoder.forward = torch.compile(
|
||||||
|
self.sam_prompt_encoder.forward,
|
||||||
|
mode="max-autotune",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False, # Accuracy regression on True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sam_mask_decoder.forward = torch.compile(
|
||||||
|
self.sam_mask_decoder.forward,
|
||||||
|
mode="max-autotune",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False, # Accuracy regression on True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_image(self, img_batch: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
||||||
|
cloning the backbone features and pos encoding to enable compilation.
|
||||||
|
"""
|
||||||
|
backbone_out = self.image_encoder(img_batch)
|
||||||
|
if self.use_high_res_features_in_sam:
|
||||||
|
# precompute projected level 0 and level 1 features in SAM decoder
|
||||||
|
# to avoid running it again on every SAM click
|
||||||
|
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
|
||||||
|
backbone_out["backbone_fpn"][0]
|
||||||
|
)
|
||||||
|
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
|
||||||
|
backbone_out["backbone_fpn"][1]
|
||||||
|
)
|
||||||
|
# Clone to help torch.compile
|
||||||
|
for i in range(len(backbone_out["backbone_fpn"])):
|
||||||
|
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
|
||||||
|
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
|
||||||
|
i
|
||||||
|
].clone()
|
||||||
|
return backbone_out
|
||||||
|
|
||||||
|
def _forward_sam_heads(
|
||||||
|
self,
|
||||||
|
backbone_features,
|
||||||
|
point_inputs=None,
|
||||||
|
mask_inputs=None,
|
||||||
|
high_res_features=None,
|
||||||
|
multimask_output=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
||||||
|
cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
|
||||||
|
"""
|
||||||
|
B = backbone_features.size(0)
|
||||||
|
device = backbone_features.device
|
||||||
|
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
||||||
|
assert backbone_features.size(2) == self.sam_image_embedding_size
|
||||||
|
assert backbone_features.size(3) == self.sam_image_embedding_size
|
||||||
|
|
||||||
|
# a) Handle point prompts
|
||||||
|
if point_inputs is not None:
|
||||||
|
sam_point_coords = point_inputs["point_coords"]
|
||||||
|
sam_point_labels = point_inputs["point_labels"]
|
||||||
|
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
||||||
|
else:
|
||||||
|
# If no points are provide, pad with an empty point (with label -1)
|
||||||
|
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
||||||
|
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
# b) Handle mask prompts
|
||||||
|
if mask_inputs is not None:
|
||||||
|
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
||||||
|
# and feed it as a dense mask prompt into the SAM mask encoder
|
||||||
|
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
||||||
|
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
||||||
|
sam_mask_prompt = F.interpolate(
|
||||||
|
mask_inputs.float(),
|
||||||
|
size=self.sam_prompt_encoder.mask_input_size,
|
||||||
|
align_corners=False,
|
||||||
|
mode="bilinear",
|
||||||
|
antialias=True, # use antialias for downsampling
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sam_mask_prompt = mask_inputs
|
||||||
|
else:
|
||||||
|
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
||||||
|
# a learned `no_mask_embed` to indicate no mask input in this case).
|
||||||
|
sam_mask_prompt = None
|
||||||
|
|
||||||
|
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
||||||
|
points=(sam_point_coords, sam_point_labels),
|
||||||
|
boxes=None,
|
||||||
|
masks=sam_mask_prompt,
|
||||||
|
)
|
||||||
|
# Clone image_pe and the outputs of sam_prompt_encoder
|
||||||
|
# to enable compilation
|
||||||
|
sparse_embeddings = sparse_embeddings.clone()
|
||||||
|
dense_embeddings = dense_embeddings.clone()
|
||||||
|
image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
|
||||||
|
(
|
||||||
|
low_res_multimasks,
|
||||||
|
ious,
|
||||||
|
sam_output_tokens,
|
||||||
|
object_score_logits,
|
||||||
|
) = self.sam_mask_decoder(
|
||||||
|
image_embeddings=backbone_features,
|
||||||
|
image_pe=image_pe,
|
||||||
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
|
multimask_output=multimask_output,
|
||||||
|
repeat_image=False, # the image is already batched
|
||||||
|
high_res_features=high_res_features,
|
||||||
|
)
|
||||||
|
# Clone the output of sam_mask_decoder
|
||||||
|
# to enable compilation
|
||||||
|
low_res_multimasks = low_res_multimasks.clone()
|
||||||
|
ious = ious.clone()
|
||||||
|
sam_output_tokens = sam_output_tokens.clone()
|
||||||
|
object_score_logits = object_score_logits.clone()
|
||||||
|
|
||||||
|
if self.pred_obj_scores:
|
||||||
|
is_obj_appearing = object_score_logits > 0
|
||||||
|
|
||||||
|
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
||||||
|
# consistent with the actual mask prediction
|
||||||
|
low_res_multimasks = torch.where(
|
||||||
|
is_obj_appearing[:, None, None],
|
||||||
|
low_res_multimasks,
|
||||||
|
NO_OBJ_SCORE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert masks from possibly bfloat16 (or float16) to float32
|
||||||
|
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||||
|
low_res_multimasks = low_res_multimasks.float()
|
||||||
|
high_res_multimasks = F.interpolate(
|
||||||
|
low_res_multimasks,
|
||||||
|
size=(self.image_size, self.image_size),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
sam_output_token = sam_output_tokens[:, 0]
|
||||||
|
if multimask_output:
|
||||||
|
# take the best mask prediction (with the highest IoU estimation)
|
||||||
|
best_iou_inds = torch.argmax(ious, dim=-1)
|
||||||
|
batch_inds = torch.arange(B, device=device)
|
||||||
|
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||||
|
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||||
|
if sam_output_tokens.size(1) > 1:
|
||||||
|
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
||||||
|
else:
|
||||||
|
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
||||||
|
|
||||||
|
# Extract object pointer from the SAM output token (with occlusion handling)
|
||||||
|
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
||||||
|
if self.pred_obj_scores:
|
||||||
|
# Allow *soft* no obj ptr, unlike for masks
|
||||||
|
if self.soft_no_obj_ptr:
|
||||||
|
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
||||||
|
else:
|
||||||
|
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||||
|
|
||||||
|
if self.fixed_no_obj_ptr:
|
||||||
|
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
||||||
|
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
||||||
|
|
||||||
|
return (
|
||||||
|
low_res_multimasks,
|
||||||
|
high_res_multimasks,
|
||||||
|
ious,
|
||||||
|
low_res_masks,
|
||||||
|
high_res_masks,
|
||||||
|
obj_ptr,
|
||||||
|
object_score_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encode_new_memory(
|
||||||
|
self,
|
||||||
|
current_vision_feats,
|
||||||
|
feat_sizes,
|
||||||
|
pred_masks_high_res,
|
||||||
|
object_score_logits,
|
||||||
|
is_mask_from_pts,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
|
||||||
|
cloning the memories and their pos enc to enable compilation.
|
||||||
|
"""
|
||||||
|
B = current_vision_feats[-1].size(1) # batch size on this frame
|
||||||
|
C = self.hidden_dim
|
||||||
|
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
||||||
|
# top-level feature, (HW)BC => BCHW
|
||||||
|
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
||||||
|
if self.non_overlap_masks_for_mem_enc and not self.training:
|
||||||
|
# optionally, apply non-overlapping constraints to the masks (it's applied
|
||||||
|
# in the batch dimension and should only be used during eval, where all
|
||||||
|
# the objects come from the same video under batch size 1).
|
||||||
|
pred_masks_high_res = self._apply_non_overlapping_constraints(
|
||||||
|
pred_masks_high_res
|
||||||
|
)
|
||||||
|
# scale the raw mask logits with a temperature before applying sigmoid
|
||||||
|
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
||||||
|
if binarize and not self.training:
|
||||||
|
mask_for_mem = (pred_masks_high_res > 0).float()
|
||||||
|
else:
|
||||||
|
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
||||||
|
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
||||||
|
# apply scale and bias terms to the sigmoid probabilities
|
||||||
|
if self.sigmoid_scale_for_mem_enc != 1.0:
|
||||||
|
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||||
|
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||||
|
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||||
|
maskmem_out = self.memory_encoder(
|
||||||
|
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
|
||||||
|
)
|
||||||
|
# Clone the feats and pos_enc to enable compilation
|
||||||
|
maskmem_features = maskmem_out["vision_features"].clone()
|
||||||
|
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
|
||||||
|
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||||
|
# is predicted to be occluded (i.e. no object is appearing in the frame)
|
||||||
|
if self.no_obj_embed_spatial is not None:
|
||||||
|
is_obj_appearing = (object_score_logits > 0).float()
|
||||||
|
maskmem_features += (
|
||||||
|
1 - is_obj_appearing[..., None, None]
|
||||||
|
) * self.no_obj_embed_spatial[..., None, None].expand(
|
||||||
|
*maskmem_features.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
return maskmem_features, maskmem_pos_enc
|
||||||
|
@ -375,7 +375,7 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sam2_checkpoint",
|
"--sam2_checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
default="./checkpoints/sam2.1_hiera_b+.pt",
|
default="./checkpoints/sam2.1_hiera_base_plus.pt",
|
||||||
help="path to the SAM 2 model checkpoint",
|
help="path to the SAM 2 model checkpoint",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -434,6 +434,11 @@ def main():
|
|||||||
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
|
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
|
||||||
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
|
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_vos_optimized_video_predictor",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to use vos optimized video predictor with all modules compiled",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
|
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
|
||||||
@ -445,6 +450,7 @@ def main():
|
|||||||
ckpt_path=args.sam2_checkpoint,
|
ckpt_path=args.sam2_checkpoint,
|
||||||
apply_postprocessing=args.apply_postprocessing,
|
apply_postprocessing=args.apply_postprocessing,
|
||||||
hydra_overrides_extra=hydra_overrides_extra,
|
hydra_overrides_extra=hydra_overrides_extra,
|
||||||
|
vos_optimized=args.use_vos_optimized_video_predictor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_all_masks:
|
if args.use_all_masks:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user