Merge pull request #153 from fairinternal/chay/improve_speed_v1

speed optimizations cleanup
This commit is contained in:
Chay Ryali 2024-12-10 23:08:28 -08:00 committed by GitHub
commit 3297dd0eb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 453 additions and 105 deletions

View File

@ -158,10 +158,10 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024. The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 | | sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 | | sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 | | sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 | | sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
### SAM 2 checkpoints ### SAM 2 checkpoints
@ -169,13 +169,12 @@ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 | | sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 | | sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 | | sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 | | sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
\* Compile the model by setting `compile_image_encoder: True` in the config.
Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
## Segment Anything Video Dataset ## Segment Anything Video Dataset
See [sav_dataset/README.md](sav_dataset/README.md) for details. See [sav_dataset/README.md](sav_dataset/README.md) for details.

92
sam2/benchmark.py Normal file
View File

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import torch
from tqdm import tqdm
from sam2.build_sam import build_sam2_video_predictor
# Only cuda supported
assert torch.cuda.is_available()
device = torch.device("cuda")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Config and checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
# Build video predictor with vos_optimized=True setting
predictor = build_sam2_video_predictor(
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
)
# Initialize with video
video_dir = "notebooks/videos/bedroom"
# scan all the JPEG frame names in this directory
frame_names = [
p
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=video_dir)
# Number of runs, warmup etc
warm_up, runs = 5, 25
verbose = True
num_frames = len(frame_names)
total, count = 0, 0
torch.cuda.empty_cache()
# We will select an object with a click.
# See video_predictor_example.ipynb for more detailed explanation
ann_frame_idx, ann_obj_id = 0, 1
# Add a positive click at (x, y) = (210, 350)
# For labels, `1` means positive click
points = np.array([[210, 350]], dtype=np.float32)
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# Warmup and then average FPS over several runs
with torch.autocast("cuda", torch.bfloat16):
with torch.inference_mode():
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
start = time.time()
# Start tracking
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in predictor.propagate_in_video(inference_state):
pass
end = time.time()
total += end - start
count += 1
if i == warm_up - 1:
print("Warmup FPS: ", count * num_frames / total)
total = 0
count = 0
print("FPS: ", count * num_frames / total)

View File

@ -104,11 +104,18 @@ def build_sam2_video_predictor(
mode="eval", mode="eval",
hydra_overrides_extra=[], hydra_overrides_extra=[],
apply_postprocessing=True, apply_postprocessing=True,
vos_optimized=False,
**kwargs, **kwargs,
): ):
hydra_overrides = [ hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
] ]
if vos_optimized:
hydra_overrides = [
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
"++model.compile_image_encoder=True", # Let sam2_base handle this
]
if apply_postprocessing: if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra = hydra_overrides_extra.copy()
hydra_overrides_extra += [ hydra_overrides_extra += [

View File

@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -97,7 +97,7 @@ trainer:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -108,7 +108,7 @@ trainer:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -36,7 +36,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -47,7 +47,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -40,7 +40,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -51,7 +51,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -39,7 +39,7 @@ model:
self_attention: self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1
downsample_rate: 1 downsample_rate: 1
@ -50,7 +50,7 @@ model:
cross_attention: cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention _target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0 rope_theta: 10000.0
feat_sizes: [32, 32] feat_sizes: [64, 64]
rope_k_repeat: True rope_k_repeat: True
embedding_dim: 256 embedding_dim: 256
num_heads: 1 num_heads: 1

View File

@ -32,9 +32,7 @@ def window_partition(x, window_size):
Hp, Wp = H + pad_h, W + pad_w Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = ( windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp) return windows, (Hp, Wp)
@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
Hp, Wp = pad_hw Hp, Wp = pad_hw
H, W = hw H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size) B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view( x = windows.reshape(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1 B, Hp // window_size, Wp // window_size, window_size, window_size, -1
) )
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
if Hp > H or Wp > W: if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :]
return x return x

View File

@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
temperature: int = 10000, temperature: int = 10000,
normalize: bool = True, normalize: bool = True,
scale: Optional[float] = None, scale: Optional[float] = None,
# Following settings only relevant
# for warmping up cache for compilation
warmup_cache: bool = True,
image_size: int = 1024,
strides: Tuple[int] = (4, 8, 16, 32),
): ):
super().__init__() super().__init__()
assert num_pos_feats % 2 == 0, "Expecting even model width" assert num_pos_feats % 2 == 0, "Expecting even model width"
@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
self.scale = scale self.scale = scale
self.cache = {} self.cache = {}
if warmup_cache and torch.cuda.is_available():
# Warmup cache for cuda, to help with compilation
device = torch.device("cuda")
for stride in strides:
cache_key = (image_size // stride, image_size // stride)
self._pe(1, device, *cache_key)
def _encode_xy(self, x, y): def _encode_xy(self, x, y):
# The positions are expected to be normalized # The positions are expected to be normalized
@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
return pos return pos
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor): def _pe(self, B, device, *cache_key):
cache_key = (x.shape[-2], x.shape[-1]) H, W = cache_key
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
y_embed = ( y_embed = (
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) torch.arange(1, H + 1, dtype=torch.float32, device=device)
.view(1, -1, 1) .view(1, -1, 1)
.repeat(x.shape[0], 1, x.shape[-1]) .repeat(B, 1, W)
) )
x_embed = ( x_embed = (
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) torch.arange(1, W + 1, dtype=torch.float32, device=device)
.view(1, 1, -1) .view(1, 1, -1)
.repeat(x.shape[0], x.shape[-2], 1) .repeat(B, H, 1)
) )
if self.normalize: if self.normalize:
@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
@ -111,6 +123,12 @@ class PositionEmbeddingSine(nn.Module):
self.cache[cache_key] = pos[0] self.cache[cache_key] = pos[0]
return pos return pos
@torch.no_grad()
def forward(self, x: torch.Tensor):
B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module): class PositionEmbeddingRandom(nn.Module):
""" """

View File

@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
point_embedding = self.pe_layer.forward_with_coords( point_embedding = self.pe_layer.forward_with_coords(
points, self.input_image_size points, self.input_image_size
) )
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding = torch.where(
point_embedding[labels == 0] += self.point_embeddings[0].weight (labels == -1).unsqueeze(-1),
point_embedding[labels == 1] += self.point_embeddings[1].weight torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
point_embedding[labels == 2] += self.point_embeddings[2].weight point_embedding,
point_embedding[labels == 3] += self.point_embeddings[3].weight )
point_embedding = torch.where(
(labels == 0).unsqueeze(-1),
point_embedding + self.point_embeddings[0].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 1).unsqueeze(-1),
point_embedding + self.point_embeddings[1].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 2).unsqueeze(-1),
point_embedding + self.point_embeddings[2].weight,
point_embedding,
)
point_embedding = torch.where(
(labels == 3).unsqueeze(-1),
point_embedding + self.point_embeddings[3].weight,
point_embedding,
)
return point_embedding return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

View File

@ -4,9 +4,7 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib
import math import math
import warnings
from functools import partial from functools import partial
from typing import Tuple, Type from typing import Tuple, Type
@ -16,29 +14,6 @@ from torch import nn, Tensor
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
from sam2.modeling.sam2_utils import MLP from sam2.modeling.sam2_utils import MLP
from sam2.utils.misc import get_sdpa_settings
warnings.simplefilter(action="ignore", category=FutureWarning)
# Check whether Flash Attention is available (and use it by default)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
# A fallback setting to allow all available kernels if Flash Attention fails
ALLOW_ALL_KERNELS = False
def sdp_kernel_context(dropout_p):
"""
Get the context for the attention scaled dot-product kernel. We use Flash Attention
by default, but fall back to all available kernels if Flash Attention fails.
"""
if ALLOW_ALL_KERNELS:
return contextlib.nullcontext()
return torch.backends.cuda.sdp_kernel(
enable_flash=USE_FLASH_ATTN,
# if Flash attention kernel is off, then math kernel needs to be enabled
enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
enable_mem_efficient=OLD_GPU,
)
class TwoWayTransformer(nn.Module): class TwoWayTransformer(nn.Module):
@ -265,19 +240,6 @@ class Attention(nn.Module):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)
@ -296,7 +258,7 @@ class RoPEAttention(Attention):
# whether to repeat q rope to match k length # whether to repeat q rope to match k length
# this is needed for cross-attention to memories # this is needed for cross-attention to memories
rope_k_repeat=False, rope_k_repeat=False,
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -305,7 +267,9 @@ class RoPEAttention(Attention):
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
) )
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
self.freqs_cis = freqs_cis self.freqs_cis = (
freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
)
self.rope_k_repeat = rope_k_repeat self.rope_k_repeat = rope_k_repeat
def forward( def forward(
@ -339,19 +303,6 @@ class RoPEAttention(Attention):
dropout_p = self.dropout_p if self.training else 0.0 dropout_p = self.dropout_p if self.training else 0.0
# Attention # Attention
try:
with sdp_kernel_context(dropout_p):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
except Exception as e:
# Fall back to all kernels if the Flash attention kernel fails
warnings.warn(
f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
f"kernels for scaled_dot_product_attention (which may have a slower speed).",
category=UserWarning,
stacklevel=2,
)
global ALLOW_ALL_KERNELS
ALLOW_ALL_KERNELS = True
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out) out = self._recombine_heads(out)

View File

@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module):
if self.add_tpos_enc_to_obj_ptrs: if self.add_tpos_enc_to_obj_ptrs:
t_diff_max = max_obj_ptrs_in_encoder - 1 t_diff_max = max_obj_ptrs_in_encoder - 1
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
obj_pos = torch.tensor(pos_list, device=device) obj_pos = (
torch.tensor(pos_list)
.pin_memory()
.to(device=device, non_blocking=True)
)
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = self.obj_ptr_tpos_proj(obj_pos)
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)

View File

@ -8,6 +8,7 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
@ -1170,3 +1171,255 @@ class SAM2VideoPredictor(SAM2Base):
non_cond_frame_outputs.pop(t, None) non_cond_frame_outputs.pop(t, None)
for obj_output_dict in inference_state["output_dict_per_obj"].values(): for obj_output_dict in inference_state["output_dict_per_obj"].values():
obj_output_dict["non_cond_frame_outputs"].pop(t, None) obj_output_dict["non_cond_frame_outputs"].pop(t, None)
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
"""Optimized for the VOS setting"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._compile_all_components()
def _compile_all_components(self):
print(
"Compiling all components for for vos setting. First time may be very slow."
)
self.memory_encoder.forward = torch.compile(
self.memory_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False,
)
self.memory_attention.forward = torch.compile(
self.memory_attention.forward,
mode="max-autotune",
fullgraph=True,
dynamic=True, # Num. of memories varies
)
self.sam_prompt_encoder.forward = torch.compile(
self.sam_prompt_encoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
self.sam_mask_decoder.forward = torch.compile(
self.sam_mask_decoder.forward,
mode="max-autotune",
fullgraph=True,
dynamic=False, # Accuracy regression on True
)
def forward_image(self, img_batch: torch.Tensor):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the backbone features and pos encoding to enable compilation.
"""
backbone_out = self.image_encoder(img_batch)
if self.use_high_res_features_in_sam:
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
backbone_out["backbone_fpn"][0]
)
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
backbone_out["backbone_fpn"][1]
)
# Clone to help torch.compile
for i in range(len(backbone_out["backbone_fpn"])):
backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
i
].clone()
return backbone_out
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
"""
B = backbone_features.size(0)
device = backbone_features.device
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# a) Handle point prompts
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"]
sam_point_labels = point_inputs["point_labels"]
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# If no points are provide, pad with an empty point (with label -1)
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# b) Handle mask prompts
if mask_inputs is not None:
# If mask_inputs is provided, downsize it into low-res mask input if needed
# and feed it as a dense mask prompt into the SAM mask encoder
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # use antialias for downsampling
)
else:
sam_mask_prompt = mask_inputs
else:
# Otherwise, simply feed None (and SAM's prompt encoder will add
# a learned `no_mask_embed` to indicate no mask input in this case).
sam_mask_prompt = None
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None,
masks=sam_mask_prompt,
)
# Clone image_pe and the outputs of sam_prompt_encoder
# to enable compilation
sparse_embeddings = sparse_embeddings.clone()
dense_embeddings = dense_embeddings.clone()
image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # the image is already batched
high_res_features=high_res_features,
)
# Clone the output of sam_mask_decoder
# to enable compilation
low_res_multimasks = low_res_multimasks.clone()
ious = ious.clone()
sam_output_tokens = sam_output_tokens.clone()
object_score_logits = object_score_logits.clone()
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
low_res_multimasks = low_res_multimasks.float()
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
sam_output_token = sam_output_tokens[:, 0]
if multimask_output:
# take the best mask prediction (with the highest IoU estimation)
best_iou_inds = torch.argmax(ious, dim=-1)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# Extract object pointer from the SAM output token (with occlusion handling)
obj_ptr = self.obj_ptr_proj(sam_output_token)
if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr:
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
lambda_is_obj_appearing = is_obj_appearing.float()
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
return (
low_res_multimasks,
high_res_multimasks,
ious,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
)
def _encode_new_memory(
self,
current_vision_feats,
feat_sizes,
pred_masks_high_res,
object_score_logits,
is_mask_from_pts,
):
"""
Identical to the corresponding method in the parent (SAM2VideoPredictor), but
cloning the memories and their pos enc to enable compilation.
"""
B = current_vision_feats[-1].size(1) # batch size on this frame
C = self.hidden_dim
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
# top-level feature, (HW)BC => BCHW
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
if self.non_overlap_masks_for_mem_enc and not self.training:
# optionally, apply non-overlapping constraints to the masks (it's applied
# in the batch dimension and should only be used during eval, where all
# the objects come from the same video under batch size 1).
pred_masks_high_res = self._apply_non_overlapping_constraints(
pred_masks_high_res
)
# scale the raw mask logits with a temperature before applying sigmoid
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
if binarize and not self.training:
mask_for_mem = (pred_masks_high_res > 0).float()
else:
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
mask_for_mem = torch.sigmoid(pred_masks_high_res)
# apply scale and bias terms to the sigmoid probabilities
if self.sigmoid_scale_for_mem_enc != 1.0:
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied
)
# Clone the feats and pos_enc to enable compilation
maskmem_features = maskmem_out["vision_features"].clone()
maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (
1 - is_obj_appearing[..., None, None]
) * self.no_obj_embed_spatial[..., None, None].expand(
*maskmem_features.shape
)
return maskmem_features, maskmem_pos_enc

View File

@ -375,7 +375,7 @@ def main():
parser.add_argument( parser.add_argument(
"--sam2_checkpoint", "--sam2_checkpoint",
type=str, type=str,
default="./checkpoints/sam2.1_hiera_b+.pt", default="./checkpoints/sam2.1_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint", help="path to the SAM 2 model checkpoint",
) )
parser.add_argument( parser.add_argument(
@ -434,6 +434,11 @@ def main():
help="whether to track objects that appear later in the video (i.e. not on the first frame; " help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
) )
parser.add_argument(
"--use_vos_optimized_video_predictor",
action="store_true",
help="whether to use vos optimized video predictor with all modules compiled",
)
args = parser.parse_args() args = parser.parse_args()
# if we use per-object PNG files, they could possibly overlap in inputs and outputs # if we use per-object PNG files, they could possibly overlap in inputs and outputs
@ -445,6 +450,7 @@ def main():
ckpt_path=args.sam2_checkpoint, ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing, apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra, hydra_overrides_extra=hydra_overrides_extra,
vos_optimized=args.use_vos_optimized_video_predictor,
) )
if args.use_all_masks: if args.use_all_masks: