mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 11:32:12 +08:00 
			
		
		
		
	move to sincos pos enc to device on fwd pass and remove dynamic causing accuracy regressions
This commit is contained in:
		
							parent
							
								
									9851575bf3
								
							
						
					
					
						commit
						11f99b63c8
					
				@ -47,8 +47,8 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
            # Warmup cache for cuda, to help with compilation
 | 
			
		||||
            device = torch.device("cuda")
 | 
			
		||||
            for stride in strides:
 | 
			
		||||
                cache_key = (image_size // stride, image_size // stride, device)
 | 
			
		||||
                self._pe(1, *cache_key)
 | 
			
		||||
                cache_key = (image_size // stride, image_size // stride)
 | 
			
		||||
                self._pe(1, device, *cache_key)
 | 
			
		||||
 | 
			
		||||
    def _encode_xy(self, x, y):
 | 
			
		||||
        # The positions are expected to be normalized
 | 
			
		||||
@ -87,10 +87,10 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
        return pos
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def _pe(self, B, *cache_key):
 | 
			
		||||
        H, W, device = cache_key
 | 
			
		||||
    def _pe(self, B, device, *cache_key):
 | 
			
		||||
        H, W = cache_key
 | 
			
		||||
        if cache_key in self.cache:
 | 
			
		||||
            return self.cache[cache_key][None].repeat(B, 1, 1, 1)
 | 
			
		||||
            return self.cache[cache_key].to(device)[None].repeat(B, 1, 1, 1)
 | 
			
		||||
 | 
			
		||||
        y_embed = (
 | 
			
		||||
            torch.arange(1, H + 1, dtype=torch.float32, device=device)
 | 
			
		||||
@ -125,10 +125,9 @@ class PositionEmbeddingSine(nn.Module):
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        device = torch.device("cuda") if x.is_cuda else x.device
 | 
			
		||||
        B = x.shape[0]
 | 
			
		||||
        cache_key = (x.shape[-2], x.shape[-1], device)
 | 
			
		||||
        return self._pe(B, *cache_key)
 | 
			
		||||
        cache_key = (x.shape[-2], x.shape[-1])
 | 
			
		||||
        return self._pe(B, x.device, *cache_key)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PositionEmbeddingRandom(nn.Module):
 | 
			
		||||
 | 
			
		||||
@ -1188,35 +1188,28 @@ class SAM2VideoPredictorVOS(SAM2VideoPredictor):
 | 
			
		||||
            self.memory_encoder.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
            dynamic=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.memory_attention.forward = torch.compile(
 | 
			
		||||
            self.memory_attention.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.sam_prompt_encoder.get_dense_pe = torch.compile(
 | 
			
		||||
            self.sam_prompt_encoder.get_dense_pe,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
            dynamic=True,  # Num. of memories varies
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.sam_prompt_encoder.forward = torch.compile(
 | 
			
		||||
            self.sam_prompt_encoder.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
            dynamic=False,  # Accuracy regression on True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.sam_mask_decoder.forward = torch.compile(
 | 
			
		||||
            self.sam_mask_decoder.forward,
 | 
			
		||||
            mode="max-autotune",
 | 
			
		||||
            fullgraph=True,
 | 
			
		||||
            dynamic=True,
 | 
			
		||||
            dynamic=False,  # Accuracy regression on True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward_image(self, img_batch: torch.Tensor):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user