From f3c1e0837c110f390165cc48e4923c2b2da14336 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Mon, 7 Nov 2022 13:43:31 -0800 Subject: [PATCH] MC rasterize supports heterogeneous bundle; refactoring of bundle-to-padded Summary: Rasterize MC was not adapted to heterogeneous bundles. There are some caveats though: 1) on CO3D, we get up to 18 points per image, which is too few for a reasonable visualisation (see below); 2) rasterising for a batch of 100 is slow. I also moved the unpacking code close to the bundle to be able to reuse it. {F789678778} Reviewed By: bottler, davnov134 Differential Revision: D41008600 fbshipit-source-id: 9f10f1f9f9a174cf8c534b9b9859587d69832b71 --- pytorch3d/implicitron/models/generic_model.py | 65 ++------------- .../models/implicit_function/voxel_grid.py | 7 ++ .../voxel_grid_implicit_function.py | 9 +++ pytorch3d/implicitron/models/metrics.py | 40 ++++----- pytorch3d/implicitron/models/renderer/base.py | 52 +++++++++--- pytorch3d/implicitron/tools/rasterize_mc.py | 81 +++++++++++++++++++ pytorch3d/ops/packed_to_padded.py | 33 +++++--- pytorch3d/renderer/implicit/raysampling.py | 20 +++-- pytorch3d/renderer/implicit/utils.py | 4 +- tests/test_packed_to_padded.py | 10 +++ 10 files changed, 210 insertions(+), 111 deletions(-) diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index ba3c4a18..56ea080b 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -9,7 +9,6 @@ # which are part of implicitron. They ensure that the registry is prepopulated. import logging -import math import warnings from dataclasses import field from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union @@ -29,7 +28,7 @@ from pytorch3d.implicitron.tools.config import ( registry, run_auto_creation, ) -from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples +from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle from pytorch3d.implicitron.tools.utils import cat_dataclass from pytorch3d.renderer import utils as rend_utils @@ -502,9 +501,10 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 preds["images_render"], preds["depths_render"], preds["masks_render"], - ) = self._rasterize_mc_samples( - ray_bundle.xys, + ) = rasterize_sparse_ray_bundle( + ray_bundle, rendered.features, + (self.render_image_height, self.render_image_width), rendered.depths, masks=rendered.masks, ) @@ -828,61 +828,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 return image_rgb, fg_mask, depth_map - @torch.no_grad() - def _rasterize_mc_samples( - self, - xys: torch.Tensor, - features: torch.Tensor, - depth: torch.Tensor, - masks: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Rasterizes Monte-Carlo features back onto the image. - - Args: - xys: B x ... x 2 2D point locations in PyTorch3D NDC convention - features: B x ... x C tensor containing per-point rendered features. - depth: B x ... x 1 tensor containing per-point rendered depth. - """ - ba = xys.shape[0] - - # Flatten the features and xy locations. - features_depth_ras = torch.cat( - ( - features.reshape(ba, -1, features.shape[-1]), - depth.reshape(ba, -1, 1), - ), - dim=-1, - ) - xys_ras = xys.reshape(ba, -1, 2) - if masks is not None: - masks_ras = masks.reshape(ba, -1, 1) - else: - masks_ras = None - - if min(self.render_image_height, self.render_image_width) <= 0: - raise ValueError( - "Need to specify a positive" - " self.render_image_height and self.render_image_width" - " for MC rasterisation." - ) - - # Estimate the rasterization point radius so that we approximately fill - # the whole image given the number of rasterized points. - pt_radius = 2.0 / math.sqrt(xys.shape[1]) - - # Rasterize the samples. - features_depth_render, masks_render = rasterize_mc_samples( - xys_ras, - features_depth_ras, - (self.render_image_height, self.render_image_width), - radius=pt_radius, - masks=masks_ras, - ) - images_render = features_depth_render[:, :-1] - depths_render = features_depth_render[:, -1:] - return images_render, depths_render, masks_render - def _apply_chunked(func, chunk_generator, tensor_collator): """ @@ -940,7 +885,7 @@ def _chunk_generator( def _safe_slice( tensor: Optional[torch.Tensor], start_idx: int, end_idx: int - ) -> Optional[torch.Tensor]: + ) -> Any: return tensor[start_idx:end_idx] if tensor is not None else None for start_idx in iter: diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index 0ecca4b7..c9d518ea 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -15,6 +15,7 @@ these classes. """ +import logging import warnings from collections.abc import Mapping from dataclasses import dataclass, field @@ -35,6 +36,9 @@ from pytorch3d.structures.volumes import VolumeLocator from .utils import interpolate_line, interpolate_plane, interpolate_volume +logger = logging.getLogger(__name__) + + @dataclass class VoxelGridValuesBase: pass @@ -250,6 +254,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): ) for name, shape in wanted_shapes.items() } + res = self.get_resolution(epoch) + logger.info(f"Changed grid resolutiuon at epoch {epoch} to {res}") else: params = { name: ( @@ -261,6 +267,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): ) for name, tensor in vars(grid_values_with_wanted_resolution).items() } + # pyre-ignore[29] return self.values_type(**params), True diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py index f106e10f..b21e253a 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import math import warnings from dataclasses import fields @@ -29,6 +30,9 @@ from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding +logger = logging.getLogger(__name__) + + enable_get_default_args(HarmonicEmbedding) @@ -491,6 +495,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): max_indices = tuple(torch.max(non_zero_idxs, dim=0)[0]) min_point, max_point = points[min_indices], points[max_indices] + logger.info( + f"Cropping at epoch {epoch} to bounding box " + f"[{min_point.tolist()}, {max_point.tolist()}]." + ) + # crop the voxel grids self.voxel_grid_density.crop_self(min_point, max_point) self.voxel_grid_color.crop_self(min_point, max_point) diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py index 13387609..174c73e5 100644 --- a/pytorch3d/implicitron/models/metrics.py +++ b/pytorch3d/implicitron/models/metrics.py @@ -12,7 +12,7 @@ import torch from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from pytorch3d.implicitron.tools import metric_utils as utils from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from pytorch3d.ops import packed_to_padded, padded_to_packed +from pytorch3d.ops import padded_to_packed from pytorch3d.renderer import utils as rend_utils from .renderer.base import RendererOutput @@ -257,20 +257,7 @@ class ViewMetrics(ViewMetricsBase): # memory requirements. Instead of having one image for every element in # ray_bundle we can than have one image per unique sampled camera. if ray_bundle.is_packed(): - # pyre-ignore[6] - cumsum = torch.cumsum(ray_bundle.camera_counts, dim=0, dtype=torch.long) - first_idxs = torch.cat( - ( - # pyre-ignore[16] - ray_bundle.camera_counts.new_zeros((1,), dtype=torch.long), - cumsum[:-1], - ) - ) - # pyre-ignore[16] - num_inputs = int(ray_bundle.camera_counts.sum()) - # pyre-ignore[6] - max_size = int(torch.max(ray_bundle.camera_counts)) - xys = packed_to_padded(xys, first_idxs, max_size) + xys, first_idxs, num_inputs = ray_bundle.get_padded_xys() # reshape the sampling grid as well # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var @@ -278,23 +265,26 @@ class ViewMetrics(ViewMetricsBase): xys = xys.reshape(xys.shape[0], -1, 1, 2) # closure with the given xys - def sample(tensor, mode): + def sample_full(tensor, mode): if tensor is None: return tensor + return rend_utils.ndc_grid_sample(tensor, xys, mode=mode) + + def sample_packed(tensor, mode): + if tensor is None: + return tensor + + # select images that corespond to sampled cameras if raybundle is packed + tensor = tensor[ray_bundle.camera_ids] if ray_bundle.is_packed(): # select images that corespond to sampled cameras if raybundle is packed tensor = tensor[ray_bundle.camera_ids] result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode) - if ray_bundle.is_packed(): - # Images after sampling are in a form [batch, 3, max_num_rays, 1], - # packed_to_padded combines first two dimensions so we need to swap 1st - # and 2nd dimension. the result is [n_rays_total_training, 1, 3, 1] - # (we use keepdim=True). - result = result.transpose(1, 2) - result = padded_to_packed(result, first_idxs, num_inputs)[:, None] - result = result.transpose(1, 2) + return padded_to_packed(result, first_idxs, num_inputs, max_size_dim=2)[ + :, :, None + ] # the result is [n_rays_total_training, 3, 1, 1] - return result + sample = sample_packed if ray_bundle.is_packed() else sample_full # eval all results in this size image_rgb = sample(image_rgb, mode="bilinear") diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index 27ee1787..9b29bdeb 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -11,10 +11,11 @@ import dataclasses from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch from pytorch3d.implicitron.tools.config import ReplaceableBase +from pytorch3d.ops import packed_to_padded class EvaluationMode(Enum): @@ -37,20 +38,23 @@ class ImplicitronRayBundle: in the respective 1D coordinate systems; see documentation for :func:`ray_bundle_to_ray_points` for the conversion formula. - camera_ids: A tensor of shape (N, ) which indicates which camera - was used to sample the rays. `N` is the number of different - sampled cameras. - camera_counts: A tensor of shape (N, ) which how many times the - coresponding camera in `camera_ids` was sampled. - `sum(camera_counts)==minibatch` + Ray bundle may represent rays from multiple cameras. In that case, cameras + are stored in the packed form (i.e. rays from the same camera are stored in + the consecutive elements). The following indices will be set: + camera_ids: A tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of different + sampled cameras. + camera_counts: A tensor of shape (N, ) which how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts) == minibatch`, where `minibatch = origins.shape[0]`. """ origins: torch.Tensor directions: torch.Tensor lengths: torch.Tensor xys: torch.Tensor - camera_ids: Optional[torch.Tensor] = None - camera_counts: Optional[torch.Tensor] = None + camera_ids: Optional[torch.LongTensor] = None + camera_counts: Optional[torch.LongTensor] = None def is_packed(self) -> bool: """ @@ -58,6 +62,36 @@ class ImplicitronRayBundle: """ return self.camera_ids is not None and self.camera_counts is not None + def get_padded_xys(self) -> Tuple[torch.Tensor, torch.LongTensor, int]: + """ + For a packed ray bundle, returns padded rays. Assumes the input bundle is packed + (i.e. `camera_ids` and `camera_counts` are set). + + Returns: + - xys: Tensor of shape (N, max_size, ...) containing the padded + representation of the pixel coordinated; + where max_size is max of `camera_counts`. The values for camera id `i` + will be copied to `xys[i, :]`, with zeros padding out the extra inputs. + - first_idxs: cumulative sum of `camera_counts` defininf the boundaries + between cameras in the packed representation + - num_inputs: the number of cameras in the bundle. + """ + if not self.is_packed(): + raise ValueError("get_padded_xys can be called only on a packed bundle") + + camera_counts = self.camera_counts + assert camera_counts is not None + + cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long) + first_idxs = torch.cat( + (camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1]) + ) + num_inputs = camera_counts.sum().item() + max_size = torch.max(camera_counts).item() + xys = packed_to_padded(self.xys, first_idxs, max_size) + # pyre-ignore [7] pytorch typeshed inaccuracy + return xys, first_idxs, num_inputs + @dataclass class RendererOutput: diff --git a/pytorch3d/implicitron/tools/rasterize_mc.py b/pytorch3d/implicitron/tools/rasterize_mc.py index 20570a30..645a5bac 100644 --- a/pytorch3d/implicitron/tools/rasterize_mc.py +++ b/pytorch3d/implicitron/tools/rasterize_mc.py @@ -4,15 +4,96 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from typing import Optional, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle +from pytorch3d.ops import packed_to_padded from pytorch3d.renderer import PerspectiveCameras from pytorch3d.structures import Pointclouds from .point_cloud_utils import render_point_cloud_pytorch3d +@torch.no_grad() +def rasterize_sparse_ray_bundle( + ray_bundle: ImplicitronRayBundle, + features: torch.Tensor, + image_size_hw: Tuple[int, int], + depth: torch.Tensor, + masks: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rasterizes sparse features corresponding to the coordinates defined by + the rays in the bundle. + + Args: + ray_bundle: ray bundle object with B x ... x 2 pixel coordinates, + it can be packed. + features: B x ... x C tensor containing per-point rendered features. + image_size_hw: Tuple[image_height, image_width] containing + the size of rasterized image. + depth: B x ... x 1 tensor containing per-point rendered depth. + masks: B x ... x 1 tensor containing the alpha mask of the + rendered features. + + Returns: + - image_render: B x C x H x W tensor of rasterized features + - depths_render: B x 1 x H x W tensor of rasterized depth maps + - masks_render: B x 1 x H x W tensor of opacities after splatting + """ + # Flatten the features and xy locations. + features_depth_ras = torch.cat( + (features.flatten(1, -2), depth.flatten(1, -2)), dim=-1 + ) + xys = ray_bundle.xys + masks_ras = None + if ray_bundle.is_packed(): + camera_counts = ray_bundle.camera_counts + assert camera_counts is not None + xys, first_idxs, _ = ray_bundle.get_padded_xys() + masks_ras = ( + torch.arange(xys.shape[1], device=xys.device)[:, None] + < camera_counts[:, None, None] + ) + + max_size = torch.max(camera_counts).item() + features_depth_ras = packed_to_padded( + features_depth_ras[:, 0], first_idxs, max_size + ) + if masks is not None: + padded_mask = packed_to_padded(masks.flatten(1, -1), first_idxs, max_size) + masks_ras = padded_mask * masks_ras + + xys_ras = xys.flatten(1, -2) + + if masks_ras is None: + assert not ray_bundle.is_packed() + masks_ras = masks.flatten(1, -2) if masks is not None else None + + if min(*image_size_hw) <= 0: + raise ValueError( + "Need to specify a positive output_size_hw for bundle rasterisation." + ) + + # Estimate the rasterization point radius so that we approximately fill + # the whole image given the number of rasterized points. + pt_radius = 2.0 / math.sqrt(xys.shape[1]) + + # Rasterize the samples. + features_depth_render, masks_render = rasterize_mc_samples( + xys_ras, + features_depth_ras, + image_size_hw, + radius=pt_radius, + masks=masks_ras, + ) + images_render = features_depth_render[:, :-1] + depths_render = features_depth_render[:, -1:] + return images_render, depths_render, masks_render + + def rasterize_mc_samples( xys: torch.Tensor, feats: torch.Tensor, diff --git a/pytorch3d/ops/packed_to_padded.py b/pytorch3d/ops/packed_to_padded.py index 8fd2b718..7c209b5f 100644 --- a/pytorch3d/ops/packed_to_padded.py +++ b/pytorch3d/ops/packed_to_padded.py @@ -60,7 +60,9 @@ class _PackedToPadded(Function): return grad_input, None, None -def packed_to_padded(inputs, first_idxs, max_size): +def packed_to_padded( + inputs: torch.Tensor, first_idxs: torch.LongTensor, max_size: int +) -> torch.Tensor: """ Torch wrapper that handles allowed input shapes. See description below. @@ -74,7 +76,7 @@ def packed_to_padded(inputs, first_idxs, max_size): Returns: inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, ...) - where max_size is max of `sizes`. The values for batch element i + where max_size is max of `sizes`. The values for batch element i which start at `inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`, with zeros padding out the extra inputs. @@ -89,6 +91,7 @@ def packed_to_padded(inputs, first_idxs, max_size): inputs = inputs.unsqueeze(1) else: inputs = inputs.reshape(input_shape[0], -1) + # pyre-ignore [16] inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size) # if flat is True, reshape output to (N, max_size) from (N, max_size, 1) # else reshape output to (N, max_size, ...) @@ -147,39 +150,49 @@ class _PaddedToPacked(Function): return grad_input, None, None -def padded_to_packed(inputs, first_idxs, num_inputs): +def padded_to_packed( + inputs: torch.Tensor, + first_idxs: torch.LongTensor, + num_inputs: int, + max_size_dim: int = 1, +) -> torch.Tensor: """ Torch wrapper that handles allowed input shapes. See description below. Args: - inputs: FloatTensor of shape (N, max_size) or (N, max_size, ...), + inputs: FloatTensor of shape (N, ..., max_size) or (N, ..., max_size, ...), representing the padded tensor, e.g. areas for faces in a batch of - meshes. + meshes, where max_size occurs on max_size_dim-th position. first_idxs: LongTensor of shape (N,) where N is the number of elements in the batch and `first_idxs[i] = f` means that the inputs for batch element i begin at `inputs_packed[f]`. num_inputs: Number of packed entries (= F) + max_size_dim: the dimension to be packed Returns: inputs_packed: FloatTensor of shape (F,) or (F, ...) where - `inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, :]`. + `inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, ..., :delta[i]]`, + where `delta[i] = first_idx[i+1] - first_idx[i]`. To handle the allowed input shapes, we convert the inputs tensor of shape - (N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from + (N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from (F, 1). """ + n_dims = inputs.dim() + # move the variable dim to position 1 + inputs = inputs.movedim(max_size_dim, 1) + # if inputs is of shape (N, max_size), reshape into (N, max_size, 1)) input_shape = inputs.shape - n_dims = inputs.dim() if n_dims == 2: inputs = inputs.unsqueeze(2) else: inputs = inputs.reshape(*input_shape[:2], -1) + # pyre-ignore [16] inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs) # if input is flat, reshape output to (F,) from (F, 1) # else reshape output to (F, ...) if n_dims == 2: return inputs_packed.squeeze(1) - if n_dims == 3: - return inputs_packed + return inputs_packed.view(-1, *input_shape[2:]) diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index 4e978b67..919c14e8 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -207,7 +207,8 @@ class MultinomialRaysampler(torch.nn.Module): n_rays_per_image, ) = _sample_cameras_and_masks(n_rays_total, cameras, mask) else: - camera_ids = torch.arange(len(cameras), dtype=torch.long) + # pyre-ignore[9] + camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long) batch_size = cameras.R.shape[0] device = cameras.device @@ -438,7 +439,8 @@ class MonteCarloRaysampler(torch.nn.Module): n_rays_per_image, ) = _sample_cameras_and_masks(self._n_rays_total, cameras, None) else: - camera_ids = torch.arange(len(cameras), dtype=torch.long) + # pyre-ignore[9] + camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long) n_rays_per_image = self._n_rays_per_image batch_size = cameras.R.shape[0] @@ -716,7 +718,11 @@ def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor: def _sample_cameras_and_masks( n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None ) -> Tuple[ - CamerasBase, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor + CamerasBase, + Optional[torch.Tensor], + torch.LongTensor, + torch.LongTensor, + torch.LongTensor, ]: """ Samples n_rays_total cameras and masks and returns them in a form @@ -740,6 +746,7 @@ def _sample_cameras_and_masks( dtype=torch.long, ) unique_ids, counts = torch.unique(sampled_ids, return_counts=True) + # pyre-ignore[7] return ( cameras[unique_ids], mask[unique_ids] if mask is not None else None, @@ -749,8 +756,9 @@ def _sample_cameras_and_masks( ) +# TODO: this function can be unified with ImplicitronRayBundle.get_padded_xys def _pack_ray_bundle( - ray_bundle: RayBundle, camera_ids: torch.Tensor, camera_counts: torch.Tensor + ray_bundle: RayBundle, camera_ids: torch.LongTensor, camera_counts: torch.LongTensor ) -> HeterogeneousRayBundle: """ Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to @@ -765,9 +773,11 @@ def _pack_ray_bundle( Returns: HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1 """ + # pyre-ignore[9] camera_counts = camera_counts.to(ray_bundle.origins.device) cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long) - first_idxs = torch.cat( + # pyre-ignore[9] + first_idxs: torch.LongTensor = torch.cat( (camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1]) ) num_inputs = int(camera_counts.sum()) diff --git a/pytorch3d/renderer/implicit/utils.py b/pytorch3d/renderer/implicit/utils.py index 6ccae29b..92b59c58 100644 --- a/pytorch3d/renderer/implicit/utils.py +++ b/pytorch3d/renderer/implicit/utils.py @@ -60,8 +60,8 @@ class HeterogeneousRayBundle: directions: torch.Tensor lengths: torch.Tensor xys: torch.Tensor - camera_ids: Optional[torch.Tensor] = None - camera_counts: Optional[torch.Tensor] = None + camera_ids: Optional[torch.LongTensor] = None + camera_counts: Optional[torch.LongTensor] = None def ray_bundle_to_ray_points( diff --git a/tests/test_packed_to_padded.py b/tests/test_packed_to_padded.py index 79dc35c0..e04c7256 100644 --- a/tests/test_packed_to_padded.py +++ b/tests/test_packed_to_padded.py @@ -188,6 +188,16 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): # check forward self.assertClose(values_packed, values_packed_torch) + if len(dims) > 0: + values_packed_dim2 = padded_to_packed( + values.transpose(1, 2), + mesh_to_faces_packed_first_idx, + num_faces_per_mesh.sum().item(), + max_size_dim=2, + ) + # check forward + self.assertClose(values_packed_dim2, values_packed_torch) + # check backward if len(dims) == 0: grad_inputs = torch.rand((num_faces_per_mesh.sum().item()), device=device)