mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Update load obj and compare with SoftRas
Summary: Updated the load obj function to support creating of a per face texture map using the information in an .mtl file. Uses the approach from in SoftRasterizer. Currently I have ported in the SoftRasterizer code but this is only to help with comparison and will be deleted before landing. The ShapeNet Test data will also be deleted. Here is the [Design doc](https://docs.google.com/document/d/1AUcLP4QwVSqlfLAUfbjM9ic5vYn9P54Ha8QbcVXW2eI/edit?usp=sharing). ## Added - texture atlas creation functions in PyTorch based on the SoftRas cuda implementation - tests to compare SoftRas vs PyTorch3D implementation to verify it matches (using real shapenet data with meshes consisting of multiple textures) - benchmarks tests ## Remaining todo: - add more tests for obj io to test the new functions and the two texturing options - replace the shapenet data with the output from SoftRas saved as a file. # MAIN FILES TO REVIEW - `obj_io.py` - `test_obj_io.py` [still some tests to be added but have comparisons with SoftRas for now] The reference SoftRas implementations are in `softras_load_obj.py` and `load_textures.cu`. Reviewed By: gkioxari Differential Revision: D20754859 fbshipit-source-id: 42ace9dfb73f26e29d800c763f56d5b66c60c5e2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
85c396f822
commit
c9267ab7af
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_obj_io import TestMeshObjIO
|
||||
from test_ply_io import TestMeshPlyIO
|
||||
@@ -61,3 +63,35 @@ def bm_save_load() -> None:
|
||||
complex_kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
# Texture loading benchmarks
|
||||
kwargs_list = [{"R": 2}, {"R": 4}, {"R": 10}, {"R": 15}, {"R": 20}]
|
||||
benchmark(
|
||||
TestMeshObjIO.bm_load_texture_atlas,
|
||||
"PYTORCH3D_TEXTURE_ATLAS",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
kwargs_list = []
|
||||
S = [64, 256, 1024]
|
||||
F = [100, 1000, 10000]
|
||||
R = [5, 10, 20]
|
||||
test_cases = product(S, F, R)
|
||||
|
||||
for case in test_cases:
|
||||
s, f, r = case
|
||||
kwargs_list.append({"S": s, "F": f, "R": r})
|
||||
|
||||
benchmark(
|
||||
TestMeshObjIO.bm_bilinear_sampling_vectorized,
|
||||
"BILINEAR_VECTORIZED",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
TestMeshObjIO.bm_bilinear_sampling_grid_sample,
|
||||
"BILINEAR_GRID_SAMPLE",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,17 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
|
||||
from pytorch3d.io.mtl_io import (
|
||||
_bilinear_interpolation_grid_sample,
|
||||
_bilinear_interpolation_vectorized,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
@@ -47,8 +52,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.all(verts == expected_verts))
|
||||
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
|
||||
self.assertTrue(faces.normals_idx == [])
|
||||
self.assertTrue(faces.textures_idx == [])
|
||||
padded_vals = -torch.ones_like(faces.verts_idx)
|
||||
self.assertTrue(torch.all(faces.normals_idx == padded_vals))
|
||||
self.assertTrue(torch.all(faces.textures_idx == padded_vals))
|
||||
self.assertTrue(
|
||||
torch.all(faces.materials_idx == -torch.ones(len(expected_faces)))
|
||||
)
|
||||
@@ -118,8 +124,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
expected_faces_normals_idx = torch.tensor([[1, 1, 1]], dtype=torch.int64)
|
||||
expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
|
||||
expected_faces_normals_idx = -torch.ones_like(expected_faces, dtype=torch.int64)
|
||||
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64)
|
||||
expected_faces_textures_idx = -torch.ones_like(
|
||||
expected_faces, dtype=torch.int64
|
||||
)
|
||||
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64)
|
||||
|
||||
self.assertTrue(torch.all(verts == expected_verts))
|
||||
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
|
||||
@@ -160,7 +170,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
|
||||
self.assertClose(normals, expected_normals)
|
||||
self.assertClose(verts, expected_verts)
|
||||
self.assertTrue(faces.textures_idx == [])
|
||||
# Textures idx padded with -1.
|
||||
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1)
|
||||
self.assertTrue(textures is None)
|
||||
self.assertTrue(materials is None)
|
||||
self.assertTrue(tex_maps is None)
|
||||
@@ -195,7 +206,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
|
||||
self.assertClose(expected_textures, textures)
|
||||
self.assertClose(expected_verts, verts)
|
||||
self.assertTrue(faces.normals_idx == [])
|
||||
self.assertTrue(
|
||||
torch.all(faces.normals_idx == -torch.ones_like(faces.textures_idx))
|
||||
)
|
||||
self.assertTrue(normals is None)
|
||||
self.assertTrue(materials is None)
|
||||
self.assertTrue(tex_maps is None)
|
||||
@@ -408,6 +421,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
"shininess": torch.tensor([10.0], dtype=dtype),
|
||||
}
|
||||
}
|
||||
# Texture atlas is not created as `create_texture_atlas=True` was
|
||||
# not set in the load_obj args
|
||||
self.assertTrue(aux.texture_atlas is None)
|
||||
# Check that there is an image with material name material_1.
|
||||
self.assertTrue(tuple(tex_maps.keys()) == ("material_1",))
|
||||
self.assertTrue(torch.is_tensor(tuple(tex_maps.values())[0]))
|
||||
@@ -423,6 +439,36 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
torch.allclose(materials[n1][k1], expected_materials[n2][k2])
|
||||
)
|
||||
|
||||
def test_load_mtl_texture_atlas_compare_softras(self):
|
||||
# Load saved texture atlas created with SoftRas.
|
||||
device = torch.device("cuda:0")
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent
|
||||
obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj"
|
||||
expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt"
|
||||
|
||||
# Note, the reference texture atlas generated using SoftRas load_obj function
|
||||
# is too large to check in to the repo. Download the file to run the test locally.
|
||||
if not os.path.exists(expected_atlas_fname):
|
||||
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/tests/cow_texture_atlas_softras.pt"
|
||||
msg = (
|
||||
"cow_texture_atlas_softras.pt not found, download from %s, save it at the path %s, and rerun"
|
||||
% (url, expected_atlas_fname)
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
expected_atlas = torch.load(expected_atlas_fname)
|
||||
_, _, aux = load_obj(
|
||||
obj_filename,
|
||||
load_textures=True,
|
||||
device=device,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=15,
|
||||
texture_wrap="repeat",
|
||||
)
|
||||
|
||||
self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5)
|
||||
|
||||
def test_load_mtl_noload(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
@@ -629,3 +675,51 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
|
||||
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
|
||||
return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5)
|
||||
|
||||
@staticmethod
|
||||
def bm_load_texture_atlas(R: int):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
DATA_DIR = "/data/users/nikhilar/fbsource/fbcode/vision/fair/pytorch3d/docs/"
|
||||
obj_filename = os.path.join(DATA_DIR, "tutorials/data/cow_mesh/cow.obj")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load():
|
||||
load_obj(
|
||||
obj_filename,
|
||||
load_textures=True,
|
||||
device=device,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=R,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return load
|
||||
|
||||
@staticmethod
|
||||
def bm_bilinear_sampling_vectorized(S: int, F: int, R: int):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
image = torch.rand((S, S, 3))
|
||||
grid = torch.rand((F, R, R, 2))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load():
|
||||
_bilinear_interpolation_vectorized(image, grid)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return load
|
||||
|
||||
@staticmethod
|
||||
def bm_bilinear_sampling_grid_sample(S: int, F: int, R: int):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
image = torch.rand((S, S, 3))
|
||||
grid = torch.rand((F, R, R, 2))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load():
|
||||
_bilinear_interpolation_grid_sample(image, grid)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return load
|
||||
|
||||
Reference in New Issue
Block a user