mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Summary: Adding a checkerboard mesh utility to Pytorch3d. Reviewed By: bottler Differential Revision: D39718916 fbshipit-source-id: d43cd30e566b5db068bae6eed0388057634428c8
90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from pytorch3d.common.compat import meshgrid_ij
|
|
from pytorch3d.renderer.mesh.textures import TexturesAtlas
|
|
from pytorch3d.structures.meshes import Meshes
|
|
|
|
|
|
def checkerboard(
|
|
radius: int = 4,
|
|
color1: Tuple[float, ...] = (0.0, 0.0, 0.0),
|
|
color2: Tuple[float, ...] = (1.0, 1.0, 1.0),
|
|
device: Optional[torch.types._device] = None,
|
|
) -> Meshes:
|
|
"""
|
|
Returns a mesh of squares in the xy-plane where each unit is one of the two given
|
|
colors and adjacent squares have opposite colors.
|
|
Args:
|
|
radius: how many squares in each direction from the origin
|
|
color1: background color
|
|
color2: foreground color (must have the same number of channels as color1)
|
|
Returns:
|
|
new Meshes object containing one mesh.
|
|
"""
|
|
|
|
if device is None:
|
|
device = torch.device("cpu")
|
|
if radius < 1:
|
|
raise ValueError("radius must be > 0")
|
|
|
|
num_verts_per_row = 2 * radius + 1
|
|
|
|
# construct 2D grid of 3D vertices
|
|
x = torch.arange(-radius, radius + 1, device=device)
|
|
grid_y, grid_x = meshgrid_ij(x, x)
|
|
verts = torch.stack(
|
|
[grid_x, grid_y, torch.zeros((2 * radius + 1, 2 * radius + 1))], dim=-1
|
|
)
|
|
verts = verts.view(1, -1, 3)
|
|
|
|
top_triangle_idx = torch.arange(0, num_verts_per_row * (num_verts_per_row - 1))
|
|
top_triangle_idx = torch.stack(
|
|
[
|
|
top_triangle_idx,
|
|
top_triangle_idx + 1,
|
|
top_triangle_idx + num_verts_per_row + 1,
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
bottom_triangle_idx = top_triangle_idx[:, [0, 2, 1]] + torch.tensor(
|
|
[0, 0, num_verts_per_row - 1]
|
|
)
|
|
|
|
faces = torch.zeros(
|
|
(1, len(top_triangle_idx) + len(bottom_triangle_idx), 3),
|
|
dtype=torch.long,
|
|
device=device,
|
|
)
|
|
faces[0, ::2] = top_triangle_idx
|
|
faces[0, 1::2] = bottom_triangle_idx
|
|
|
|
# construct range of indices that excludes the boundary to avoid wrong triangles
|
|
indexing_range = torch.arange(0, 2 * num_verts_per_row * num_verts_per_row).view(
|
|
num_verts_per_row, num_verts_per_row, 2
|
|
)
|
|
indexing_range = indexing_range[:-1, :-1] # removes boundaries from list of indices
|
|
indexing_range = indexing_range.reshape(
|
|
2 * (num_verts_per_row - 1) * (num_verts_per_row - 1)
|
|
)
|
|
|
|
faces = faces[:, indexing_range]
|
|
|
|
# adding color
|
|
colors = torch.tensor(color1).repeat(2 * num_verts_per_row * num_verts_per_row, 1)
|
|
colors[2::4] = torch.tensor(color2)
|
|
colors[3::4] = torch.tensor(color2)
|
|
colors = colors[None, indexing_range, None, None]
|
|
|
|
texture_atlas = TexturesAtlas(colors)
|
|
|
|
return Meshes(verts=verts, faces=faces, textures=texture_atlas)
|