mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
PLY color scaling
Summary: When a PLY file contains colors in byte format, these are now scaled from 0..255 to [0,1], as they should be Reviewed By: gkioxari Differential Revision: D27765254 fbshipit-source-id: 526b5f5149d5e8cbffd7412b411be52c935fa4ad
This commit is contained in:
parent
6c3fe952d1
commit
e9f4e0d086
@ -14,7 +14,7 @@ import warnings
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from io import BytesIO, TextIOBase
|
from io import BytesIO, TextIOBase
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -780,10 +780,24 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
|
|||||||
|
|
||||||
def _get_verts_column_indices(
|
def _get_verts_column_indices(
|
||||||
vertex_head: _PlyElementType,
|
vertex_head: _PlyElementType,
|
||||||
) -> Tuple[List[int], Optional[List[int]]]:
|
) -> Tuple[List[int], Optional[List[int]], float]:
|
||||||
"""
|
"""
|
||||||
Get the columns of verts and verts_colors in the vertex
|
Get the columns of verts and verts_colors in the vertex
|
||||||
element of a parsed ply file.
|
element of a parsed ply file, together with a color scale factor.
|
||||||
|
When the colors are in byte format, they are scaled from 0..255 to [0,1].
|
||||||
|
Otherwise they are not scaled.
|
||||||
|
|
||||||
|
For example, if the vertex element looks as follows:
|
||||||
|
|
||||||
|
element vertex 892
|
||||||
|
property double x
|
||||||
|
property double y
|
||||||
|
property double z
|
||||||
|
property uchar red
|
||||||
|
property uchar green
|
||||||
|
property uchar blue
|
||||||
|
|
||||||
|
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vertex_head: as returned from load_ply_raw.
|
vertex_head: as returned from load_ply_raw.
|
||||||
@ -792,6 +806,7 @@ def _get_verts_column_indices(
|
|||||||
point_idxs: List[int] of 3 point columns.
|
point_idxs: List[int] of 3 point columns.
|
||||||
color_idxs: List[int] of 3 color columns if they are present,
|
color_idxs: List[int] of 3 color columns if they are present,
|
||||||
otherwise None.
|
otherwise None.
|
||||||
|
color_scale: value to scale colors by.
|
||||||
"""
|
"""
|
||||||
point_idxs: List[Optional[int]] = [None, None, None]
|
point_idxs: List[Optional[int]] = [None, None, None]
|
||||||
color_idxs: List[Optional[int]] = [None, None, None]
|
color_idxs: List[Optional[int]] = [None, None, None]
|
||||||
@ -806,9 +821,17 @@ def _get_verts_column_indices(
|
|||||||
color_idxs[j] = i
|
color_idxs[j] = i
|
||||||
if None in point_idxs:
|
if None in point_idxs:
|
||||||
raise ValueError("Invalid vertices in file.")
|
raise ValueError("Invalid vertices in file.")
|
||||||
if None in color_idxs:
|
color_scale = 1.0
|
||||||
return point_idxs, None
|
if all(
|
||||||
return point_idxs, color_idxs
|
idx is not None and _PLY_TYPES[vertex_head.properties[idx].data_type].size == 1
|
||||||
|
for idx in color_idxs
|
||||||
|
):
|
||||||
|
color_scale = 1.0 / 255
|
||||||
|
return (
|
||||||
|
point_idxs,
|
||||||
|
None if None in color_idxs else cast(List[int], color_idxs),
|
||||||
|
color_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_verts(
|
def _get_verts(
|
||||||
@ -831,7 +854,7 @@ def _get_verts(
|
|||||||
if not isinstance(vertex, list):
|
if not isinstance(vertex, list):
|
||||||
raise ValueError("Invalid vertices in file.")
|
raise ValueError("Invalid vertices in file.")
|
||||||
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
||||||
point_idxs, color_idxs = _get_verts_column_indices(vertex_head)
|
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
|
||||||
|
|
||||||
# Case of no vertices
|
# Case of no vertices
|
||||||
if vertex_head.count == 0:
|
if vertex_head.count == 0:
|
||||||
@ -856,7 +879,9 @@ def _get_verts(
|
|||||||
# so it was read as a single array and we can index straight into it.
|
# so it was read as a single array and we can index straight into it.
|
||||||
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
|
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
|
||||||
if color_idxs is not None:
|
if color_idxs is not None:
|
||||||
vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32)
|
vertex_colors = color_scale * torch.tensor(
|
||||||
|
vertex[0][:, color_idxs], dtype=torch.float32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# The vertex element is heterogeneous. It was read as several arrays,
|
# The vertex element is heterogeneous. It was read as several arrays,
|
||||||
# part by part, where a part is a set of properties with the same type.
|
# part by part, where a part is a set of properties with the same type.
|
||||||
@ -887,6 +912,7 @@ def _get_verts(
|
|||||||
for color in range(3):
|
for color in range(3):
|
||||||
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
||||||
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
||||||
|
vertex_colors *= color_scale
|
||||||
|
|
||||||
return verts, vertex_colors
|
return verts, vertex_colors
|
||||||
|
|
||||||
|
@ -448,7 +448,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
|
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
|
||||||
)
|
)
|
||||||
self.assertClose(
|
self.assertClose(
|
||||||
pointcloud.features_padded()[0],
|
pointcloud.features_padded()[0] * 255,
|
||||||
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
|
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -518,7 +518,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
|
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
|
||||||
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
|
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
|
||||||
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
|
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
|
||||||
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32)
|
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32) / 255.0
|
||||||
self.assertClose(pointcloud.points_padded(), expected_points)
|
self.assertClose(pointcloud.points_padded(), expected_points)
|
||||||
self.assertClose(pointcloud.features_padded(), expected_features)
|
self.assertClose(pointcloud.features_padded(), expected_features)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user