mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	fix device error
Summary: When using `sample_farthest_points` with `lengths`, it throws an error because of the device mismatch between `lengths` and `torch.rand(lengths.size())` on GPU. Reviewed By: bottler Differential Revision: D82378997 fbshipit-source-id: 8e929256177d543d1dd1249e8488f70e03e4101f
This commit is contained in:
		
							parent
							
								
									d098beb7a7
								
							
						
					
					
						commit
						7711bf34a8
					
				@ -89,7 +89,9 @@ def sample_farthest_points(
 | 
			
		||||
        if constant_length:
 | 
			
		||||
            start_idxs = torch.randint(high=P, size=(N,), device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
 | 
			
		||||
            start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to(
 | 
			
		||||
                torch.int64
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        start_idxs = torch.zeros_like(lengths)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user