mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
defaulted grid_sizes in points2vols
Summary: Fix #873, that grid_sizes defaults to the wrong dtype in points2volumes code, and mask doesn't have a proper default. Reviewed By: nikhilaravi Differential Revision: D31503545 fbshipit-source-id: fa32a1a6074fc7ac7bdb362edfb5e5839866a472
This commit is contained in:
parent
2f2466f472
commit
34b1b4ab8b
@ -5,7 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Tuple, Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -364,7 +364,7 @@ def add_points_features_to_volume_densities_features(
|
|||||||
# grid sizes shape (minibatch, 3)
|
# grid sizes shape (minibatch, 3)
|
||||||
grid_sizes = (
|
grid_sizes = (
|
||||||
torch.LongTensor(list(volume_densities.shape[2:]))
|
torch.LongTensor(list(volume_densities.shape[2:]))
|
||||||
.to(volume_densities)
|
.to(volume_densities.device)
|
||||||
.expand(volume_densities.shape[0], 3)
|
.expand(volume_densities.shape[0], 3)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -386,6 +386,10 @@ def add_points_features_to_volume_densities_features(
|
|||||||
splat = False
|
splat = False
|
||||||
else:
|
else:
|
||||||
raise ValueError('No such interpolation mode "%s"' % mode)
|
raise ValueError('No such interpolation mode "%s"' % mode)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
mask = points_3d.new_ones(1).expand(points_3d.shape[:2])
|
||||||
|
|
||||||
volume_densities, volume_features = _points_to_volumes(
|
volume_densities, volume_features = _points_to_volumes(
|
||||||
points_3d,
|
points_3d,
|
||||||
points_features,
|
points_features,
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
|
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Callable, Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from common_testing import get_random_cuda_device
|
from common_testing import get_random_cuda_device
|
||||||
@ -14,6 +14,7 @@ from fvcore.common.benchmark import benchmark
|
|||||||
from pytorch3d.common.workaround import symeig3x3
|
from pytorch3d.common.workaround import symeig3x3
|
||||||
from test_symeig3x3 import TestSymEig3x3
|
from test_symeig3x3 import TestSymEig3x3
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
CUDA_DEVICE = get_random_cuda_device()
|
CUDA_DEVICE = get_random_cuda_device()
|
||||||
|
@ -16,6 +16,7 @@ from pytorch3d.io import save_obj
|
|||||||
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
|
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
|
||||||
from pytorch3d.transforms.rotation_conversions import random_rotation
|
from pytorch3d.transforms.rotation_conversions import random_rotation
|
||||||
|
|
||||||
|
|
||||||
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
|
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
|
||||||
DATA_DIR = get_tests_dir() / "data"
|
DATA_DIR = get_tests_dir() / "data"
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
@ -12,7 +12,10 @@ from typing import Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.ops import add_pointclouds_to_volumes
|
from pytorch3d.ops import (
|
||||||
|
add_pointclouds_to_volumes,
|
||||||
|
add_points_features_to_volume_densities_features,
|
||||||
|
)
|
||||||
from pytorch3d.ops.points_to_volumes import _points_to_volumes
|
from pytorch3d.ops.points_to_volumes import _points_to_volumes
|
||||||
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
@ -373,6 +376,17 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(torch.isfinite(field.grad.data).all())
|
self.assertTrue(torch.isfinite(field.grad.data).all())
|
||||||
|
|
||||||
|
def test_defaulted_arguments(self):
|
||||||
|
points = torch.rand(30, 1000, 3)
|
||||||
|
features = torch.rand(30, 1000, 5)
|
||||||
|
_, densities = add_points_features_to_volume_densities_features(
|
||||||
|
points,
|
||||||
|
features,
|
||||||
|
torch.zeros(30, 1, 32, 32, 32),
|
||||||
|
torch.zeros(30, 5, 32, 32, 32),
|
||||||
|
)
|
||||||
|
self.assertClose(torch.sum(densities), torch.tensor(30 * 1000.0), atol=0.1)
|
||||||
|
|
||||||
def _check_volume_slice_color_density(
|
def _check_volume_slice_color_density(
|
||||||
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
|
self, V, split_dim, interp_mode, clr_gt, slice_type, border=3
|
||||||
):
|
):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user