Optimize list_to_packed to avoid for loop (#1737)

Summary:
For larger N and Mi value (e.g. N=154, Mi=238) I notice list_to_packed() has become a bottleneck for my application. By removing the for loop and running on GPU, i see a 10-20 x speedup.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1737

Reviewed By: MichaelRamamonjisoa

Differential Revision: D54187993

Pulled By: bottler

fbshipit-source-id: 16399a24cb63b48c30460c7d960abef603b115d0
This commit is contained in:
Ruishen Lyu 2024-04-02 07:50:25 -07:00 committed by Facebook GitHub Bot
parent 128be02fc0
commit ccf22911d4

View File

@ -135,22 +135,21 @@ def list_to_packed(x: List[torch.Tensor]):
- **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the - **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the
index of the element in the list the item belongs to. index of the element in the list the item belongs to.
""" """
N = len(x) if not x:
num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device) raise ValueError("Input list is empty")
item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device) device = x[0].device
item_packed_to_list_idx = [] sizes = [xi.shape[0] for xi in x]
cur = 0 sizes_total = sum(sizes)
for i, y in enumerate(x): num_items = torch.tensor(sizes, dtype=torch.int64, device=device)
num = len(y) item_packed_first_idx = torch.zeros_like(num_items)
num_items[i] = num item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0)
item_packed_first_idx[i] = cur item_packed_to_list_idx = torch.arange(
item_packed_to_list_idx.append( sizes_total, dtype=torch.int64, device=device
torch.full((num,), i, dtype=torch.int64, device=y.device) )
) item_packed_to_list_idx = (
cur += num torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1
)
x_packed = torch.cat(x, dim=0) x_packed = torch.cat(x, dim=0)
item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0)
return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx