mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Adding a Checkerboard mesh utility to Pytorch3d
Summary: Adding a checkerboard mesh utility to Pytorch3d. Reviewed By: bottler Differential Revision: D39718916 fbshipit-source-id: d43cd30e566b5db068bae6eed0388057634428c8
This commit is contained in:
parent
f34da3d3b6
commit
ce3fce49d7
@ -10,6 +10,7 @@ from .camera_conversions import (
|
||||
pulsar_from_cameras_projection,
|
||||
pulsar_from_opencv_projection,
|
||||
)
|
||||
from .checkerboard import checkerboard
|
||||
from .ico_sphere import ico_sphere
|
||||
from .torus import torus
|
||||
|
||||
|
89
pytorch3d/utils/checkerboard.py
Normal file
89
pytorch3d/utils/checkerboard.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.renderer.mesh.textures import TexturesAtlas
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
|
||||
def checkerboard(
|
||||
radius: int = 4,
|
||||
color1: Tuple[float, ...] = (0.0, 0.0, 0.0),
|
||||
color2: Tuple[float, ...] = (1.0, 1.0, 1.0),
|
||||
device: Optional[torch.types._device] = None,
|
||||
) -> Meshes:
|
||||
"""
|
||||
Returns a mesh of squares in the xy-plane where each unit is one of the two given
|
||||
colors and adjacent squares have opposite colors.
|
||||
Args:
|
||||
radius: how many squares in each direction from the origin
|
||||
color1: background color
|
||||
color2: foreground color (must have the same number of channels as color1)
|
||||
Returns:
|
||||
new Meshes object containing one mesh.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if radius < 1:
|
||||
raise ValueError("radius must be > 0")
|
||||
|
||||
num_verts_per_row = 2 * radius + 1
|
||||
|
||||
# construct 2D grid of 3D vertices
|
||||
x = torch.arange(-radius, radius + 1, device=device)
|
||||
grid_y, grid_x = meshgrid_ij(x, x)
|
||||
verts = torch.stack(
|
||||
[grid_x, grid_y, torch.zeros((2 * radius + 1, 2 * radius + 1))], dim=-1
|
||||
)
|
||||
verts = verts.view(1, -1, 3)
|
||||
|
||||
top_triangle_idx = torch.arange(0, num_verts_per_row * (num_verts_per_row - 1))
|
||||
top_triangle_idx = torch.stack(
|
||||
[
|
||||
top_triangle_idx,
|
||||
top_triangle_idx + 1,
|
||||
top_triangle_idx + num_verts_per_row + 1,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
bottom_triangle_idx = top_triangle_idx[:, [0, 2, 1]] + torch.tensor(
|
||||
[0, 0, num_verts_per_row - 1]
|
||||
)
|
||||
|
||||
faces = torch.zeros(
|
||||
(1, len(top_triangle_idx) + len(bottom_triangle_idx), 3),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
faces[0, ::2] = top_triangle_idx
|
||||
faces[0, 1::2] = bottom_triangle_idx
|
||||
|
||||
# construct range of indices that excludes the boundary to avoid wrong triangles
|
||||
indexing_range = torch.arange(0, 2 * num_verts_per_row * num_verts_per_row).view(
|
||||
num_verts_per_row, num_verts_per_row, 2
|
||||
)
|
||||
indexing_range = indexing_range[:-1, :-1] # removes boundaries from list of indices
|
||||
indexing_range = indexing_range.reshape(
|
||||
2 * (num_verts_per_row - 1) * (num_verts_per_row - 1)
|
||||
)
|
||||
|
||||
faces = faces[:, indexing_range]
|
||||
|
||||
# adding color
|
||||
colors = torch.tensor(color1).repeat(2 * num_verts_per_row * num_verts_per_row, 1)
|
||||
colors[2::4] = torch.tensor(color2)
|
||||
colors[3::4] = torch.tensor(color2)
|
||||
colors = colors[None, indexing_range, None, None]
|
||||
|
||||
texture_atlas = TexturesAtlas(colors)
|
||||
|
||||
return Meshes(verts=verts, faces=faces, textures=texture_atlas)
|
21
tests/test_checkerboard.py
Normal file
21
tests/test_checkerboard.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.utils import checkerboard
|
||||
|
||||
from .common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestCheckerboard(TestCaseMixin, unittest.TestCase):
|
||||
def test_simple(self):
|
||||
board = checkerboard(5)
|
||||
verts = board.verts_packed()
|
||||
expect = torch.tensor([5.0, 5.0, 0])
|
||||
self.assertClose(verts.min(dim=0).values, -expect)
|
||||
self.assertClose(verts.max(dim=0).values, expect)
|
Loading…
x
Reference in New Issue
Block a user