mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Joining mismatched texture maps on CUDA #175
Summary: Use nn.functional.interpolate instead of a TorchVision transform to resize texture maps to a common value. This works on all devices. This fixes issue #175. Also fix the condition so it only happens when needed. Reviewed By: nikhilaravi Differential Revision: D21324510 fbshipit-source-id: c50eb06514984995bd81f2c44079be6e0b4098e4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e64e0d17ef
commit
0c595dcf5b
@@ -3,7 +3,7 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from .utils import padded_to_list, padded_to_packed
|
||||
|
||||
@@ -18,10 +18,10 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
|
||||
Pad all texture images so they have the same height and width.
|
||||
|
||||
Args:
|
||||
images: list of N tensors of shape (H, W)
|
||||
images: list of N tensors of shape (H, W, 3)
|
||||
|
||||
Returns:
|
||||
tex_maps: Tensor of shape (N, max_H, max_W)
|
||||
tex_maps: Tensor of shape (N, max_H, max_W, 3)
|
||||
"""
|
||||
tex_maps = []
|
||||
max_H = 0
|
||||
@@ -35,15 +35,13 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
|
||||
tex_maps.append(im)
|
||||
max_shape = (max_H, max_W)
|
||||
|
||||
# If all texture images are not the same size then resize to the
|
||||
# largest size.
|
||||
resize = T.Compose([T.ToPILImage(), T.Resize(size=max_shape), T.ToTensor()])
|
||||
|
||||
for i, image in enumerate(tex_maps):
|
||||
if image.shape != max_shape:
|
||||
# ToPIL takes and returns a C x H x W tensor
|
||||
image = resize(image.permute(2, 0, 1)).permute(1, 2, 0)
|
||||
tex_maps[i] = image
|
||||
if image.shape[:2] != max_shape:
|
||||
image_BCHW = image.permute(2, 0, 1)[None]
|
||||
new_image_BCHW = interpolate(
|
||||
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
|
||||
)
|
||||
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
|
||||
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
|
||||
return tex_maps
|
||||
|
||||
|
||||
Reference in New Issue
Block a user