mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
avoid warning in ndc_grid_sample
Summary: If you miss grid_sample in recent pytorch, it gives a warning, so stop doing this. Reviewed By: kjchalup Differential Revision: D36410619 fbshipit-source-id: 41dd4455298645c926f4d96c2084093b3f64ee2c
This commit is contained in:
parent
90d00f1b2b
commit
1702c85bec
@ -355,6 +355,8 @@ def convert_to_tensors_and_broadcast(
|
|||||||
def ndc_grid_sample(
|
def ndc_grid_sample(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
grid_ndc: torch.Tensor,
|
grid_ndc: torch.Tensor,
|
||||||
|
*,
|
||||||
|
align_corners: bool = False,
|
||||||
**grid_sample_kwargs,
|
**grid_sample_kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -368,6 +370,8 @@ def ndc_grid_sample(
|
|||||||
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of
|
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of
|
||||||
2D locations at which `input` is sampled.
|
2D locations at which `input` is sampled.
|
||||||
See [1] for a detailed description of the NDC coordinates.
|
See [1] for a detailed description of the NDC coordinates.
|
||||||
|
align_corners: Forwarded to the `torch.nn.functional.grid_sample`
|
||||||
|
call. See its docstring.
|
||||||
grid_sample_kwargs: Additional arguments forwarded to the
|
grid_sample_kwargs: Additional arguments forwarded to the
|
||||||
`torch.nn.functional.grid_sample` call. See the corresponding
|
`torch.nn.functional.grid_sample` call. See the corresponding
|
||||||
docstring for a listing of the corresponding arguments.
|
docstring for a listing of the corresponding arguments.
|
||||||
@ -393,7 +397,7 @@ def ndc_grid_sample(
|
|||||||
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:])
|
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:])
|
||||||
|
|
||||||
sampled_input_flat = torch.nn.functional.grid_sample(
|
sampled_input_flat = torch.nn.functional.grid_sample(
|
||||||
input, grid_flat, **grid_sample_kwargs
|
input, grid_flat, align_corners=align_corners, **grid_sample_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])
|
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user