mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Optimize list_to_packed to avoid for loop (#1737)
Summary: For larger N and Mi value (e.g. N=154, Mi=238) I notice list_to_packed() has become a bottleneck for my application. By removing the for loop and running on GPU, i see a 10-20 x speedup. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1737 Reviewed By: MichaelRamamonjisoa Differential Revision: D54187993 Pulled By: bottler fbshipit-source-id: 16399a24cb63b48c30460c7d960abef603b115d0
This commit is contained in:
		
							parent
							
								
									128be02fc0
								
							
						
					
					
						commit
						ccf22911d4
					
				@ -135,22 +135,21 @@ def list_to_packed(x: List[torch.Tensor]):
 | 
			
		||||
        - **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the
 | 
			
		||||
          index of the element in the list the item belongs to.
 | 
			
		||||
    """
 | 
			
		||||
    N = len(x)
 | 
			
		||||
    num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device)
 | 
			
		||||
    item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device)
 | 
			
		||||
    item_packed_to_list_idx = []
 | 
			
		||||
    cur = 0
 | 
			
		||||
    for i, y in enumerate(x):
 | 
			
		||||
        num = len(y)
 | 
			
		||||
        num_items[i] = num
 | 
			
		||||
        item_packed_first_idx[i] = cur
 | 
			
		||||
        item_packed_to_list_idx.append(
 | 
			
		||||
            torch.full((num,), i, dtype=torch.int64, device=y.device)
 | 
			
		||||
        )
 | 
			
		||||
        cur += num
 | 
			
		||||
 | 
			
		||||
    if not x:
 | 
			
		||||
        raise ValueError("Input list is empty")
 | 
			
		||||
    device = x[0].device
 | 
			
		||||
    sizes = [xi.shape[0] for xi in x]
 | 
			
		||||
    sizes_total = sum(sizes)
 | 
			
		||||
    num_items = torch.tensor(sizes, dtype=torch.int64, device=device)
 | 
			
		||||
    item_packed_first_idx = torch.zeros_like(num_items)
 | 
			
		||||
    item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0)
 | 
			
		||||
    item_packed_to_list_idx = torch.arange(
 | 
			
		||||
        sizes_total, dtype=torch.int64, device=device
 | 
			
		||||
    )
 | 
			
		||||
    item_packed_to_list_idx = (
 | 
			
		||||
        torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1
 | 
			
		||||
    )
 | 
			
		||||
    x_packed = torch.cat(x, dim=0)
 | 
			
		||||
    item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0)
 | 
			
		||||
 | 
			
		||||
    return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user