PLY color scaling

Summary: When a PLY file contains colors in byte format, these are now scaled from 0..255 to [0,1], as they should be

Reviewed By: gkioxari

Differential Revision: D27765254

fbshipit-source-id: 526b5f5149d5e8cbffd7412b411be52c935fa4ad
This commit is contained in:
Jeremy Reizenstein 2021-05-04 05:35:24 -07:00 committed by Facebook GitHub Bot
parent 6c3fe952d1
commit e9f4e0d086
2 changed files with 36 additions and 10 deletions

View File

@ -14,7 +14,7 @@ import warnings
from collections import namedtuple
from io import BytesIO, TextIOBase
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, cast
import numpy as np
import torch
@ -780,10 +780,24 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
def _get_verts_column_indices(
vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]]]:
) -> Tuple[List[int], Optional[List[int]], float]:
"""
Get the columns of verts and verts_colors in the vertex
element of a parsed ply file.
element of a parsed ply file, together with a color scale factor.
When the colors are in byte format, they are scaled from 0..255 to [0,1].
Otherwise they are not scaled.
For example, if the vertex element looks as follows:
element vertex 892
property double x
property double y
property double z
property uchar red
property uchar green
property uchar blue
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
Args:
vertex_head: as returned from load_ply_raw.
@ -792,6 +806,7 @@ def _get_verts_column_indices(
point_idxs: List[int] of 3 point columns.
color_idxs: List[int] of 3 color columns if they are present,
otherwise None.
color_scale: value to scale colors by.
"""
point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None]
@ -806,9 +821,17 @@ def _get_verts_column_indices(
color_idxs[j] = i
if None in point_idxs:
raise ValueError("Invalid vertices in file.")
if None in color_idxs:
return point_idxs, None
return point_idxs, color_idxs
color_scale = 1.0
if all(
idx is not None and _PLY_TYPES[vertex_head.properties[idx].data_type].size == 1
for idx in color_idxs
):
color_scale = 1.0 / 255
return (
point_idxs,
None if None in color_idxs else cast(List[int], color_idxs),
color_scale,
)
def _get_verts(
@ -831,7 +854,7 @@ def _get_verts(
if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs = _get_verts_column_indices(vertex_head)
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
# Case of no vertices
if vertex_head.count == 0:
@ -856,7 +879,9 @@ def _get_verts(
# so it was read as a single array and we can index straight into it.
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
if color_idxs is not None:
vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32)
vertex_colors = color_scale * torch.tensor(
vertex[0][:, color_idxs], dtype=torch.float32
)
else:
# The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type.
@ -887,6 +912,7 @@ def _get_verts(
for color in range(3):
partnum, col = prop_to_partnum_col[color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
vertex_colors *= color_scale
return verts, vertex_colors

View File

@ -448,7 +448,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
)
self.assertClose(
pointcloud.features_padded()[0],
pointcloud.features_padded()[0] * 255,
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
)
@ -518,7 +518,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32)
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32) / 255.0
self.assertClose(pointcloud.points_padded(), expected_points)
self.assertClose(pointcloud.features_padded(), expected_features)