diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py index f5993d3..2241d4c 100644 --- a/sam2/modeling/position_encoding.py +++ b/sam2/modeling/position_encoding.py @@ -47,8 +47,8 @@ class PositionEmbeddingSine(nn.Module): # Warmup cache for cuda, to help with compilation device = torch.device("cuda") for stride in strides: - cache_key = (image_size // stride, image_size // stride, device) - self._pe(1, *cache_key) + cache_key = (image_size // stride, image_size // stride) + self._pe(1, device, *cache_key) def _encode_xy(self, x, y): # The positions are expected to be normalized @@ -87,10 +87,10 @@ class PositionEmbeddingSine(nn.Module): return pos @torch.no_grad() - def _pe(self, B, *cache_key): - H, W, device = cache_key + def _pe(self, B, device, *cache_key): + H, W = cache_key if cache_key in self.cache: - return self.cache[cache_key][None].repeat(B, 1, 1, 1) + return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1) y_embed = ( torch.arange(1, H + 1, dtype=torch.float32, device=device) @@ -125,10 +125,9 @@ class PositionEmbeddingSine(nn.Module): @torch.no_grad() def forward(self, x: torch.Tensor): - device = torch.device("cuda") if x.is_cuda else x.device B = x.shape[0] - cache_key = (x.shape[-2], x.shape[-1], device) - return self._pe(B, *cache_key) + cache_key = (x.shape[-2], x.shape[-1]) + return self._pe(B, x.device, *cache_key) class PositionEmbeddingRandom(nn.Module): diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 055cde5..62f45e1 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -1188,35 +1188,28 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor): self.memory_encoder.forward, mode="max-autotune", fullgraph=True, - dynamic=True, + dynamic=False, ) self.memory_attention.forward = torch.compile( self.memory_attention.forward, mode="max-autotune", fullgraph=True, - dynamic=True, - ) - - self.sam_prompt_encoder.get_dense_pe = torch.compile( - self.sam_prompt_encoder.get_dense_pe, - mode="max-autotune", - fullgraph=True, - dynamic=True, + dynamic=True, # Num. of memories varies ) self.sam_prompt_encoder.forward = torch.compile( self.sam_prompt_encoder.forward, mode="max-autotune", fullgraph=True, - dynamic=True, + dynamic=False, # Accuracy regression on True ) self.sam_mask_decoder.forward = torch.compile( self.sam_mask_decoder.forward, mode="max-autotune", fullgraph=True, - dynamic=True, + dynamic=False, # Accuracy regression on True ) def forward_image(self, img_batch: torch.Tensor):