mirror of
https://github.com/facebookresearch/sam2.git
synced 2025-09-18 12:42:48 +08:00
move to sincos pos enc to device on fwd pass and remove dynamic causing accuracy regressions
This commit is contained in:
parent
9851575bf3
commit
11f99b63c8
@ -47,8 +47,8 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
# Warmup cache for cuda, to help with compilation
|
# Warmup cache for cuda, to help with compilation
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
for stride in strides:
|
for stride in strides:
|
||||||
cache_key = (image_size // stride, image_size // stride, device)
|
cache_key = (image_size // stride, image_size // stride)
|
||||||
self._pe(1, *cache_key)
|
self._pe(1, device, *cache_key)
|
||||||
|
|
||||||
def _encode_xy(self, x, y):
|
def _encode_xy(self, x, y):
|
||||||
# The positions are expected to be normalized
|
# The positions are expected to be normalized
|
||||||
@ -87,10 +87,10 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
return pos
|
return pos
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _pe(self, B, *cache_key):
|
def _pe(self, B, device, *cache_key):
|
||||||
H, W, device = cache_key
|
H, W = cache_key
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
return self.cache[cache_key][None].repeat(B, 1, 1, 1)
|
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
||||||
|
|
||||||
y_embed = (
|
y_embed = (
|
||||||
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
||||||
@ -125,10 +125,9 @@ class PositionEmbeddingSine(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
device = torch.device("cuda") if x.is_cuda else x.device
|
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
cache_key = (x.shape[-2], x.shape[-1], device)
|
cache_key = (x.shape[-2], x.shape[-1])
|
||||||
return self._pe(B, *cache_key)
|
return self._pe(B, x.device, *cache_key)
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingRandom(nn.Module):
|
class PositionEmbeddingRandom(nn.Module):
|
||||||
|
@ -1188,35 +1188,28 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
|||||||
self.memory_encoder.forward,
|
self.memory_encoder.forward,
|
||||||
mode="max-autotune",
|
mode="max-autotune",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=True,
|
dynamic=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.memory_attention.forward = torch.compile(
|
self.memory_attention.forward = torch.compile(
|
||||||
self.memory_attention.forward,
|
self.memory_attention.forward,
|
||||||
mode="max-autotune",
|
mode="max-autotune",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=True,
|
dynamic=True, # Num. of memories varies
|
||||||
)
|
|
||||||
|
|
||||||
self.sam_prompt_encoder.get_dense_pe = torch.compile(
|
|
||||||
self.sam_prompt_encoder.get_dense_pe,
|
|
||||||
mode="max-autotune",
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sam_prompt_encoder.forward = torch.compile(
|
self.sam_prompt_encoder.forward = torch.compile(
|
||||||
self.sam_prompt_encoder.forward,
|
self.sam_prompt_encoder.forward,
|
||||||
mode="max-autotune",
|
mode="max-autotune",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=True,
|
dynamic=False, # Accuracy regression on True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sam_mask_decoder.forward = torch.compile(
|
self.sam_mask_decoder.forward = torch.compile(
|
||||||
self.sam_mask_decoder.forward,
|
self.sam_mask_decoder.forward,
|
||||||
mode="max-autotune",
|
mode="max-autotune",
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=True,
|
dynamic=False, # Accuracy regression on True
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_image(self, img_batch: torch.Tensor):
|
def forward_image(self, img_batch: torch.Tensor):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user