From 9851575bf3ee5d310f10c1edf9625515d003a280 Mon Sep 17 00:00:00 2001 From: Chay Ryali Date: Wed, 11 Dec 2024 03:02:28 +0000 Subject: [PATCH 1/4] speed optimizations cleanup --- README.md | 19 +- sam2/benchmark.py | 86 ++++++++ sam2/build_sam.py | 7 + sam2/configs/sam2.1/sam2.1_hiera_b+.yaml | 4 +- sam2/configs/sam2.1/sam2.1_hiera_l.yaml | 4 +- sam2/configs/sam2.1/sam2.1_hiera_s.yaml | 4 +- sam2/configs/sam2.1/sam2.1_hiera_t.yaml | 4 +- sam2/configs/sam2/sam2_hiera_b+.yaml | 4 +- sam2/configs/sam2/sam2_hiera_l.yaml | 4 +- sam2/configs/sam2/sam2_hiera_s.yaml | 4 +- sam2/configs/sam2/sam2_hiera_t.yaml | 4 +- sam2/modeling/backbones/utils.py | 10 +- sam2/modeling/position_encoding.py | 35 ++- sam2/modeling/sam/prompt_encoder.py | 32 ++- sam2/modeling/sam/transformer.py | 61 +----- sam2/modeling/sam2_base.py | 6 +- sam2/sam2_video_predictor.py | 260 +++++++++++++++++++++++ tools/vos_inference.py | 8 +- 18 files changed, 453 insertions(+), 103 deletions(-) create mode 100644 sam2/benchmark.py diff --git a/README.md b/README.md index 65654f5..89d1da5 100644 --- a/README.md +++ b/README.md @@ -158,10 +158,10 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024. | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | -| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 | -| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 | -| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 | -| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 | +| sam2.1_hiera_tiny
([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 | +| sam2.1_hiera_small
([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 | +| sam2.1_hiera_base_plus
([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 | +| sam2.1_hiera_large
([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 | ### SAM 2 checkpoints @@ -169,13 +169,12 @@ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows | **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** | | :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: | -| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 | -| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 | -| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 | -| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 | - -\* Compile the model by setting `compile_image_encoder: True` in the config. +| sam2_hiera_tiny
([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 | +| sam2_hiera_small
([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 | +| sam2_hiera_base_plus
([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 | +| sam2_hiera_large
([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 | +Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config). ## Segment Anything Video Dataset See [sav_dataset/README.md](sav_dataset/README.md) for details. diff --git a/sam2/benchmark.py b/sam2/benchmark.py new file mode 100644 index 0000000..5b25d7a --- /dev/null +++ b/sam2/benchmark.py @@ -0,0 +1,86 @@ +import os +import time + +import numpy as np +import torch +from tqdm import tqdm + +from sam2.build_sam import build_sam2_video_predictor + +# Only cuda supported +assert torch.cuda.is_available() +device = torch.device("cuda") + +torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() +if torch.cuda.get_device_properties(0).major >= 8: + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +# Config and checkpoint +sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" +model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" + +# Build video predictor with vos_optimized=True setting +predictor = build_sam2_video_predictor( + model_cfg, sam2_checkpoint, device=device, vos_optimized=True +) + + +# Initialize with video +video_dir = "notebooks/videos/bedroom" +# scan all the JPEG frame names in this directory +frame_names = [ + p + for p in os.listdir(video_dir) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] +] +frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) +inference_state = predictor.init_state(video_path=video_dir) + + +# Number of runs, warmup etc +warm_up, runs = 5, 25 +verbose = True +num_frames = len(frame_names) +total, count = 0, 0 +torch.cuda.empty_cache() + +# We will select an object with a click. +# See video_predictor_example.ipynb for more detailed explanation +ann_frame_idx, ann_obj_id = 0, 1 +# Add a positive click at (x, y) = (210, 350) +# For labels, `1` means positive click +points = np.array([[210, 350]], dtype=np.float32) +labels = np.array([1], np.int32) + +_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + points=points, + labels=labels, +) + +# Warmup and then average FPS over several runs +with torch.autocast("cuda", torch.bfloat16): + with torch.inference_mode(): + for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): + start = time.time() + # Start tracking + for ( + out_frame_idx, + out_obj_ids, + out_mask_logits, + ) in predictor.propagate_in_video(inference_state): + pass + + end = time.time() + total += end - start + count += 1 + if i == warm_up - 1: + print("Warmup FPS: ", count * num_frames / total) + total = 0 + count = 0 + +print("FPS: ", count * num_frames / total) diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 7cfc451..3a3bef1 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -104,11 +104,18 @@ def build_sam2_video_predictor( mode="eval", hydra_overrides_extra=[], apply_postprocessing=True, + vos_optimized=False, **kwargs, ): hydra_overrides = [ "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", ] + if vos_optimized: + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS", + "++model.compile_image_encoder=True", # Let sam2_base handle this + ] + if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ diff --git a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml index cbee3cf..d7172f9 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -36,7 +36,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -47,7 +47,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml index 33c9097..23073ea 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_l.yaml @@ -40,7 +40,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -51,7 +51,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml index 8e803df..fd8d404 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml index 983c2ea..e762aec 100644 --- a/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +++ b/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_b+.yaml b/sam2/configs/sam2/sam2_hiera_b+.yaml index 58f3eb8..0f435af 100644 --- a/sam2/configs/sam2/sam2_hiera_b+.yaml +++ b/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -36,7 +36,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -47,7 +47,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_l.yaml b/sam2/configs/sam2/sam2_hiera_l.yaml index 918667f..1092802 100644 --- a/sam2/configs/sam2/sam2_hiera_l.yaml +++ b/sam2/configs/sam2/sam2_hiera_l.yaml @@ -40,7 +40,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -51,7 +51,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_s.yaml b/sam2/configs/sam2/sam2_hiera_s.yaml index 26e5d4d..174e414 100644 --- a/sam2/configs/sam2/sam2_hiera_s.yaml +++ b/sam2/configs/sam2/sam2_hiera_s.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/configs/sam2/sam2_hiera_t.yaml b/sam2/configs/sam2/sam2_hiera_t.yaml index a62c903..121447a 100644 --- a/sam2/configs/sam2/sam2_hiera_t.yaml +++ b/sam2/configs/sam2/sam2_hiera_t.yaml @@ -39,7 +39,7 @@ model: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -50,7 +50,7 @@ model: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1 diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py index 32d55c7..930b1b7 100644 --- a/sam2/modeling/backbones/utils.py +++ b/sam2/modeling/backbones/utils.py @@ -32,9 +32,7 @@ def window_partition(x, window_size): Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) + windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C) return windows, (Hp, Wp) @@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw): Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view( + x = windows.reshape( B, Hp // window_size, Wp // window_size, window_size, window_size, -1 ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1) if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() + x = x[:, :H, :W, :] return x diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index 52ac226..f5993d3 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module): temperature: int = 10000, normalize: bool = True, scale: Optional[float] = None, + # Following settings only relevant + # for warmping up cache for compilation + warmup_cache: bool = True, + image_size: int = 1024, + strides: Tuple[int] = (4, 8, 16, 32), ): super().__init__() assert num_pos_feats % 2 == 0, "Expecting even model width" @@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module): self.scale = scale self.cache = {} + if warmup_cache and torch.cuda.is_available(): + # Warmup cache for cuda, to help with compilation + device = torch.device("cuda") + for stride in strides: + cache_key = (image_size // stride, image_size // stride, device) + self._pe(1, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized @@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module): return pos @torch.no_grad() - def forward(self, x: torch.Tensor): - cache_key = (x.shape[-2], x.shape[-1]) + def _pe(self, B, *cache_key): + H, W, device = cache_key if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + return self.cache[cache_key][None].repeat(B, 1, 1, 1) + y_embed = ( - torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, H + 1, dtype=torch.float32, device=device) .view(1, -1, 1) - .repeat(x.shape[0], 1, x.shape[-1]) + .repeat(B, 1, W) ) x_embed = ( - torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + torch.arange(1, W + 1, dtype=torch.float32, device=device) .view(1, 1, -1) - .repeat(x.shape[0], x.shape[-2], 1) + .repeat(B, H, 1) ) if self.normalize: @@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module): y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -111,6 +123,13 @@ class PositionEmbeddingSine(nn.Module): self.cache[cache_key] = pos[0] return pos + @torch.no_grad() + def forward(self, x: torch.Tensor): + device = torch.device("cuda") if x.is_cuda else x.device + B = x.shape[0] + cache_key = (x.shape[-2], x.shape[-1], device) + return self._pe(B, *cache_key) + class PositionEmbeddingRandom(nn.Module): """ diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py index 6b3bbb9..c578762 100644 --- a/sam2/modeling/sam/prompt_encoder.py +++ b/sam2/modeling/sam/prompt_encoder.py @@ -92,12 +92,32 @@ class PromptEncoder(nn.Module): point_embedding = self.pe_layer.forward_with_coords( points, self.input_image_size ) - point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight - point_embedding[labels == 0] += self.point_embeddings[0].weight - point_embedding[labels == 1] += self.point_embeddings[1].weight - point_embedding[labels == 2] += self.point_embeddings[2].weight - point_embedding[labels == 3] += self.point_embeddings[3].weight + + point_embedding = torch.where( + (labels == -1).unsqueeze(-1), + torch.zeros_like(point_embedding) + self.not_a_point_embed.weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 0).unsqueeze(-1), + point_embedding + self.point_embeddings[0].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 1).unsqueeze(-1), + point_embedding + self.point_embeddings[1].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 2).unsqueeze(-1), + point_embedding + self.point_embeddings[2].weight, + point_embedding, + ) + point_embedding = torch.where( + (labels == 3).unsqueeze(-1), + point_embedding + self.point_embeddings[3].weight, + point_embedding, + ) return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py index b5b6fa2..f9fe9a3 100644 --- a/sam2/modeling/sam/transformer.py +++ b/sam2/modeling/sam/transformer.py @@ -4,9 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import contextlib import math -import warnings from functools import partial from typing import Tuple, Type @@ -16,29 +14,6 @@ from torch import nn, Tensor from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis from sam2.modeling.sam2_utils import MLP -from sam2.utils.misc import get_sdpa_settings - -warnings.simplefilter(action="ignore", category=FutureWarning) -# Check whether Flash Attention is available (and use it by default) -OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() -# A fallback setting to allow all available kernels if Flash Attention fails -ALLOW_ALL_KERNELS = False - - -def sdp_kernel_context(dropout_p): - """ - Get the context for the attention scaled dot-product kernel. We use Flash Attention - by default, but fall back to all available kernels if Flash Attention fails. - """ - if ALLOW_ALL_KERNELS: - return contextlib.nullcontext() - - return torch.backends.cuda.sdp_kernel( - enable_flash=USE_FLASH_ATTN, - # if Flash attention kernel is off, then math kernel needs to be enabled - enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, - enable_mem_efficient=OLD_GPU, - ) class TwoWayTransformer(nn.Module): @@ -265,20 +240,7 @@ class Attention(nn.Module): dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) @@ -296,7 +258,7 @@ class RoPEAttention(Attention): # whether to repeat q rope to match k length # this is needed for cross-attention to memories rope_k_repeat=False, - feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + feat_sizes=(64, 64), # [w, h] for stride 16 feats at 1024 resolution **kwargs, ): super().__init__(*args, **kwargs) @@ -305,7 +267,9 @@ class RoPEAttention(Attention): compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta ) freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis + self.freqs_cis = ( + freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis + ) self.rope_k_repeat = rope_k_repeat def forward( @@ -339,20 +303,7 @@ class RoPEAttention(Attention): dropout_p = self.dropout_p if self.training else 0.0 # Attention - try: - with sdp_kernel_context(dropout_p): - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) - except Exception as e: - # Fall back to all kernels if the Flash attention kernel fails - warnings.warn( - f"Flash Attention kernel failed due to: {e}\nFalling back to all available " - f"kernels for scaled_dot_product_attention (which may have a slower speed).", - category=UserWarning, - stacklevel=2, - ) - global ALLOW_ALL_KERNELS - ALLOW_ALL_KERNELS = True - out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) out = self.out_proj(out) diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py index a5d243a..8aa1a0b 100644 --- a/sam2/modeling/sam2_base.py +++ b/sam2/modeling/sam2_base.py @@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module): if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) + obj_pos = ( + torch.tensor(pos_list) + .pin_memory() + .to(device=device, non_blocking=True) + ) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index c7e01cc..055cde5 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -8,6 +8,7 @@ import warnings from collections import OrderedDict import torch +import torch.nn.functional as F from tqdm import tqdm @@ -1170,3 +1171,262 @@ class SAM2VideoPredictor(SAM2Base): non_cond_frame_outputs.pop(t, None) for obj_output_dict in inference_state["output_dict_per_obj"].values(): obj_output_dict["non_cond_frame_outputs"].pop(t, None) + + +class SAM2VideoPredictorVOS(SAM2VideoPredictor): + """Optimized for the VOS setting""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._compile_all_components() + + def _compile_all_components(self): + print( + "Compiling all components for for vos setting. First time may be very slow." + ) + self.memory_encoder.forward = torch.compile( + self.memory_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, + ) + + self.memory_attention.forward = torch.compile( + self.memory_attention.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, + ) + + self.sam_prompt_encoder.get_dense_pe = torch.compile( + self.sam_prompt_encoder.get_dense_pe, + mode="max-autotune", + fullgraph=True, + dynamic=True, + ) + + self.sam_prompt_encoder.forward = torch.compile( + self.sam_prompt_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, + ) + + self.sam_mask_decoder.forward = torch.compile( + self.sam_mask_decoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=True, + ) + + def forward_image(self, img_batch: torch.Tensor): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the backbone features and pos encoding to enable compilation. + """ + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + # Clone to help torch.compile + for i in range(len(backbone_out["backbone_fpn"])): + backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone() + backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][ + i + ].clone() + return backbone_out + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the outputs of prompt_encoder and mask_decoder to enable compilation. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + # Clone image_pe and the outputs of sam_prompt_encoder + # to enable compilation + sparse_embeddings = sparse_embeddings.clone() + dense_embeddings = dense_embeddings.clone() + image_pe = self.sam_prompt_encoder.get_dense_pe().clone() + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # Clone the output of sam_mask_decoder + # to enable compilation + low_res_multimasks = low_res_multimasks.clone() + ious = ious.clone() + sam_output_tokens = sam_output_tokens.clone() + object_score_logits = object_score_logits.clone() + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """ + Identical to the corresponding method in the parent (SAM2VideoPredictor), but + cloning the memories and their pos enc to enable compilation. + """ + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + # Clone the feats and pos_enc to enable compilation + maskmem_features = maskmem_out["vision_features"].clone() + maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += ( + 1 - is_obj_appearing[..., None, None] + ) * self.no_obj_embed_spatial[..., None, None].expand( + *maskmem_features.shape + ) + + return maskmem_features, maskmem_pos_enc diff --git a/tools/vos_inference.py b/tools/vos_inference.py index 5c40cda..ef3e8c6 100644 --- a/tools/vos_inference.py +++ b/tools/vos_inference.py @@ -375,7 +375,7 @@ def main(): parser.add_argument( "--sam2_checkpoint", type=str, - default="./checkpoints/sam2.1_hiera_b+.pt", + default="./checkpoints/sam2.1_hiera_base_plus.pt", help="path to the SAM 2 model checkpoint", ) parser.add_argument( @@ -434,6 +434,11 @@ def main(): help="whether to track objects that appear later in the video (i.e. not on the first frame; " "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)", ) + parser.add_argument( + "--use_vos_optimized_video_predictor", + action="store_true", + help="whether to use vos optimized video predictor with all modules compiled", + ) args = parser.parse_args() # if we use per-object PNG files, they could possibly overlap in inputs and outputs @@ -445,6 +450,7 @@ def main(): ckpt_path=args.sam2_checkpoint, apply_postprocessing=args.apply_postprocessing, hydra_overrides_extra=hydra_overrides_extra, + vos_optimized=args.use_vos_optimized_video_predictor, ) if args.use_all_masks: From 11f99b63c819822505977acb76052982d1a2baf2 Mon Sep 17 00:00:00 2001 From: Chay Ryali Date: Wed, 11 Dec 2024 06:29:00 +0000 Subject: [PATCH 2/4] move to sincos pos enc to device on fwd pass and remove dynamic causing accuracy regressions --- sam2/modeling/position_encoding.py | 15 +++++++-------- sam2/sam2_video_predictor.py | 15 ++++----------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index f5993d3..2241d4c 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -47,8 +47,8 @@ class PositionEmbeddingSine(nn.Module): # Warmup cache for cuda, to help with compilation device = torch.device("cuda") for stride in strides: - cache_key = (image_size // stride, image_size // stride, device) - self._pe(1, *cache_key) + cache_key = (image_size // stride, image_size // stride) + self._pe(1, device, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized @@ -87,10 +87,10 @@ class PositionEmbeddingSine(nn.Module): return pos @torch.no_grad() - def _pe(self, B, *cache_key): - H, W, device = cache_key + def _pe(self, B, device, *cache_key): + H, W = cache_key if cache_key in self.cache: - return self.cache[cache_key][None].repeat(B, 1, 1, 1) + return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) y_embed = ( torch.arange(1, H + 1, dtype=torch.float32, device=device) @@ -125,10 +125,9 @@ class PositionEmbeddingSine(nn.Module): @torch.no_grad() def forward(self, x: torch.Tensor): - device = torch.device("cuda") if x.is_cuda else x.device B = x.shape[0] - cache_key = (x.shape[-2], x.shape[-1], device) - return self._pe(B, *cache_key) + cache_key = (x.shape[-2], x.shape[-1]) + return self._pe(B, x.device, *cache_key) class PositionEmbeddingRandom(nn.Module): diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 055cde5..62f45e1 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -1188,35 +1188,28 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor): self.memory_encoder.forward, mode="max-autotune", fullgraph=True, - dynamic=True, + dynamic=False, ) self.memory_attention.forward = torch.compile( self.memory_attention.forward, mode="max-autotune", fullgraph=True, - dynamic=True, - ) - - self.sam_prompt_encoder.get_dense_pe = torch.compile( - self.sam_prompt_encoder.get_dense_pe, - mode="max-autotune", - fullgraph=True, - dynamic=True, + dynamic=True, # Num. of memories varies ) self.sam_prompt_encoder.forward = torch.compile( self.sam_prompt_encoder.forward, mode="max-autotune", fullgraph=True, - dynamic=True, + dynamic=False, # Accuracy regression on True ) self.sam_mask_decoder.forward = torch.compile( self.sam_mask_decoder.forward, mode="max-autotune", fullgraph=True, - dynamic=True, + dynamic=False, # Accuracy regression on True ) def forward_image(self, img_batch: torch.Tensor): From beacd9a521b75352f007eb45c3ed8f2ae498bade Mon Sep 17 00:00:00 2001 From: Chay Ryali Date: Wed, 11 Dec 2024 06:50:20 +0000 Subject: [PATCH 3/4] add license header --- sam2/benchmark.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sam2/benchmark.py b/sam2/benchmark.py index 5b25d7a..6519534 100644 --- a/sam2/benchmark.py +++ b/sam2/benchmark.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + import os import time From a5dc1d59240da6f537f4ebc5c698de544089ca69 Mon Sep 17 00:00:00 2001 From: Chay Ryali Date: Wed, 11 Dec 2024 07:07:23 +0000 Subject: [PATCH 4/4] update training config as well to be consistent --- .../sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml index 2046791..9b6faa7 100644 --- a/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +++ b/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml @@ -97,7 +97,7 @@ trainer: self_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] embedding_dim: 256 num_heads: 1 downsample_rate: 1 @@ -108,7 +108,7 @@ trainer: cross_attention: _target_: sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 - feat_sizes: [32, 32] + feat_sizes: [64, 64] rope_k_repeat: True embedding_dim: 256 num_heads: 1