mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-15 17:05:58 +08:00
Rendering texturing fixes
Summary: Fix errors raised by issue on GitHub - extending mesh textures + rendering with Gourad and Phong shaders. https://github.com/facebookresearch/pytorch3d/issues/97 Reviewed By: gkioxari Differential Revision: D20319610 fbshipit-source-id: d1c692ff0b9397a77a9b829c5c731790de70c09f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f580ce1385
commit
5d3cc3569a
@@ -223,27 +223,32 @@ class TensorProperties(object):
|
||||
self with all properties reshaped. e.g. a property with shape (N, 3)
|
||||
is transformed to shape (B, 3).
|
||||
"""
|
||||
# Iterate through the attributes of the class which are tensors.
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
if v.shape[0] > 1:
|
||||
# There are different values for each batch element
|
||||
# so gather these using the batch_idx
|
||||
idx_dims = batch_idx.shape
|
||||
# so gather these using the batch_idx.
|
||||
# First clone the input batch_idx tensor before
|
||||
# modifying it.
|
||||
_batch_idx = batch_idx.clone()
|
||||
idx_dims = _batch_idx.shape
|
||||
tensor_dims = v.shape
|
||||
if len(idx_dims) > len(tensor_dims):
|
||||
msg = "batch_idx cannot have more dimensions than %s. "
|
||||
msg += "got shape %r and %s has shape %r"
|
||||
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
|
||||
if idx_dims != tensor_dims:
|
||||
# To use torch.gather the index tensor (batch_idx) has
|
||||
# To use torch.gather the index tensor (_batch_idx) has
|
||||
# to have the same shape as the input tensor.
|
||||
new_dims = len(tensor_dims) - len(idx_dims)
|
||||
new_shape = idx_dims + (1,) * new_dims
|
||||
expand_dims = (-1,) + tensor_dims[1:]
|
||||
batch_idx = batch_idx.view(*new_shape)
|
||||
batch_idx = batch_idx.expand(*expand_dims)
|
||||
v = v.gather(0, batch_idx)
|
||||
_batch_idx = _batch_idx.view(*new_shape)
|
||||
_batch_idx = _batch_idx.expand(*expand_dims)
|
||||
|
||||
v = v.gather(0, _batch_idx)
|
||||
setattr(self, k, v)
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user