mirror of
https://github.com/facebookresearch/sam2.git
synced 2025-09-18 12:42:48 +08:00
move to sincos pos enc to device on fwd pass and remove dynamic causing accuracy regressions
This commit is contained in:
parent
9851575bf3
commit
11f99b63c8
@ -47,8 +47,8 @@ class PositionEmbeddingSine(nn.Module):
|
||||
# Warmup cache for cuda, to help with compilation
|
||||
device = torch.device("cuda")
|
||||
for stride in strides:
|
||||
cache_key = (image_size // stride, image_size // stride, device)
|
||||
self._pe(1, *cache_key)
|
||||
cache_key = (image_size // stride, image_size // stride)
|
||||
self._pe(1, device, *cache_key)
|
||||
|
||||
def _encode_xy(self, x, y):
|
||||
# The positions are expected to be normalized
|
||||
@ -87,10 +87,10 @@ class PositionEmbeddingSine(nn.Module):
|
||||
return pos
|
||||
|
||||
@torch.no_grad()
|
||||
def _pe(self, B, *cache_key):
|
||||
H, W, device = cache_key
|
||||
def _pe(self, B, device, *cache_key):
|
||||
H, W = cache_key
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key][None].repeat(B, 1, 1, 1)
|
||||
return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
|
||||
|
||||
y_embed = (
|
||||
torch.arange(1, H + 1, dtype=torch.float32, device=device)
|
||||
@ -125,10 +125,9 @@ class PositionEmbeddingSine(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
device = torch.device("cuda") if x.is_cuda else x.device
|
||||
B = x.shape[0]
|
||||
cache_key = (x.shape[-2], x.shape[-1], device)
|
||||
return self._pe(B, *cache_key)
|
||||
cache_key = (x.shape[-2], x.shape[-1])
|
||||
return self._pe(B, x.device, *cache_key)
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
|
@ -1188,35 +1188,28 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
|
||||
self.memory_encoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=True,
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
self.memory_attention.forward = torch.compile(
|
||||
self.memory_attention.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=True,
|
||||
)
|
||||
|
||||
self.sam_prompt_encoder.get_dense_pe = torch.compile(
|
||||
self.sam_prompt_encoder.get_dense_pe,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=True,
|
||||
dynamic=True, # Num. of memories varies
|
||||
)
|
||||
|
||||
self.sam_prompt_encoder.forward = torch.compile(
|
||||
self.sam_prompt_encoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=True,
|
||||
dynamic=False, # Accuracy regression on True
|
||||
)
|
||||
|
||||
self.sam_mask_decoder.forward = torch.compile(
|
||||
self.sam_mask_decoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=True,
|
||||
dynamic=False, # Accuracy regression on True
|
||||
)
|
||||
|
||||
def forward_image(self, img_batch: torch.Tensor):
|
||||
|
Loading…
x
Reference in New Issue
Block a user