mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
PLY color scaling
Summary: When a PLY file contains colors in byte format, these are now scaled from 0..255 to [0,1], as they should be Reviewed By: gkioxari Differential Revision: D27765254 fbshipit-source-id: 526b5f5149d5e8cbffd7412b411be52c935fa4ad
This commit is contained in:
parent
6c3fe952d1
commit
e9f4e0d086
@ -14,7 +14,7 @@ import warnings
|
||||
from collections import namedtuple
|
||||
from io import BytesIO, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -780,10 +780,24 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
|
||||
|
||||
def _get_verts_column_indices(
|
||||
vertex_head: _PlyElementType,
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
) -> Tuple[List[int], Optional[List[int]], float]:
|
||||
"""
|
||||
Get the columns of verts and verts_colors in the vertex
|
||||
element of a parsed ply file.
|
||||
element of a parsed ply file, together with a color scale factor.
|
||||
When the colors are in byte format, they are scaled from 0..255 to [0,1].
|
||||
Otherwise they are not scaled.
|
||||
|
||||
For example, if the vertex element looks as follows:
|
||||
|
||||
element vertex 892
|
||||
property double x
|
||||
property double y
|
||||
property double z
|
||||
property uchar red
|
||||
property uchar green
|
||||
property uchar blue
|
||||
|
||||
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
|
||||
|
||||
Args:
|
||||
vertex_head: as returned from load_ply_raw.
|
||||
@ -792,6 +806,7 @@ def _get_verts_column_indices(
|
||||
point_idxs: List[int] of 3 point columns.
|
||||
color_idxs: List[int] of 3 color columns if they are present,
|
||||
otherwise None.
|
||||
color_scale: value to scale colors by.
|
||||
"""
|
||||
point_idxs: List[Optional[int]] = [None, None, None]
|
||||
color_idxs: List[Optional[int]] = [None, None, None]
|
||||
@ -806,9 +821,17 @@ def _get_verts_column_indices(
|
||||
color_idxs[j] = i
|
||||
if None in point_idxs:
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
if None in color_idxs:
|
||||
return point_idxs, None
|
||||
return point_idxs, color_idxs
|
||||
color_scale = 1.0
|
||||
if all(
|
||||
idx is not None and _PLY_TYPES[vertex_head.properties[idx].data_type].size == 1
|
||||
for idx in color_idxs
|
||||
):
|
||||
color_scale = 1.0 / 255
|
||||
return (
|
||||
point_idxs,
|
||||
None if None in color_idxs else cast(List[int], color_idxs),
|
||||
color_scale,
|
||||
)
|
||||
|
||||
|
||||
def _get_verts(
|
||||
@ -831,7 +854,7 @@ def _get_verts(
|
||||
if not isinstance(vertex, list):
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
||||
point_idxs, color_idxs = _get_verts_column_indices(vertex_head)
|
||||
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
|
||||
|
||||
# Case of no vertices
|
||||
if vertex_head.count == 0:
|
||||
@ -856,7 +879,9 @@ def _get_verts(
|
||||
# so it was read as a single array and we can index straight into it.
|
||||
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
|
||||
if color_idxs is not None:
|
||||
vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32)
|
||||
vertex_colors = color_scale * torch.tensor(
|
||||
vertex[0][:, color_idxs], dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
# The vertex element is heterogeneous. It was read as several arrays,
|
||||
# part by part, where a part is a set of properties with the same type.
|
||||
@ -887,6 +912,7 @@ def _get_verts(
|
||||
for color in range(3):
|
||||
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
||||
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
||||
vertex_colors *= color_scale
|
||||
|
||||
return verts, vertex_colors
|
||||
|
||||
|
@ -448,7 +448,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
|
||||
)
|
||||
self.assertClose(
|
||||
pointcloud.features_padded()[0],
|
||||
pointcloud.features_padded()[0] * 255,
|
||||
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
|
||||
)
|
||||
|
||||
@ -518,7 +518,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
|
||||
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
|
||||
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
|
||||
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32)
|
||||
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32) / 255.0
|
||||
self.assertClose(pointcloud.points_padded(), expected_points)
|
||||
self.assertClose(pointcloud.features_padded(), expected_features)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user