avoid warning in ndc_grid_sample

Summary: If you miss grid_sample in recent pytorch, it gives a warning, so stop doing this.

Reviewed By: kjchalup

Differential Revision: D36410619

fbshipit-source-id: 41dd4455298645c926f4d96c2084093b3f64ee2c
This commit is contained in:
Jeremy Reizenstein 2022-05-24 18:18:21 -07:00 committed by Facebook GitHub Bot
parent 90d00f1b2b
commit 1702c85bec

View File

@ -355,6 +355,8 @@ def convert_to_tensors_and_broadcast(
def ndc_grid_sample(
input: torch.Tensor,
grid_ndc: torch.Tensor,
*,
align_corners: bool = False,
**grid_sample_kwargs,
) -> torch.Tensor:
"""
@ -368,6 +370,8 @@ def ndc_grid_sample(
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of
2D locations at which `input` is sampled.
See [1] for a detailed description of the NDC coordinates.
align_corners: Forwarded to the `torch.nn.functional.grid_sample`
call. See its docstring.
grid_sample_kwargs: Additional arguments forwarded to the
`torch.nn.functional.grid_sample` call. See the corresponding
docstring for a listing of the corresponding arguments.
@ -393,7 +397,7 @@ def ndc_grid_sample(
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:])
sampled_input_flat = torch.nn.functional.grid_sample(
input, grid_flat, **grid_sample_kwargs
input, grid_flat, align_corners=align_corners, **grid_sample_kwargs
)
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])