mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	avoid CPU/GPU sync in sample_farthest_points
Summary: Optimizing sample_farthest_poinst by reducing CPU/GPU sync: 1. replacing iterative randint for starting indexes for 1 function call, if length is constant 2. Avoid sync in fetching maxumum of sample points, if we sample the same amount 3. Initializing 1 tensor for samples and indixes compare https://fburl.com/mlhub/7wk0xi98 Before {F1980383703} after {F1980383707} Histogram match pretty closely {F1980464338} Reviewed By: bottler Differential Revision: D78731869 fbshipit-source-id: 060528ae7a1e0fbbd005d129c151eaf9405841de
This commit is contained in:
		
							parent
							
								
									e3d3a67a89
								
							
						
					
					
						commit
						5043d15361
					
				@ -107,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda(
 | 
				
			|||||||
    const at::Tensor& points, // (N, P, 3)
 | 
					    const at::Tensor& points, // (N, P, 3)
 | 
				
			||||||
    const at::Tensor& lengths, // (N,)
 | 
					    const at::Tensor& lengths, // (N,)
 | 
				
			||||||
    const at::Tensor& K, // (N,)
 | 
					    const at::Tensor& K, // (N,)
 | 
				
			||||||
    const at::Tensor& start_idxs) {
 | 
					    const at::Tensor& start_idxs,
 | 
				
			||||||
 | 
					    const int64_t max_K_known = -1) {
 | 
				
			||||||
  // Check inputs are on the same device
 | 
					  // Check inputs are on the same device
 | 
				
			||||||
  at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
 | 
					  at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
 | 
				
			||||||
      k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
 | 
					      k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
 | 
				
			||||||
@ -129,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  const int64_t N = points.size(0);
 | 
					  const int64_t N = points.size(0);
 | 
				
			||||||
  const int64_t P = points.size(1);
 | 
					  const int64_t P = points.size(1);
 | 
				
			||||||
  const int64_t max_K = at::max(K).item<int64_t>();
 | 
					  int64_t max_K;
 | 
				
			||||||
 | 
					  if (max_K_known > 0) {
 | 
				
			||||||
 | 
					    max_K = max_K_known;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    max_K = at::max(K).item<int64_t>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Initialize the output tensor with the sampled indices
 | 
					  // Initialize the output tensor with the sampled indices
 | 
				
			||||||
  auto idxs = at::full({N, max_K}, -1, lengths.options());
 | 
					  auto idxs = at::full({N, max_K}, -1, lengths.options());
 | 
				
			||||||
 | 
				
			|||||||
@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
 | 
				
			|||||||
    const at::Tensor& points,
 | 
					    const at::Tensor& points,
 | 
				
			||||||
    const at::Tensor& lengths,
 | 
					    const at::Tensor& lengths,
 | 
				
			||||||
    const at::Tensor& K,
 | 
					    const at::Tensor& K,
 | 
				
			||||||
    const at::Tensor& start_idxs);
 | 
					    const at::Tensor& start_idxs,
 | 
				
			||||||
 | 
					    const int64_t max_K_known = -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
at::Tensor FarthestPointSamplingCpu(
 | 
					at::Tensor FarthestPointSamplingCpu(
 | 
				
			||||||
    const at::Tensor& points,
 | 
					    const at::Tensor& points,
 | 
				
			||||||
@ -56,14 +57,16 @@ at::Tensor FarthestPointSampling(
 | 
				
			|||||||
    const at::Tensor& points,
 | 
					    const at::Tensor& points,
 | 
				
			||||||
    const at::Tensor& lengths,
 | 
					    const at::Tensor& lengths,
 | 
				
			||||||
    const at::Tensor& K,
 | 
					    const at::Tensor& K,
 | 
				
			||||||
    const at::Tensor& start_idxs) {
 | 
					    const at::Tensor& start_idxs,
 | 
				
			||||||
 | 
					    const int64_t max_K_known = -1) {
 | 
				
			||||||
  if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
 | 
					  if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
 | 
				
			||||||
#ifdef WITH_CUDA
 | 
					#ifdef WITH_CUDA
 | 
				
			||||||
    CHECK_CUDA(points);
 | 
					    CHECK_CUDA(points);
 | 
				
			||||||
    CHECK_CUDA(lengths);
 | 
					    CHECK_CUDA(lengths);
 | 
				
			||||||
    CHECK_CUDA(K);
 | 
					    CHECK_CUDA(K);
 | 
				
			||||||
    CHECK_CUDA(start_idxs);
 | 
					    CHECK_CUDA(start_idxs);
 | 
				
			||||||
    return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
 | 
					    return FarthestPointSamplingCuda(
 | 
				
			||||||
 | 
					        points, lengths, K, start_idxs, max_K_known);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    AT_ERROR("Not compiled with GPU support.");
 | 
					    AT_ERROR("Not compiled with GPU support.");
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
				
			|||||||
@ -55,6 +55,7 @@ def sample_farthest_points(
 | 
				
			|||||||
    N, P, D = points.shape
 | 
					    N, P, D = points.shape
 | 
				
			||||||
    device = points.device
 | 
					    device = points.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    constant_length = lengths is None
 | 
				
			||||||
    # Validate inputs
 | 
					    # Validate inputs
 | 
				
			||||||
    if lengths is None:
 | 
					    if lengths is None:
 | 
				
			||||||
        lengths = torch.full((N,), P, dtype=torch.int64, device=device)
 | 
					        lengths = torch.full((N,), P, dtype=torch.int64, device=device)
 | 
				
			||||||
@ -65,7 +66,9 @@ def sample_farthest_points(
 | 
				
			|||||||
            raise ValueError("A value in lengths was too large.")
 | 
					            raise ValueError("A value in lengths was too large.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: support providing K as a ratio of the total number of points instead of as an int
 | 
					    # TODO: support providing K as a ratio of the total number of points instead of as an int
 | 
				
			||||||
 | 
					    max_K = -1
 | 
				
			||||||
    if isinstance(K, int):
 | 
					    if isinstance(K, int):
 | 
				
			||||||
 | 
					        max_K = K
 | 
				
			||||||
        K = torch.full((N,), K, dtype=torch.int64, device=device)
 | 
					        K = torch.full((N,), K, dtype=torch.int64, device=device)
 | 
				
			||||||
    elif isinstance(K, list):
 | 
					    elif isinstance(K, list):
 | 
				
			||||||
        K = torch.tensor(K, dtype=torch.int64, device=device)
 | 
					        K = torch.tensor(K, dtype=torch.int64, device=device)
 | 
				
			||||||
@ -82,15 +85,17 @@ def sample_farthest_points(
 | 
				
			|||||||
        K = K.to(torch.int64)
 | 
					        K = K.to(torch.int64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Generate the starting indices for sampling
 | 
					    # Generate the starting indices for sampling
 | 
				
			||||||
    start_idxs = torch.zeros_like(lengths)
 | 
					 | 
				
			||||||
    if random_start_point:
 | 
					    if random_start_point:
 | 
				
			||||||
        for n in range(N):
 | 
					        if constant_length:
 | 
				
			||||||
            # pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
 | 
					            start_idxs = torch.randint(high=P, size=(N,), device=device)
 | 
				
			||||||
            start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
 | 
					        else:
 | 
				
			||||||
 | 
					            start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        start_idxs = torch.zeros_like(lengths)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with torch.no_grad():
 | 
					    with torch.no_grad():
 | 
				
			||||||
        # pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
 | 
					        # pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
 | 
				
			||||||
        idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
 | 
					        idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K)
 | 
				
			||||||
    sampled_points = masked_gather(points, idx)
 | 
					    sampled_points = masked_gather(points, idx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return sampled_points, idx
 | 
					    return sampled_points, idx
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user