mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Removing dynamic shape ops and boolean indexing in matrix_to_quaternion
Summary: The current implementation of `matrix_to_quaternion` and `_sqrt_positive_part` uses boolean indexing, which can slow down performance and cause incompatibility with `torch.compile` unless `torch._dynamo.config.capture_dynamic_output_shape_ops` is set to `True`. To enhance performance and compatibility, I recommend using `torch.gather` to select the best-conditioned quaternions and `F.relu` instead of `x>0` (bottler's suggestion) For a detailed comparison of the implementation differences when using `torch.compile`, please refer to my Bento notebook N7438339. Reviewed By: bottler Differential Revision: D77176230 fbshipit-source-id: 9a6a2e0015b5865056297d5f45badc3c425b93ce
This commit is contained in:
		
							parent
							
								
									6020323d94
								
							
						
					
					
						commit
						71db7a0ea2
					
				@ -95,13 +95,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    Returns torch.sqrt(torch.max(0, x))
 | 
			
		||||
    but with a zero subgradient where x is 0.
 | 
			
		||||
    """
 | 
			
		||||
    ret = torch.zeros_like(x)
 | 
			
		||||
    positive_mask = x > 0
 | 
			
		||||
    if torch.is_grad_enabled():
 | 
			
		||||
        ret[positive_mask] = torch.sqrt(x[positive_mask])
 | 
			
		||||
    else:
 | 
			
		||||
        ret = torch.where(positive_mask, torch.sqrt(x), ret)
 | 
			
		||||
    return ret
 | 
			
		||||
    return torch.sqrt(F.relu(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
@ -160,9 +154,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
    # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
 | 
			
		||||
    # forall i; we pick the best-conditioned one (with the largest denominator)
 | 
			
		||||
    out = quat_candidates[
 | 
			
		||||
        F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
 | 
			
		||||
    ].reshape(batch_dim + (4,))
 | 
			
		||||
    indices = q_abs.argmax(dim=-1, keepdim=True)
 | 
			
		||||
    expand_dims = list(batch_dim) + [1, 4]
 | 
			
		||||
    gather_indices = indices.unsqueeze(-1).expand(expand_dims)
 | 
			
		||||
    out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
 | 
			
		||||
    return standardize_quaternion(out)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user