Simplify _xy_grid computation in raysampling

Summary: Remove the need of tuple and reversed in the raysampling xy_grid computation

Reviewed By: bottler

Differential Revision: D45269342

fbshipit-source-id: d0e4c0923b9a2cca674b35e8d64862043a0eab3b
This commit is contained in:
Emilien Garreau 2023-04-27 03:07:37 -07:00 committed by Facebook GitHub Bot
parent 32e1992924
commit 823ab75d27

View File

@ -109,17 +109,11 @@ class MultinomialRaysampler(torch.nn.Module):
self._stratified_sampling = stratified_sampling
# get the initial grid of image xy coords
_xy_grid = torch.stack(
tuple(
reversed(
meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
)
),
dim=-1,
y, x = meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
_xy_grid = torch.stack([x, y], dim=-1)
self.register_buffer("_xy_grid", _xy_grid, persistent=False)