move to sincos pos enc to device on fwd pass and remove dynamic causing accuracy regressions

This commit is contained in:
Chay Ryali 2024-12-11 06:29:00 +00:00
parent 9851575bf3
commit 11f99b63c8
2 changed files with 11 additions and 19 deletions

View File

@ -47,8 +47,8 @@ class PositionEmbeddingSine(nn.Module):
# Warmup cache for cuda, to help with compilation # Warmup cache for cuda, to help with compilation
device = torch.device("cuda") device = torch.device("cuda")
for stride in strides: for stride in strides:
cache_key = (image_size // stride, image_size // stride, device) cache_key = (image_size // stride, image_size // stride)
self._pe(1, *cache_key) self._pe(1, device, *cache_key)
def _encode_xy(self, x, y): def _encode_xy(self, x, y):
# The positions are expected to be normalized # The positions are expected to be normalized
@ -87,10 +87,10 @@ class PositionEmbeddingSine(nn.Module):
return pos return pos
@torch.no_grad() @torch.no_grad()
def _pe(self, B, *cache_key): def _pe(self, B, device, *cache_key):
H, W, device = cache_key H, W = cache_key
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key][None].repeat(B, 1, 1, 1) return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
y_embed = ( y_embed = (
torch.arange(1, H + 1, dtype=torch.float32, device=device) torch.arange(1, H + 1, dtype=torch.float32, device=device)
@ -125,10 +125,9 @@ class PositionEmbeddingSine(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
device = torch.device("cuda") if x.is_cuda else x.device
B = x.shape[0] B = x.shape[0]
cache_key = (x.shape[-2], x.shape[-1], device) cache_key = (x.shape[-2], x.shape[-1])
return self._pe(B, *cache_key) return self._pe(B, x.device, *cache_key)
class PositionEmbeddingRandom(nn.Module): class PositionEmbeddingRandom(nn.Module):

View File

@ -1188,35 +1188,28 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
self.memory_encoder.forward, self.memory_encoder.forward,
mode="max-autotune", mode="max-autotune",
fullgraph=True, fullgraph=True,
dynamic=True, dynamic=False,
) )
self.memory_attention.forward = torch.compile( self.memory_attention.forward = torch.compile(
self.memory_attention.forward, self.memory_attention.forward,
mode="max-autotune", mode="max-autotune",
fullgraph=True, fullgraph=True,
dynamic=True, dynamic=True, # Num. of memories varies
)
self.sam_prompt_encoder.get_dense_pe = torch.compile(
self.sam_prompt_encoder.get_dense_pe,
mode="max-autotune",
fullgraph=True,
dynamic=True,
) )
self.sam_prompt_encoder.forward = torch.compile( self.sam_prompt_encoder.forward = torch.compile(
self.sam_prompt_encoder.forward, self.sam_prompt_encoder.forward,
mode="max-autotune", mode="max-autotune",
fullgraph=True, fullgraph=True,
dynamic=True, dynamic=False, # Accuracy regression on True
) )
self.sam_mask_decoder.forward = torch.compile( self.sam_mask_decoder.forward = torch.compile(
self.sam_mask_decoder.forward, self.sam_mask_decoder.forward,
mode="max-autotune", mode="max-autotune",
fullgraph=True, fullgraph=True,
dynamic=True, dynamic=False, # Accuracy regression on True
) )
def forward_image(self, img_batch: torch.Tensor): def forward_image(self, img_batch: torch.Tensor):