mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 15:50:39 +08:00
Update NDC raysampler for non-square convention (#29)
Summary: - Old NDC convention had xy coords in [-1,1]x[-1,1] - New NDC convention has xy coords in [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. This PR fixes the NDC raysampler to use the new convention. Partial fix for https://github.com/facebookresearch/pytorch3d/issues/868 Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/29 Reviewed By: davnov134 Differential Revision: D31926148 Pulled By: bottler fbshipit-source-id: c6c42c60d1473b04e60ceb49c8c10951ddf03c74
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bfeb82efa3
commit
cac6cb1b78
@@ -24,6 +24,69 @@ from pytorch3d.transforms import Rotate
|
||||
from test_cameras import init_random_cameras
|
||||
|
||||
|
||||
class TestNDCRaysamplerConvention(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_ndc_convention(
|
||||
self,
|
||||
h=428,
|
||||
w=760,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
camera = init_random_cameras(PerspectiveCameras, 1, random_z=True).to(device)
|
||||
|
||||
depth_map = torch.ones((1, 1, h, w)).to(device)
|
||||
|
||||
xyz = ray_bundle_to_ray_points(
|
||||
NDCGridRaysampler(
|
||||
image_width=w,
|
||||
image_height=h,
|
||||
n_pts_per_ray=1,
|
||||
min_depth=1.0,
|
||||
max_depth=1.0,
|
||||
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
|
||||
).view(1, -1, 3)
|
||||
|
||||
# project pointcloud
|
||||
xy = camera.transform_points(xyz)[:, :, :2].squeeze()
|
||||
|
||||
xy_grid = self._get_ndc_grid(h, w, device)
|
||||
|
||||
self.assertClose(
|
||||
xy,
|
||||
xy_grid,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
def _get_ndc_grid(self, h, w, device):
|
||||
if w >= h:
|
||||
range_x = w / h
|
||||
range_y = 1.0
|
||||
else:
|
||||
range_x = 1.0
|
||||
range_y = h / w
|
||||
|
||||
half_pix_width = range_x / w
|
||||
half_pix_height = range_y / h
|
||||
|
||||
min_x = range_x - half_pix_width
|
||||
max_x = -range_x + half_pix_width
|
||||
min_y = range_y - half_pix_height
|
||||
max_y = -range_y + half_pix_height
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(
|
||||
torch.linspace(min_y, max_y, h, dtype=torch.float32),
|
||||
torch.linspace(min_x, max_x, w, dtype=torch.float32),
|
||||
)
|
||||
|
||||
x_points = x_grid.contiguous().view(-1).to(device)
|
||||
y_points = y_grid.contiguous().view(-1).to(device)
|
||||
xy = torch.stack((x_points, y_points), dim=1)
|
||||
return xy
|
||||
|
||||
|
||||
class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
@@ -147,12 +210,19 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
# adjust the gt bounds for NDCGridRaysampler
|
||||
half_pix_width = 1.0 / image_width
|
||||
half_pix_height = 1.0 / image_height
|
||||
min_x_ = 1.0 - half_pix_width
|
||||
max_x_ = -1.0 + half_pix_width
|
||||
min_y_ = 1.0 - half_pix_height
|
||||
max_y_ = -1.0 + half_pix_height
|
||||
if image_width >= image_height:
|
||||
range_x = image_width / image_height
|
||||
range_y = 1.0
|
||||
else:
|
||||
range_x = 1.0
|
||||
range_y = image_height / image_width
|
||||
|
||||
half_pix_width = range_x / image_width
|
||||
half_pix_height = range_y / image_height
|
||||
min_x_ = range_x - half_pix_width
|
||||
max_x_ = -range_x + half_pix_width
|
||||
min_y_ = range_y - half_pix_height
|
||||
max_y_ = -range_y + half_pix_height
|
||||
else:
|
||||
min_x_ = min_x
|
||||
max_x_ = max_x
|
||||
|
||||
Reference in New Issue
Block a user