mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 19:42:12 +08:00 
			
		
		
		
	speed optimizations cleanup
This commit is contained in:
		
							parent
							
								
									c2ec8e14a1
								
							
						
					
					
						commit
						9851575bf3
					
				
							
								
								
									
										19
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								README.md
									
									
									
									
									
								
							@ -158,10 +158,10 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
 | 
			
		||||
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
 | 
			
		||||
|      **Model**       | **Size (M)** |    **Speed (FPS)**     | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
 | 
			
		||||
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
 | 
			
		||||
|   sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt))    |     38.9     |          47.2          |        76.5         |        71.8        |       77.3        |
 | 
			
		||||
|   sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt))   |      46      | 43.3 (53.0 compiled\*) |        76.6         |        73.5        |       78.3        |
 | 
			
		||||
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) |     80.8     | 34.8 (43.8 compiled\*) |        78.2         |        73.7        |       78.2        |
 | 
			
		||||
|   sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt))   |    224.4     | 24.2 (30.2 compiled\*) |        79.5         |        74.6        |       80.6        |
 | 
			
		||||
|   sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt))    |     38.9     |          91.2          |        76.5         |        71.8        |       77.3        |
 | 
			
		||||
|   sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt))   |      46      |          84.8          |        76.6         |        73.5        |       78.3        |
 | 
			
		||||
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) |     80.8     |        64.1          |        78.2         |        73.7        |       78.2        |
 | 
			
		||||
|   sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt))   |    224.4     |          39.5          |        79.5         |        74.6        |       80.6        |
 | 
			
		||||
 | 
			
		||||
### SAM 2 checkpoints
 | 
			
		||||
 | 
			
		||||
@ -169,13 +169,12 @@ The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows
 | 
			
		||||
 | 
			
		||||
|      **Model**       | **Size (M)** |    **Speed (FPS)**     | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
 | 
			
		||||
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
 | 
			
		||||
|   sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt))   |     38.9     |          47.2          |        75.0         |        70.9        |       75.3        |
 | 
			
		||||
|   sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt))   |      46      | 43.3 (53.0 compiled\*) |        74.9         |        71.5        |       76.4        |
 | 
			
		||||
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) |     80.8     | 34.8 (43.8 compiled\*) |        74.7         |        72.8        |       75.8        |
 | 
			
		||||
|   sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt))   |    224.4     | 24.2 (30.2 compiled\*) |        76.0         |        74.6        |       79.8        |
 | 
			
		||||
 | 
			
		||||
\* Compile the model by setting `compile_image_encoder: True` in the config.
 | 
			
		||||
|   sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt))   |     38.9     |          91.5          |        75.0         |        70.9        |       75.3        |
 | 
			
		||||
|   sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt))   |      46      |          85.6          |        74.9         |        71.5        |       76.4        |
 | 
			
		||||
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) |     80.8     |     64.8    |        74.7         |        72.8        |       75.8        |
 | 
			
		||||
|   sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt))   |    224.4     | 39.7 |        76.0         |        74.6        |       79.8        |
 | 
			
		||||
 | 
			
		||||
Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
 | 
			
		||||
## Segment Anything Video Dataset
 | 
			
		||||
 | 
			
		||||
See [sav_dataset/README.md](sav_dataset/README.md) for details.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										86
									
								
								sam2/benchmark.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								sam2/benchmark.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,86 @@
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from sam2.build_sam import build_sam2_video_predictor
 | 
			
		||||
 | 
			
		||||
# Only cuda supported
 | 
			
		||||
assert torch.cuda.is_available()
 | 
			
		||||
device = torch.device("cuda")
 | 
			
		||||
 | 
			
		||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
 | 
			
		||||
if torch.cuda.get_device_properties(0).major >= 8:
 | 
			
		||||
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
 | 
			
		||||
    torch.backends.cuda.matmul.allow_tf32 = True
 | 
			
		||||
    torch.backends.cudnn.allow_tf32 = True
 | 
			
		||||
 | 
			
		||||
# Config and checkpoint
 | 
			
		||||
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
 | 
			
		||||
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
 | 
			
		||||
 | 
			
		||||
# Build video predictor with vos_optimized=True setting
 | 
			
		||||
predictor = build_sam2_video_predictor(
 | 
			
		||||
    model_cfg, sam2_checkpoint, device=device, vos_optimized=True
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Initialize with video
 | 
			
		||||
video_dir = "notebooks/videos/bedroom"
 | 
			
		||||
# scan all the JPEG frame names in this directory
 | 
			
		||||
frame_names = [
 | 
			
		||||
    p
 | 
			
		||||
    for p in os.listdir(video_dir)
 | 
			
		||||
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
 | 
			
		||||
]
 | 
			
		||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
 | 
			
		||||
inference_state = predictor.init_state(video_path=video_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Number of runs, warmup etc
 | 
			
		||||
warm_up, runs = 5, 25
 | 
			
		||||
verbose = True
 | 
			
		||||
num_frames = len(frame_names)
 | 
			
		||||
total, count = 0, 0
 | 
			
		||||
torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
# We will select an object with a click.
 | 
			
		||||
# See video_predictor_example.ipynb for more detailed explanation
 | 
			
		||||
ann_frame_idx, ann_obj_id = 0, 1
 | 
			
		||||
# Add a positive click at (x, y) = (210, 350)
 | 
			
		||||
# For labels, `1` means positive click
 | 
			
		||||
points = np.array([[210, 350]], dtype=np.float32)
 | 
			
		||||
labels = np.array([1], np.int32)
 | 
			
		||||
 | 
			
		||||
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
 | 
			
		||||
    inference_state=inference_state,
 | 
			
		||||
    frame_idx=ann_frame_idx,
 | 
			
		||||
    obj_id=ann_obj_id,
 | 
			
		||||
    points=points,
 | 
			
		||||
    labels=labels,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Warmup and then average FPS over several runs
 | 
			
		||||
with torch.autocast("cuda", torch.bfloat16):
 | 
			
		||||
    with torch.inference_mode():
 | 
			
		||||
        for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
 | 
			
		||||
            start = time.time()
 | 
			
		||||
            # Start tracking
 | 
			
		||||
            for (
 | 
			
		||||
                out_frame_idx,
 | 
			
		||||
                out_obj_ids,
 | 
			
		||||
                out_mask_logits,
 | 
			
		||||
            ) in predictor.propagate_in_video(inference_state):
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
            end = time.time()
 | 
			
		||||
            total += end - start
 | 
			
		||||
            count += 1
 | 
			
		||||
            if i == warm_up - 1:
 | 
			
		||||
                print("Warmup FPS: ", count * num_frames / total)
 | 
			
		||||
                total = 0
 | 
			
		||||
                count = 0
 | 
			
		||||
 | 
			
		||||
print("FPS: ", count * num_frames / total)
 | 
			
		||||
@ -104,11 +104,18 @@ def build_sam2_video_predictor(
 | 
			
		||||
    mode="eval",
 | 
			
		||||
    hydra_overrides_extra=[],
 | 
			
		||||
    apply_postprocessing=True,
 | 
			
		||||
    vos_optimized=False,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    hydra_overrides = [
 | 
			
		||||
        "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
 | 
			
		||||
    ]
 | 
			
		||||
    if vos_optimized:
 | 
			
		||||
        hydra_overrides = [
 | 
			
		||||
            "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
 | 
			
		||||
            "++model.compile_image_encoder=True",  # Let sam2_base handle this
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    if apply_postprocessing:
 | 
			
		||||
        hydra_overrides_extra = hydra_overrides_extra.copy()
 | 
			
		||||
        hydra_overrides_extra += [
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -47,7 +47,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -51,7 +51,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -50,7 +50,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -50,7 +50,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -36,7 +36,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -47,7 +47,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -51,7 +51,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -50,7 +50,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@ model:
 | 
			
		||||
      self_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
        downsample_rate: 1
 | 
			
		||||
@ -50,7 +50,7 @@ model:
 | 
			
		||||
      cross_attention:
 | 
			
		||||
        _target_: sam2.modeling.sam.transformer.RoPEAttention
 | 
			
		||||
        rope_theta: 10000.0
 | 
			
		||||
        feat_sizes: [32, 32]
 | 
			
		||||
        feat_sizes: [64, 64]
 | 
			
		||||
        rope_k_repeat: True
 | 
			
		||||
        embedding_dim: 256
 | 
			
		||||
        num_heads: 1
 | 
			
		||||
 | 
			
		||||
@ -32,9 +32,7 @@ def window_partition(x, window_size):
 | 
			
		||||
    Hp, Wp = H + pad_h, W + pad_w
 | 
			
		||||
 | 
			
		||||
    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
 | 
			
		||||
    windows = (
 | 
			
		||||
        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
 | 
			
		||||
    )
 | 
			
		||||
    windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
 | 
			
		||||
    return windows, (Hp, Wp)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,13 +50,13 @@ def window_unpartition(windows, window_size, pad_hw, hw):
 | 
			
		||||
    Hp, Wp = pad_hw
 | 
			
		||||
    H, W = hw
 | 
			
		||||
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
 | 
			
		||||
    x = windows.view(
 | 
			
		||||
    x = windows.reshape(
 | 
			
		||||
        B, Hp // window_size, Wp // window_size, window_size, window_size, -1
 | 
			
		||||
    )
 | 
			
		||||
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
 | 
			
		||||
    x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
 | 
			
		||||
 | 
			
		||||
    if Hp > H or Wp > W:
 | 
			
		||||
        x = x[:, :H, :W, :].contiguous()
 | 
			
		||||
        x = x[:, :H, :W, :]
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,11 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
        temperature: int = 10000,
 | 
			
		||||
        normalize: bool = True,
 | 
			
		||||
        scale: Optional[float] = None,
 | 
			
		||||
        # Following settings only relevant
 | 
			
		||||
        # for warmping up cache for compilation
 | 
			
		||||
        warmup_cache: bool = True,
 | 
			
		||||
        image_size: int = 1024,
 | 
			
		||||
        strides: Tuple[int] = (4, 8, 16, 32),
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        assert num_pos_feats % 2 == 0, "Expecting even model width"
 | 
			
		||||
@ -38,6 +43,12 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
        self.scale = scale
 | 
			
		||||
 | 
			
		||||
        self.cache = {}
 | 
			
		||||
        if warmup_cache and torch.cuda.is_available():
 | 
			
		||||
            # Warmup cache for cuda, to help with compilation
 | 
			
		||||
            device = torch.device("cuda")
 | 
			
		||||
            for stride in strides:
 | 
			
		||||
                cache_key = (image_size // stride, image_size // stride, device)
 | 
			
		||||
                self._pe(1, *cache_key)
 | 
			
		||||
 | 
			
		||||
    def _encode_xy(self, x, y):
 | 
			
		||||
        # The positions are expected to be normalized
 | 
			
		||||
@ -76,19 +87,20 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
        return pos
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        cache_key = (x.shape[-2], x.shape[-1])
 | 
			
		||||
    def _pe(self, B, *cache_key):
 | 
			
		||||
        H, W, device = cache_key
 | 
			
		||||
        if cache_key in self.cache:
 | 
			
		||||
            return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
 | 
			
		||||
            return self.cache[cache_key][None].repeat(B, 1, 1, 1)
 | 
			
		||||
 | 
			
		||||
        y_embed = (
 | 
			
		||||
            torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
 | 
			
		||||
            torch.arange(1, H + 1, dtype=torch.float32, device=device)
 | 
			
		||||
            .view(1, -1, 1)
 | 
			
		||||
            .repeat(x.shape[0], 1, x.shape[-1])
 | 
			
		||||
            .repeat(B, 1, W)
 | 
			
		||||
        )
 | 
			
		||||
        x_embed = (
 | 
			
		||||
            torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
 | 
			
		||||
            torch.arange(1, W + 1, dtype=torch.float32, device=device)
 | 
			
		||||
            .view(1, 1, -1)
 | 
			
		||||
            .repeat(x.shape[0], x.shape[-2], 1)
 | 
			
		||||
            .repeat(B, H, 1)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.normalize:
 | 
			
		||||
@ -96,7 +108,7 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
 | 
			
		||||
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
 | 
			
		||||
 | 
			
		||||
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
 | 
			
		||||
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
 | 
			
		||||
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
 | 
			
		||||
 | 
			
		||||
        pos_x = x_embed[:, :, :, None] / dim_t
 | 
			
		||||
@ -111,6 +123,13 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
        self.cache[cache_key] = pos[0]
 | 
			
		||||
        return pos
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        device = torch.device("cuda") if x.is_cuda else x.device
 | 
			
		||||
        B = x.shape[0]
 | 
			
		||||
        cache_key = (x.shape[-2], x.shape[-1], device)
 | 
			
		||||
        return self._pe(B, *cache_key)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PositionEmbeddingRandom(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -92,12 +92,32 @@ class PromptEncoder(nn.Module):
 | 
			
		||||
        point_embedding = self.pe_layer.forward_with_coords(
 | 
			
		||||
            points, self.input_image_size
 | 
			
		||||
        )
 | 
			
		||||
        point_embedding[labels == -1] = 0.0
 | 
			
		||||
        point_embedding[labels == -1] += self.not_a_point_embed.weight
 | 
			
		||||
        point_embedding[labels == 0] += self.point_embeddings[0].weight
 | 
			
		||||
        point_embedding[labels == 1] += self.point_embeddings[1].weight
 | 
			
		||||
        point_embedding[labels == 2] += self.point_embeddings[2].weight
 | 
			
		||||
        point_embedding[labels == 3] += self.point_embeddings[3].weight
 | 
			
		||||
 | 
			
		||||
        point_embedding = torch.where(
 | 
			
		||||
            (labels == -1).unsqueeze(-1),
 | 
			
		||||
            torch.zeros_like(point_embedding) + self.not_a_point_embed.weight,
 | 
			
		||||
            point_embedding,
 | 
			
		||||
        )
 | 
			
		||||
        point_embedding = torch.where(
 | 
			
		||||
            (labels == 0).unsqueeze(-1),
 | 
			
		||||
            point_embedding + self.point_embeddings[0].weight,
 | 
			
		||||
            point_embedding,
 | 
			
		||||
        )
 | 
			
		||||
        point_embedding = torch.where(
 | 
			
		||||
            (labels == 1).unsqueeze(-1),
 | 
			
		||||
            point_embedding + self.point_embeddings[1].weight,
 | 
			
		||||
            point_embedding,
 | 
			
		||||
        )
 | 
			
		||||
        point_embedding = torch.where(
 | 
			
		||||
            (labels == 2).unsqueeze(-1),
 | 
			
		||||
            point_embedding + self.point_embeddings[2].weight,
 | 
			
		||||
            point_embedding,
 | 
			
		||||
        )
 | 
			
		||||
        point_embedding = torch.where(
 | 
			
		||||
            (labels == 3).unsqueeze(-1),
 | 
			
		||||
            point_embedding + self.point_embeddings[3].weight,
 | 
			
		||||
            point_embedding,
 | 
			
		||||
        )
 | 
			
		||||
        return point_embedding
 | 
			
		||||
 | 
			
		||||
    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
@ -4,9 +4,7 @@
 | 
			
		||||
# This source code is licensed under the license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Tuple, Type
 | 
			
		||||
 | 
			
		||||
@ -16,29 +14,6 @@ from torch import nn, Tensor
 | 
			
		||||
 | 
			
		||||
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
 | 
			
		||||
from sam2.modeling.sam2_utils import MLP
 | 
			
		||||
from sam2.utils.misc import get_sdpa_settings
 | 
			
		||||
 | 
			
		||||
warnings.simplefilter(action="ignore", category=FutureWarning)
 | 
			
		||||
# Check whether Flash Attention is available (and use it by default)
 | 
			
		||||
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
 | 
			
		||||
# A fallback setting to allow all available kernels if Flash Attention fails
 | 
			
		||||
ALLOW_ALL_KERNELS = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sdp_kernel_context(dropout_p):
 | 
			
		||||
    """
 | 
			
		||||
    Get the context for the attention scaled dot-product kernel. We use Flash Attention
 | 
			
		||||
    by default, but fall back to all available kernels if Flash Attention fails.
 | 
			
		||||
    """
 | 
			
		||||
    if ALLOW_ALL_KERNELS:
 | 
			
		||||
        return contextlib.nullcontext()
 | 
			
		||||
 | 
			
		||||
    return torch.backends.cuda.sdp_kernel(
 | 
			
		||||
        enable_flash=USE_FLASH_ATTN,
 | 
			
		||||
        # if Flash attention kernel is off, then math kernel needs to be enabled
 | 
			
		||||
        enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
 | 
			
		||||
        enable_mem_efficient=OLD_GPU,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TwoWayTransformer(nn.Module):
 | 
			
		||||
@ -265,20 +240,7 @@ class Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
        dropout_p = self.dropout_p if self.training else 0.0
 | 
			
		||||
        # Attention
 | 
			
		||||
        try:
 | 
			
		||||
            with sdp_kernel_context(dropout_p):
 | 
			
		||||
                out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            # Fall back to all kernels if the Flash attention kernel fails
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
 | 
			
		||||
                f"kernels for scaled_dot_product_attention (which may have a slower speed).",
 | 
			
		||||
                category=UserWarning,
 | 
			
		||||
                stacklevel=2,
 | 
			
		||||
            )
 | 
			
		||||
            global ALLOW_ALL_KERNELS
 | 
			
		||||
            ALLOW_ALL_KERNELS = True
 | 
			
		||||
            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
 | 
			
		||||
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
 | 
			
		||||
 | 
			
		||||
        out = self._recombine_heads(out)
 | 
			
		||||
        out = self.out_proj(out)
 | 
			
		||||
@ -296,7 +258,7 @@ class RoPEAttention(Attention):
 | 
			
		||||
        # whether to repeat q rope to match k length
 | 
			
		||||
        # this is needed for cross-attention to memories
 | 
			
		||||
        rope_k_repeat=False,
 | 
			
		||||
        feat_sizes=(32, 32),  # [w, h] for stride 16 feats at 512 resolution
 | 
			
		||||
        feat_sizes=(64, 64),  # [w, h] for stride 16 feats at 1024 resolution
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
@ -305,7 +267,9 @@ class RoPEAttention(Attention):
 | 
			
		||||
            compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
 | 
			
		||||
        )
 | 
			
		||||
        freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
 | 
			
		||||
        self.freqs_cis = freqs_cis
 | 
			
		||||
        self.freqs_cis = (
 | 
			
		||||
            freqs_cis.to("cuda") if torch.cuda.is_available() else freqs_cis
 | 
			
		||||
        )
 | 
			
		||||
        self.rope_k_repeat = rope_k_repeat
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
@ -339,20 +303,7 @@ class RoPEAttention(Attention):
 | 
			
		||||
 | 
			
		||||
        dropout_p = self.dropout_p if self.training else 0.0
 | 
			
		||||
        # Attention
 | 
			
		||||
        try:
 | 
			
		||||
            with sdp_kernel_context(dropout_p):
 | 
			
		||||
                out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            # Fall back to all kernels if the Flash attention kernel fails
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
 | 
			
		||||
                f"kernels for scaled_dot_product_attention (which may have a slower speed).",
 | 
			
		||||
                category=UserWarning,
 | 
			
		||||
                stacklevel=2,
 | 
			
		||||
            )
 | 
			
		||||
            global ALLOW_ALL_KERNELS
 | 
			
		||||
            ALLOW_ALL_KERNELS = True
 | 
			
		||||
            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
 | 
			
		||||
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
 | 
			
		||||
 | 
			
		||||
        out = self._recombine_heads(out)
 | 
			
		||||
        out = self.out_proj(out)
 | 
			
		||||
 | 
			
		||||
@ -628,7 +628,11 @@ class SAM2Base(torch.nn.Module):
 | 
			
		||||
                    if self.add_tpos_enc_to_obj_ptrs:
 | 
			
		||||
                        t_diff_max = max_obj_ptrs_in_encoder - 1
 | 
			
		||||
                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
 | 
			
		||||
                        obj_pos = torch.tensor(pos_list, device=device)
 | 
			
		||||
                        obj_pos = (
 | 
			
		||||
                            torch.tensor(pos_list)
 | 
			
		||||
                            .pin_memory()
 | 
			
		||||
                            .to(device=device, non_blocking=True)
 | 
			
		||||
                        )
 | 
			
		||||
                        obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
 | 
			
		||||
                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
 | 
			
		||||
                        obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import warnings
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
@ -1170,3 +1171,262 @@ class SAM2VideoPredictor(SAM2Base):
 | 
			
		||||
            non_cond_frame_outputs.pop(t, None)
 | 
			
		||||
            for obj_output_dict in inference_state["output_dict_per_obj"].values():
 | 
			
		||||
                obj_output_dict["non_cond_frame_outputs"].pop(t, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SAM2VideoPredictorVOS(SAM2VideoPredictor):
 | 
			
		||||
    """Optimized for the VOS setting"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self._compile_all_components()
 | 
			
		||||
 | 
			
		||||
    def _compile_all_components(self):
 | 
			
		||||
        print(
 | 
			
		||||
            "Compiling all components for for vos setting. First time may be very slow."
 | 
			
		||||
        )
 | 
			
		||||
        self.memory_encoder.forward = torch.compile(
 | 
			
		||||
            self.memory_encoder.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.memory_attention.forward = torch.compile(
 | 
			
		||||
            self.memory_attention.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.sam_prompt_encoder.get_dense_pe = torch.compile(
 | 
			
		||||
            self.sam_prompt_encoder.get_dense_pe,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.sam_prompt_encoder.forward = torch.compile(
 | 
			
		||||
            self.sam_prompt_encoder.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.sam_mask_decoder.forward = torch.compile(
 | 
			
		||||
            self.sam_mask_decoder.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward_image(self, img_batch: torch.Tensor):
 | 
			
		||||
        """
 | 
			
		||||
        Identical to the corresponding method in the parent (SAM2VideoPredictor), but
 | 
			
		||||
        cloning the backbone features and pos encoding to enable compilation.
 | 
			
		||||
        """
 | 
			
		||||
        backbone_out = self.image_encoder(img_batch)
 | 
			
		||||
        if self.use_high_res_features_in_sam:
 | 
			
		||||
            # precompute projected level 0 and level 1 features in SAM decoder
 | 
			
		||||
            # to avoid running it again on every SAM click
 | 
			
		||||
            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
 | 
			
		||||
                backbone_out["backbone_fpn"][0]
 | 
			
		||||
            )
 | 
			
		||||
            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
 | 
			
		||||
                backbone_out["backbone_fpn"][1]
 | 
			
		||||
            )
 | 
			
		||||
        # Clone to help torch.compile
 | 
			
		||||
        for i in range(len(backbone_out["backbone_fpn"])):
 | 
			
		||||
            backbone_out["backbone_fpn"][i] = backbone_out["backbone_fpn"][i].clone()
 | 
			
		||||
            backbone_out["vision_pos_enc"][i] = backbone_out["vision_pos_enc"][
 | 
			
		||||
                i
 | 
			
		||||
            ].clone()
 | 
			
		||||
        return backbone_out
 | 
			
		||||
 | 
			
		||||
    def _forward_sam_heads(
 | 
			
		||||
        self,
 | 
			
		||||
        backbone_features,
 | 
			
		||||
        point_inputs=None,
 | 
			
		||||
        mask_inputs=None,
 | 
			
		||||
        high_res_features=None,
 | 
			
		||||
        multimask_output=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Identical to the corresponding method in the parent (SAM2VideoPredictor), but
 | 
			
		||||
        cloning the outputs of prompt_encoder and mask_decoder to enable compilation.
 | 
			
		||||
        """
 | 
			
		||||
        B = backbone_features.size(0)
 | 
			
		||||
        device = backbone_features.device
 | 
			
		||||
        assert backbone_features.size(1) == self.sam_prompt_embed_dim
 | 
			
		||||
        assert backbone_features.size(2) == self.sam_image_embedding_size
 | 
			
		||||
        assert backbone_features.size(3) == self.sam_image_embedding_size
 | 
			
		||||
 | 
			
		||||
        # a) Handle point prompts
 | 
			
		||||
        if point_inputs is not None:
 | 
			
		||||
            sam_point_coords = point_inputs["point_coords"]
 | 
			
		||||
            sam_point_labels = point_inputs["point_labels"]
 | 
			
		||||
            assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
 | 
			
		||||
        else:
 | 
			
		||||
            # If no points are provide, pad with an empty point (with label -1)
 | 
			
		||||
            sam_point_coords = torch.zeros(B, 1, 2, device=device)
 | 
			
		||||
            sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
 | 
			
		||||
 | 
			
		||||
        # b) Handle mask prompts
 | 
			
		||||
        if mask_inputs is not None:
 | 
			
		||||
            # If mask_inputs is provided, downsize it into low-res mask input if needed
 | 
			
		||||
            # and feed it as a dense mask prompt into the SAM mask encoder
 | 
			
		||||
            assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
 | 
			
		||||
            if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
 | 
			
		||||
                sam_mask_prompt = F.interpolate(
 | 
			
		||||
                    mask_inputs.float(),
 | 
			
		||||
                    size=self.sam_prompt_encoder.mask_input_size,
 | 
			
		||||
                    align_corners=False,
 | 
			
		||||
                    mode="bilinear",
 | 
			
		||||
                    antialias=True,  # use antialias for downsampling
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                sam_mask_prompt = mask_inputs
 | 
			
		||||
        else:
 | 
			
		||||
            # Otherwise, simply feed None (and SAM's prompt encoder will add
 | 
			
		||||
            # a learned `no_mask_embed` to indicate no mask input in this case).
 | 
			
		||||
            sam_mask_prompt = None
 | 
			
		||||
 | 
			
		||||
        sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
 | 
			
		||||
            points=(sam_point_coords, sam_point_labels),
 | 
			
		||||
            boxes=None,
 | 
			
		||||
            masks=sam_mask_prompt,
 | 
			
		||||
        )
 | 
			
		||||
        # Clone image_pe and the outputs of sam_prompt_encoder
 | 
			
		||||
        # to enable compilation
 | 
			
		||||
        sparse_embeddings = sparse_embeddings.clone()
 | 
			
		||||
        dense_embeddings = dense_embeddings.clone()
 | 
			
		||||
        image_pe = self.sam_prompt_encoder.get_dense_pe().clone()
 | 
			
		||||
        (
 | 
			
		||||
            low_res_multimasks,
 | 
			
		||||
            ious,
 | 
			
		||||
            sam_output_tokens,
 | 
			
		||||
            object_score_logits,
 | 
			
		||||
        ) = self.sam_mask_decoder(
 | 
			
		||||
            image_embeddings=backbone_features,
 | 
			
		||||
            image_pe=image_pe,
 | 
			
		||||
            sparse_prompt_embeddings=sparse_embeddings,
 | 
			
		||||
            dense_prompt_embeddings=dense_embeddings,
 | 
			
		||||
            multimask_output=multimask_output,
 | 
			
		||||
            repeat_image=False,  # the image is already batched
 | 
			
		||||
            high_res_features=high_res_features,
 | 
			
		||||
        )
 | 
			
		||||
        # Clone the output of sam_mask_decoder
 | 
			
		||||
        # to enable compilation
 | 
			
		||||
        low_res_multimasks = low_res_multimasks.clone()
 | 
			
		||||
        ious = ious.clone()
 | 
			
		||||
        sam_output_tokens = sam_output_tokens.clone()
 | 
			
		||||
        object_score_logits = object_score_logits.clone()
 | 
			
		||||
 | 
			
		||||
        if self.pred_obj_scores:
 | 
			
		||||
            is_obj_appearing = object_score_logits > 0
 | 
			
		||||
 | 
			
		||||
            # Mask used for spatial memories is always a *hard* choice between obj and no obj,
 | 
			
		||||
            # consistent with the actual mask prediction
 | 
			
		||||
            low_res_multimasks = torch.where(
 | 
			
		||||
                is_obj_appearing[:, None, None],
 | 
			
		||||
                low_res_multimasks,
 | 
			
		||||
                NO_OBJ_SCORE,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # convert masks from possibly bfloat16 (or float16) to float32
 | 
			
		||||
        # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
 | 
			
		||||
        low_res_multimasks = low_res_multimasks.float()
 | 
			
		||||
        high_res_multimasks = F.interpolate(
 | 
			
		||||
            low_res_multimasks,
 | 
			
		||||
            size=(self.image_size, self.image_size),
 | 
			
		||||
            mode="bilinear",
 | 
			
		||||
            align_corners=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        sam_output_token = sam_output_tokens[:, 0]
 | 
			
		||||
        if multimask_output:
 | 
			
		||||
            # take the best mask prediction (with the highest IoU estimation)
 | 
			
		||||
            best_iou_inds = torch.argmax(ious, dim=-1)
 | 
			
		||||
            batch_inds = torch.arange(B, device=device)
 | 
			
		||||
            low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
 | 
			
		||||
            high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
 | 
			
		||||
            if sam_output_tokens.size(1) > 1:
 | 
			
		||||
                sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
 | 
			
		||||
        else:
 | 
			
		||||
            low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
 | 
			
		||||
 | 
			
		||||
        # Extract object pointer from the SAM output token (with occlusion handling)
 | 
			
		||||
        obj_ptr = self.obj_ptr_proj(sam_output_token)
 | 
			
		||||
        if self.pred_obj_scores:
 | 
			
		||||
            # Allow *soft* no obj ptr, unlike for masks
 | 
			
		||||
            if self.soft_no_obj_ptr:
 | 
			
		||||
                lambda_is_obj_appearing = object_score_logits.sigmoid()
 | 
			
		||||
            else:
 | 
			
		||||
                lambda_is_obj_appearing = is_obj_appearing.float()
 | 
			
		||||
 | 
			
		||||
            if self.fixed_no_obj_ptr:
 | 
			
		||||
                obj_ptr = lambda_is_obj_appearing * obj_ptr
 | 
			
		||||
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            low_res_multimasks,
 | 
			
		||||
            high_res_multimasks,
 | 
			
		||||
            ious,
 | 
			
		||||
            low_res_masks,
 | 
			
		||||
            high_res_masks,
 | 
			
		||||
            obj_ptr,
 | 
			
		||||
            object_score_logits,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _encode_new_memory(
 | 
			
		||||
        self,
 | 
			
		||||
        current_vision_feats,
 | 
			
		||||
        feat_sizes,
 | 
			
		||||
        pred_masks_high_res,
 | 
			
		||||
        object_score_logits,
 | 
			
		||||
        is_mask_from_pts,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Identical to the corresponding method in the parent (SAM2VideoPredictor), but
 | 
			
		||||
        cloning the memories and their pos enc to enable compilation.
 | 
			
		||||
        """
 | 
			
		||||
        B = current_vision_feats[-1].size(1)  # batch size on this frame
 | 
			
		||||
        C = self.hidden_dim
 | 
			
		||||
        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
 | 
			
		||||
        # top-level feature, (HW)BC => BCHW
 | 
			
		||||
        pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
 | 
			
		||||
        if self.non_overlap_masks_for_mem_enc and not self.training:
 | 
			
		||||
            # optionally, apply non-overlapping constraints to the masks (it's applied
 | 
			
		||||
            # in the batch dimension and should only be used during eval, where all
 | 
			
		||||
            # the objects come from the same video under batch size 1).
 | 
			
		||||
            pred_masks_high_res = self._apply_non_overlapping_constraints(
 | 
			
		||||
                pred_masks_high_res
 | 
			
		||||
            )
 | 
			
		||||
        # scale the raw mask logits with a temperature before applying sigmoid
 | 
			
		||||
        binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
 | 
			
		||||
        if binarize and not self.training:
 | 
			
		||||
            mask_for_mem = (pred_masks_high_res > 0).float()
 | 
			
		||||
        else:
 | 
			
		||||
            # apply sigmoid on the raw mask logits to turn them into range (0, 1)
 | 
			
		||||
            mask_for_mem = torch.sigmoid(pred_masks_high_res)
 | 
			
		||||
        # apply scale and bias terms to the sigmoid probabilities
 | 
			
		||||
        if self.sigmoid_scale_for_mem_enc != 1.0:
 | 
			
		||||
            mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
 | 
			
		||||
        if self.sigmoid_bias_for_mem_enc != 0.0:
 | 
			
		||||
            mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
 | 
			
		||||
        maskmem_out = self.memory_encoder(
 | 
			
		||||
            pix_feat, mask_for_mem, skip_mask_sigmoid=True  # sigmoid already applied
 | 
			
		||||
        )
 | 
			
		||||
        # Clone the feats and pos_enc to enable compilation
 | 
			
		||||
        maskmem_features = maskmem_out["vision_features"].clone()
 | 
			
		||||
        maskmem_pos_enc = [m.clone() for m in maskmem_out["vision_pos_enc"]]
 | 
			
		||||
        # add a no-object embedding to the spatial memory to indicate that the frame
 | 
			
		||||
        # is predicted to be occluded (i.e. no object is appearing in the frame)
 | 
			
		||||
        if self.no_obj_embed_spatial is not None:
 | 
			
		||||
            is_obj_appearing = (object_score_logits > 0).float()
 | 
			
		||||
            maskmem_features += (
 | 
			
		||||
                1 - is_obj_appearing[..., None, None]
 | 
			
		||||
            ) * self.no_obj_embed_spatial[..., None, None].expand(
 | 
			
		||||
                *maskmem_features.shape
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return maskmem_features, maskmem_pos_enc
 | 
			
		||||
 | 
			
		||||
@ -375,7 +375,7 @@ def main():
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--sam2_checkpoint",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="./checkpoints/sam2.1_hiera_b+.pt",
 | 
			
		||||
        default="./checkpoints/sam2.1_hiera_base_plus.pt",
 | 
			
		||||
        help="path to the SAM 2 model checkpoint",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
@ -434,6 +434,11 @@ def main():
 | 
			
		||||
        help="whether to track objects that appear later in the video (i.e. not on the first frame; "
 | 
			
		||||
        "some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--use_vos_optimized_video_predictor",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="whether to use vos optimized video predictor with all modules compiled",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # if we use per-object PNG files, they could possibly overlap in inputs and outputs
 | 
			
		||||
@ -445,6 +450,7 @@ def main():
 | 
			
		||||
        ckpt_path=args.sam2_checkpoint,
 | 
			
		||||
        apply_postprocessing=args.apply_postprocessing,
 | 
			
		||||
        hydra_overrides_extra=hydra_overrides_extra,
 | 
			
		||||
        vos_optimized=args.use_vos_optimized_video_predictor,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if args.use_all_masks:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user