mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
avoid warning in ndc_grid_sample
Summary: If you miss grid_sample in recent pytorch, it gives a warning, so stop doing this. Reviewed By: kjchalup Differential Revision: D36410619 fbshipit-source-id: 41dd4455298645c926f4d96c2084093b3f64ee2c
This commit is contained in:
parent
90d00f1b2b
commit
1702c85bec
@ -355,6 +355,8 @@ def convert_to_tensors_and_broadcast(
|
||||
def ndc_grid_sample(
|
||||
input: torch.Tensor,
|
||||
grid_ndc: torch.Tensor,
|
||||
*,
|
||||
align_corners: bool = False,
|
||||
**grid_sample_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -368,6 +370,8 @@ def ndc_grid_sample(
|
||||
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of
|
||||
2D locations at which `input` is sampled.
|
||||
See [1] for a detailed description of the NDC coordinates.
|
||||
align_corners: Forwarded to the `torch.nn.functional.grid_sample`
|
||||
call. See its docstring.
|
||||
grid_sample_kwargs: Additional arguments forwarded to the
|
||||
`torch.nn.functional.grid_sample` call. See the corresponding
|
||||
docstring for a listing of the corresponding arguments.
|
||||
@ -393,7 +397,7 @@ def ndc_grid_sample(
|
||||
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:])
|
||||
|
||||
sampled_input_flat = torch.nn.functional.grid_sample(
|
||||
input, grid_flat, **grid_sample_kwargs
|
||||
input, grid_flat, align_corners=align_corners, **grid_sample_kwargs
|
||||
)
|
||||
|
||||
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])
|
||||
|
Loading…
x
Reference in New Issue
Block a user