mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
fix recent lint
Summary: Flowing of compositing comments Reviewed By: nikhilaravi Differential Revision: D20556707 fbshipit-source-id: 4abdc85e4f65abd41c4a890b6895bc5e95b4576b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d57daa6f85
commit
27eb791e2f
@@ -21,11 +21,11 @@ class CompositeParams(NamedTuple):
|
||||
|
||||
class _CompositeAlphaPoints(torch.autograd.Function):
|
||||
"""
|
||||
Composite features within a z-buffer using alpha compositing. Given a zbuffer
|
||||
Composite features within a z-buffer using alpha compositing. Given a z-buffer
|
||||
with corresponding features and weights, these values are accumulated according
|
||||
to their weights such that features nearer in depth contribute more to the final
|
||||
feature than ones further away.
|
||||
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
|
||||
cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
|
||||
@@ -37,9 +37,9 @@ class _CompositeAlphaPoints(torch.autograd.Function):
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the
|
||||
feature of the kth closest point (along the z-direction) to pixel (y, x) in
|
||||
batch element n. This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
weighted_fs: Tensor of shape (N, C, image_size, image_size)
|
||||
@@ -69,7 +69,7 @@ class _CompositeAlphaPoints(torch.autograd.Function):
|
||||
|
||||
def alpha_composite(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
|
||||
"""
|
||||
Composite features within a z-buffer using alpha compositing. Given a zbuffer
|
||||
Composite features within a z-buffer using alpha compositing. Given a z-buffer
|
||||
with corresponding features and weights, these values are accumulated according
|
||||
to their weights such that features nearer in depth contribute more to the final
|
||||
feature than ones further away.
|
||||
@@ -80,15 +80,16 @@ def alpha_composite(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tens
|
||||
|
||||
|
||||
Args:
|
||||
pt_clds: Tensor of shape (N, C, P) giving the features of each point (can use RGB for example).
|
||||
pt_clds: Tensor of shape (N, C, P) giving the features of each point (can use
|
||||
RGB for example).
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[n, :, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[n, :, p] is the
|
||||
feature of the kth closest point (along the z-direction) to pixel (y, x) in
|
||||
batch element n. This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
Combined features: Tensor of shape (N, C, image_size, image_size)
|
||||
@@ -99,10 +100,10 @@ def alpha_composite(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tens
|
||||
|
||||
class _CompositeNormWeightedSumPoints(torch.autograd.Function):
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a z-buffer
|
||||
with corresponding features and weights, these values are accumulated
|
||||
according to their weights such that depth is ignored; the weights are used to perform
|
||||
a weighted sum.
|
||||
according to their weights such that depth is ignored; the weights are used to
|
||||
perform a weighted sum.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] =
|
||||
@@ -115,9 +116,9 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function):
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the
|
||||
feature of the kth closest point (along the z-direction) to pixel (y, x) in
|
||||
batch element n. This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
weighted_fs: Tensor of shape (N, C, image_size, image_size)
|
||||
@@ -147,10 +148,10 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function):
|
||||
|
||||
def norm_weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a z-buffer
|
||||
with corresponding features and weights, these values are accumulated
|
||||
according to their weights such that depth is ignored; the weights are used to perform
|
||||
a weighted sum.
|
||||
according to their weights such that depth is ignored; the weights are used to
|
||||
perform a weighted sum.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] =
|
||||
@@ -164,9 +165,9 @@ def norm_weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Te
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the
|
||||
feature of the kth closest point (along the z-direction) to pixel (y, x) in
|
||||
batch element n. This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
Combined features: Tensor of shape (N, C, image_size, image_size)
|
||||
@@ -177,7 +178,7 @@ def norm_weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Te
|
||||
|
||||
class _CompositeWeightedSumPoints(torch.autograd.Function):
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a z-buffer
|
||||
with corresponding features and weights, these values are accumulated
|
||||
according to their weights such that depth is ignored; the weights are used to
|
||||
perform a weighted sum. As opposed to norm weighted sum, the weights are not
|
||||
@@ -193,9 +194,9 @@ class _CompositeWeightedSumPoints(torch.autograd.Function):
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the
|
||||
feature of the kth closest point (along the z-direction) to pixel (y, x) in
|
||||
batch element n. This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
weighted_fs: Tensor of shape (N, C, image_size, image_size)
|
||||
@@ -235,9 +236,9 @@ def weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the
|
||||
feature of the kth closest point (along the z-direction) to pixel (y, x) in
|
||||
batch element n. This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
Combined features: Tensor of shape (N, C, image_size, image_size)
|
||||
|
||||
@@ -245,7 +245,8 @@ class Pointclouds(object):
|
||||
|
||||
Returns:
|
||||
3-element tuple of list, padded, num_channels.
|
||||
If aux_input is list, then padded is None. If aux_input is a tensor, then list is None.
|
||||
If aux_input is list, then padded is None. If aux_input is a tensor,
|
||||
then list is None.
|
||||
"""
|
||||
if aux_input is None or self._N == 0:
|
||||
return None, None, None
|
||||
|
||||
Reference in New Issue
Block a user