diff --git a/README.md b/README.md
index 65654f5..89d1da5 100644
--- a/README.md
+++ b/README.md
@@ -158,10 +158,10 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
-| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 |
-| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 |
-| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 |
-| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 |
+| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
+| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
+| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
+| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
### SAM 2 checkpoints
@@ -169,13 +169,12 @@ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
-| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
-| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
-| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
-| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
-
-\* Compile the model by setting `compile_image_encoder: True` in the config.
+| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
+| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
+| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
+| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
+Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
## Segment Anything Video Dataset
See [sav_dataset/README.md](sav_dataset/README.md) for details.
diff --git a/sam2/benchmark.py b/sam2/benchmark.py
new file mode 100644
index 0000000..6519534
--- /dev/null
+++ b/sam2/benchmark.py
@@ -0,0 +1,92 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from sam2.build_sam import build_sam2_video_predictor
+
+# Only cuda supported
+assert torch.cuda.is_available()
+device = torch.device("cuda")
+
+torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
+if torch.cuda.get_device_properties(0).major >= 8:
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+# Config and checkpoint
+sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
+model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
+
+# Build video predictor with vos_optimized=True setting
+predictor = build_sam2_video_predictor(
+ model_cfg, sam2_checkpoint, device=device, vos_optimized=True
+)
+
+
+# Initialize with video
+video_dir = "notebooks/videos/bedroom"
+# scan all the JPEG frame names in this directory
+frame_names = [
+ p
+ for p in os.listdir(video_dir)
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+]
+frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+inference_state = predictor.init_state(video_path=video_dir)
+
+
+# Number of runs, warmup etc
+warm_up, runs = 5, 25
+verbose = True
+num_frames = len(frame_names)
+total, count = 0, 0
+torch.cuda.empty_cache()
+
+# We will select an object with a click.
+# See video_predictor_example.ipynb for more detailed explanation
+ann_frame_idx, ann_obj_id = 0, 1
+# Add a positive click at (x, y) = (210, 350)
+# For labels, `1` means positive click
+points = np.array([[210, 350]], dtype=np.float32)
+labels = np.array([1], np.int32)
+
+_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
+ inference_state=inference_state,
+ frame_idx=ann_frame_idx,
+ obj_id=ann_obj_id,
+ points=points,
+ labels=labels,
+)
+
+# Warmup and then average FPS over several runs
+with torch.autocast("cuda", torch.bfloat16):
+ with torch.inference_mode():
+ for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
+ start = time.time()
+ # Start tracking
+ for (
+ out_frame_idx,
+ out_obj_ids,
+ out_mask_logits,
+ ) in predictor.propagate_in_video(inference_state):
+ pass
+
+ end = time.time()
+ total += end - start
+ count += 1
+ if i == warm_up - 1:
+ print("Warmup FPS: ", count * num_frames / total)
+ total = 0
+ count = 0
+
+print("FPS: ", count * num_frames / total)
diff --git a/sam2/build_sam.py b/sam2/build_sam.py
index 7cfc451..3a3bef1 100644
--- a/sam2/build_sam.py
+++ b/sam2/build_sam.py
@@ -104,11 +104,18 @@ def build_sam2_video_predictor(
mode="eval",
hydra_overrides_extra=[],
apply_postprocessing=True,
+ vos_optimized=False,
**kwargs,
):
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
]
+ if vos_optimized:
+ hydra_overrides = [
+ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
+ "++model.compile_image_encoder=True", # Let sam2_base handle this
+ ]
+
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [
diff --git a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
index cbee3cf..d7172f9 100644
--- a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
+++ b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
@@ -36,7 +36,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
index 33c9097..23073ea 100644
--- a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
+++ b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml
@@ -40,7 +40,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
index 8e803df..fd8d404 100644
--- a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
+++ b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
@@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
index 983c2ea..e762aec 100644
--- a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
+++ b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml
@@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
index 2046791..9b6faa7 100644
--- a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
+++ b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml
@@ -97,7 +97,7 @@ trainer:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -108,7 +108,7 @@ trainer:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2/sam2_hiera_b+.yaml b/sam2/configs/sam2/sam2_hiera_b+.yaml
index 58f3eb8..0f435af 100644
--- a/sam2/configs/sam2/sam2_hiera_b+.yaml
+++ b/sam2/configs/sam2/sam2_hiera_b+.yaml
@@ -36,7 +36,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -47,7 +47,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2/sam2_hiera_l.yaml b/sam2/configs/sam2/sam2_hiera_l.yaml
index 918667f..1092802 100644
--- a/sam2/configs/sam2/sam2_hiera_l.yaml
+++ b/sam2/configs/sam2/sam2_hiera_l.yaml
@@ -40,7 +40,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -51,7 +51,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2/sam2_hiera_s.yaml b/sam2/configs/sam2/sam2_hiera_s.yaml
index 26e5d4d..174e414 100644
--- a/sam2/configs/sam2/sam2_hiera_s.yaml
+++ b/sam2/configs/sam2/sam2_hiera_s.yaml
@@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/configs/sam2/sam2_hiera_t.yaml b/sam2/configs/sam2/sam2_hiera_t.yaml
index a62c903..121447a 100644
--- a/sam2/configs/sam2/sam2_hiera_t.yaml
+++ b/sam2/configs/sam2/sam2_hiera_t.yaml
@@ -39,7 +39,7 @@ model:
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
@@ -50,7 +50,7 @@ model:
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
- feat_sizes: [32, 32]
+ feat_sizes: [64, 64]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py
index 32d55c7..930b1b7 100644
--- a/sam2/modeling/backbones/utils.py
+++ b/sam2/modeling/backbones/utils.py
@@ -32,9 +32,7 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
- windows = (
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- )
+ windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
return windows, (Hp, Wp)
@@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
- x = windows.view(
+ x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W:
- x = x[:, :H, :W, :].contiguous()
+ x = x[:, :H, :W, :]
return x
diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py
index 52ac226..2241d4c 100644
--- a/sam2/modeling/position_encoding.py
+++ b/sam2/modeling/position_encoding.py
@@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
temperature: int = 10000,
normalize: bool = True,
scale: Optional[float] = None,
+ # Following settings only relevant
+ # for warmping up cache for compilation
+ warmup_cache: bool = True,
+ image_size: int = 1024,
+ strides: Tuple[int] = (4, 8, 16, 32),
):
super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width"
@@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
self.scale = scale
self.cache = {}
+ if warmup_cache and torch.cuda.is_available():
+ # Warmup cache for cuda, to help with compilation
+ device = torch.device("cuda")
+ for stride in strides:
+ cache_key = (image_size // stride, image_size // stride)
+ self._pe(1, device, *cache_key)
def _encode_xy(self, x, y):
# The positions are expected to be normalized
@@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
return pos
@torch.no_grad()
- def forward(self, x: torch.Tensor):
- cache_key = (x.shape[-2], x.shape[-1])
+ def _pe(self, B, device, *cache_key):
+ H, W = cache_key
if cache_key in self.cache:
- return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+ return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
+
y_embed = (
- torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+ torch.arange(1, H + 1, dtype=torch.float32, device=device)
.view(1, -1, 1)
- .repeat(x.shape[0], 1, x.shape[-1])
+ .repeat(B, 1, W)
)
x_embed = (
- torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+ torch.arange(1, W + 1, dtype=torch.float32, device=device)
.view(1, 1, -1)
- .repeat(x.shape[0], x.shape[-2], 1)
+ .repeat(B, H, 1)
)
if self.normalize:
@@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
@@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
self.cache[cache_key] = pos[0]
return pos
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor):
+ B = x.shape[0]
+ cache_key = (x.shape[-2], x.shape[-1])
+ return self._pe(B, x.device, *cache_key)
+
class PositionEmbeddingRandom(nn.Module):
"""
diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py
index 6b3bbb9..c578762 100644
--- a/sam2/modeling/sam/prompt_encoder.py
+++ b/sam2/modeling/sam/prompt_encoder.py
@@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size
)
- point_embedding[labels == -1] = 0.0
- point_embedding[labels == -1] += self.not_a_point_embed.weight
- point_embedding[labels == 0] += self.point_embeddings[0].weight
- point_embedding[labels == 1] += self.point_embeddings[1].weight
- point_embedding[labels == 2] += self.point_embeddings[2].weight
- point_embedding[labels == 3] += self.point_embeddings[3].weight
+
+ point_embedding = torch.where(
+ (labels == -1).unsqueeze(-1),
+ torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 0).unsqueeze(-1),
+ point_embedding + self.point_embeddings[0].weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 1).unsqueeze(-1),
+ point_embedding + self.point_embeddings[1].weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 2).unsqueeze(-1),
+ point_embedding + self.point_embeddings[2].weight,
+ point_embedding,
+ )
+ point_embedding = torch.where(
+ (labels == 3).unsqueeze(-1),
+ point_embedding + self.point_embeddings[3].weight,
+ point_embedding,
+ )
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py
index b5b6fa2..f9fe9a3 100644
--- a/sam2/modeling/sam/transformer.py
+++ b/sam2/modeling/sam/transformer.py
@@ -4,9 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-import contextlib
import math
-import warnings
from functools import partial
from typing import Tuple, Type
@@ -16,29 +14,6 @@ from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP
-from sam2.utils.misc import get_sdpa_settings
-
-warnings.simplefilter(action="ignore", category=FutureWarning)
-# Check whether Flash Attention is available (and use it by default)
-OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
-# A fallback setting to allow all available kernels if Flash Attention fails
-ALLOW_ALL_KERNELS = False
-
-
-def sdp_kernel_context(dropout_p):
- """
- Get the context for the attention scaled dot-product kernel. We use Flash Attention
- by default, but fall back to all available kernels if Flash Attention fails.
- """
- if ALLOW_ALL_KERNELS:
- return contextlib.nullcontext()
-
- return torch.backends.cuda.sdp_kernel(
- enable_flash=USE_FLASH_ATTN,
- # if Flash attention kernel is off, then math kernel needs to be enabled
- enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
- enable_mem_efficient=OLD_GPU,
- )
class TwoWayTransformer(nn.Module):
@@ -265,20 +240,7 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0
# Attention
- try:
- with sdp_kernel_context(dropout_p):
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
- except Exception as e:
- # Fall back to all kernels if the Flash attention kernel fails
- warnings.warn(
- f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
- f"kernels for scaled_dot_product_attention (which may have a slower speed).",
- category=UserWarning,
- stacklevel=2,
- )
- global ALLOW_ALL_KERNELS
- ALLOW_ALL_KERNELS = True
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
@@ -296,7 +258,7 @@ class RoPEAttention(Attention):
# whether to repeat q rope to match k length
# this is needed for cross-attention to memories
rope_k_repeat=False,
- feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
+ feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
**kwargs,
):
super().__init__(*args, **kwargs)
@@ -305,7 +267,9 @@ class RoPEAttention(Attention):
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
)
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
- self.freqs_cis = freqs_cis
+ self.freqs_cis = (
+ freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
+ )
self.rope_k_repeat = rope_k_repeat
def forward(
@@ -339,20 +303,7 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0
# Attention
- try:
- with sdp_kernel_context(dropout_p):
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
- except Exception as e:
- # Fall back to all kernels if the Flash attention kernel fails
- warnings.warn(
- f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
- f"kernels for scaled_dot_product_attention (which may have a slower speed).",
- category=UserWarning,
- stacklevel=2,
- )
- global ALLOW_ALL_KERNELS
- ALLOW_ALL_KERNELS = True
- out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
out = self.out_proj(out)
diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py
index a5d243a..8aa1a0b 100644
--- a/sam2/modeling/sam2_base.py
+++ b/sam2/modeling/sam2_base.py
@@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
- obj_pos = torch.tensor(pos_list, device=device)
+ obj_pos = (
+ torch.tensor(pos_list)
+ .pin_memory()
+ .to(device=device, non_blocking=True)
+ )
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py
index c7e01cc..62f45e1 100644
--- a/sam2/sam2_video_predictor.py
+++ b/sam2/sam2_video_predictor.py
@@ -8,6 +8,7 @@ import warnings
from collections import OrderedDict
import torch
+import torch.nn.functional as F
from tqdm import tqdm
@@ -1170,3 +1171,255 @@ class SAM2VideoPredictor(SAM2Base):
non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
+
+
+class SAM2VideoPredictorVOS(SAM2VideoPredictor):
+ """Optimized for the VOS setting"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._compile_all_components()
+
+ def _compile_all_components(self):
+ print(
+ "Compiling all components for for vos setting. First time may be very slow."
+ )
+ self.memory_encoder.forward = torch.compile(
+ self.memory_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ self.memory_attention.forward = torch.compile(
+ self.memory_attention.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=True, # Num. of memories varies
+ )
+
+ self.sam_prompt_encoder.forward = torch.compile(
+ self.sam_prompt_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False, # Accuracy regression on True
+ )
+
+ self.sam_mask_decoder.forward = torch.compile(
+ self.sam_mask_decoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False, # Accuracy regression on True
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
+ cloning the backbone features and pos encoding to enable compilation.
+ """
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
+ backbone_out["backbone_fpn"][0]
+ )
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
+ backbone_out["backbone_fpn"][1]
+ )
+ # Clone to help torch.compile
+ for i in range(len(backbone_out["backbone_fpn"])):
+ backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
+ backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
+ i
+ ].clone()
+ return backbone_out
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
+ cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ # Clone image_pe and the outputs of sam_prompt_encoder
+ # to enable compilation
+ sparse_embeddings = sparse_embeddings.clone()
+ dense_embeddings = dense_embeddings.clone()
+ image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ # Clone the output of sam_mask_decoder
+ # to enable compilation
+ low_res_multimasks = low_res_multimasks.clone()
+ ious = ious.clone()
+ sam_output_tokens = sam_output_tokens.clone()
+ object_score_logits = object_score_logits.clone()
+
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ object_score_logits,
+ is_mask_from_pts,
+ ):
+ """
+ Identical to the corresponding method in the parent (SAM2VideoPredictor), but
+ cloning the memories and their pos enc to enable compilation.
+ """
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
+ pred_masks_high_res
+ )
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
+ )
+ # Clone the feats and pos_enc to enable compilation
+ maskmem_features = maskmem_out["vision_features"].clone()
+ maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
+ # add a no-object embedding to the spatial memory to indicate that the frame
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
+ if self.no_obj_embed_spatial is not None:
+ is_obj_appearing = (object_score_logits > 0).float()
+ maskmem_features += (
+ 1 - is_obj_appearing[..., None, None]
+ ) * self.no_obj_embed_spatial[..., None, None].expand(
+ *maskmem_features.shape
+ )
+
+ return maskmem_features, maskmem_pos_enc
diff --git a/tools/vos_inference.py b/tools/vos_inference.py
index 5c40cda..ef3e8c6 100644
--- a/tools/vos_inference.py
+++ b/tools/vos_inference.py
@@ -375,7 +375,7 @@ def main():
parser.add_argument(
"--sam2_checkpoint",
type=str,
- default="./checkpoints/sam2.1_hiera_b+.pt",
+ default="./checkpoints/sam2.1_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint",
)
parser.add_argument(
@@ -434,6 +434,11 @@ def main():
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
)
+ parser.add_argument(
+ "--use_vos_optimized_video_predictor",
+ action="store_true",
+ help="whether to use vos optimized video predictor with all modules compiled",
+ )
args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
@@ -445,6 +450,7 @@ def main():
ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra,
+ vos_optimized=args.use_vos_optimized_video_predictor,
)
if args.use_all_masks: