From 1702c85beca5e69fe068db94e955319d668ae417 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 24 May 2022 18:18:21 -0700 Subject: [PATCH] avoid warning in ndc_grid_sample Summary: If you miss grid_sample in recent pytorch, it gives a warning, so stop doing this. Reviewed By: kjchalup Differential Revision: D36410619 fbshipit-source-id: 41dd4455298645c926f4d96c2084093b3f64ee2c --- pytorch3d/renderer/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 0e2b7fa2..414c7f79 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -355,6 +355,8 @@ def convert_to_tensors_and_broadcast( def ndc_grid_sample( input: torch.Tensor, grid_ndc: torch.Tensor, + *, + align_corners: bool = False, **grid_sample_kwargs, ) -> torch.Tensor: """ @@ -368,6 +370,8 @@ def ndc_grid_sample( grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of 2D locations at which `input` is sampled. See [1] for a detailed description of the NDC coordinates. + align_corners: Forwarded to the `torch.nn.functional.grid_sample` + call. See its docstring. grid_sample_kwargs: Additional arguments forwarded to the `torch.nn.functional.grid_sample` call. See the corresponding docstring for a listing of the corresponding arguments. @@ -393,7 +397,7 @@ def ndc_grid_sample( grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:]) sampled_input_flat = torch.nn.functional.grid_sample( - input, grid_flat, **grid_sample_kwargs + input, grid_flat, align_corners=align_corners, **grid_sample_kwargs ) sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])