mirror of
				https://github.com/facebookresearch/sam2.git
				synced 2025-11-04 11:32:12 +08:00 
			
		
		
		
	remove .pin_memory() in obj_pos of SAM2Base to resolve and error in MPS (#495)
				
					
				
			In this PR, we remove `.pin_memory()` in `obj_pos` of `SAM2Base` to resolve and error in MPS. Investigations show that `.pin_memory()` causes an error of `Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0"`, as originally reported in https://github.com/facebookresearch/sam2/issues/487. (close https://github.com/facebookresearch/sam2/issues/487)
This commit is contained in:
		
							parent
							
								
									722d1d1511
								
							
						
					
					
						commit
						2b90b9f5ce
					
				@ -628,10 +628,8 @@ class SAM2Base(torch.nn.Module):
 | 
			
		||||
                    if self.add_tpos_enc_to_obj_ptrs:
 | 
			
		||||
                        t_diff_max = max_obj_ptrs_in_encoder - 1
 | 
			
		||||
                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
 | 
			
		||||
                        obj_pos = (
 | 
			
		||||
                            torch.tensor(pos_list)
 | 
			
		||||
                            .pin_memory()
 | 
			
		||||
                            .to(device=device, non_blocking=True)
 | 
			
		||||
                        obj_pos = torch.tensor(pos_list).to(
 | 
			
		||||
                            device=device, non_blocking=True
 | 
			
		||||
                        )
 | 
			
		||||
                        obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
 | 
			
		||||
                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user