Update load obj and compare with SoftRas

Summary:
Updated the load obj function to support creating of a per face texture map using the information in an .mtl file. Uses the approach from in SoftRasterizer.

Currently I have ported in the SoftRasterizer code but this is only to help with comparison and will  be deleted before landing. The ShapeNet Test data will also be deleted.

Here is the [Design doc](https://docs.google.com/document/d/1AUcLP4QwVSqlfLAUfbjM9ic5vYn9P54Ha8QbcVXW2eI/edit?usp=sharing).

## Added
- texture atlas creation functions in PyTorch based on the SoftRas cuda implementation
- tests to compare SoftRas vs PyTorch3D implementation to verify it matches (using real shapenet data with meshes consisting of multiple textures)
- benchmarks tests

## Remaining todo:
- add more tests for obj io to test the new functions and the two texturing options
- replace the shapenet data with the output from SoftRas saved as a file.

# MAIN FILES TO REVIEW

- `obj_io.py`
- `test_obj_io.py` [still some tests to be added but have comparisons with SoftRas for now]

The reference SoftRas implementations are in `softras_load_obj.py` and `load_textures.cu`.

Reviewed By: gkioxari

Differential Revision: D20754859

fbshipit-source-id: 42ace9dfb73f26e29d800c763f56d5b66c60c5e2
This commit is contained in:
Nikhila Ravi
2020-04-23 19:31:50 -07:00
committed by Facebook GitHub Bot
parent 85c396f822
commit c9267ab7af
5 changed files with 785 additions and 168 deletions

View File

@@ -1,5 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
from fvcore.common.benchmark import benchmark
from test_obj_io import TestMeshObjIO
from test_ply_io import TestMeshPlyIO
@@ -61,3 +63,35 @@ def bm_save_load() -> None:
complex_kwargs_list,
warmup_iters=1,
)
# Texture loading benchmarks
kwargs_list = [{"R": 2}, {"R": 4}, {"R": 10}, {"R": 15}, {"R": 20}]
benchmark(
TestMeshObjIO.bm_load_texture_atlas,
"PYTORCH3D_TEXTURE_ATLAS",
kwargs_list,
warmup_iters=1,
)
kwargs_list = []
S = [64, 256, 1024]
F = [100, 1000, 10000]
R = [5, 10, 20]
test_cases = product(S, F, R)
for case in test_cases:
s, f, r = case
kwargs_list.append({"S": s, "F": f, "R": r})
benchmark(
TestMeshObjIO.bm_bilinear_sampling_vectorized,
"BILINEAR_VECTORIZED",
kwargs_list,
warmup_iters=1,
)
benchmark(
TestMeshObjIO.bm_bilinear_sampling_grid_sample,
"BILINEAR_GRID_SAMPLE",
kwargs_list,
warmup_iters=1,
)

View File

@@ -2,12 +2,17 @@
import os
import unittest
import warnings
from io import StringIO
from pathlib import Path
import torch
from common_testing import TestCaseMixin
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import (
_bilinear_interpolation_grid_sample,
_bilinear_interpolation_vectorized,
)
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
from pytorch3d.utils import torus
@@ -47,8 +52,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
self.assertTrue(faces.normals_idx == [])
self.assertTrue(faces.textures_idx == [])
padded_vals = -torch.ones_like(faces.verts_idx)
self.assertTrue(torch.all(faces.normals_idx == padded_vals))
self.assertTrue(torch.all(faces.textures_idx == padded_vals))
self.assertTrue(
torch.all(faces.materials_idx == -torch.ones(len(expected_faces)))
)
@@ -118,8 +124,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
dtype=torch.float32,
)
expected_faces_normals_idx = torch.tensor([[1, 1, 1]], dtype=torch.int64)
expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
expected_faces_normals_idx = -torch.ones_like(expected_faces, dtype=torch.int64)
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64)
expected_faces_textures_idx = -torch.ones_like(
expected_faces, dtype=torch.int64
)
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
@@ -160,7 +170,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
self.assertClose(normals, expected_normals)
self.assertClose(verts, expected_verts)
self.assertTrue(faces.textures_idx == [])
# Textures idx padded with -1.
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1)
self.assertTrue(textures is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
@@ -195,7 +206,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
self.assertClose(expected_textures, textures)
self.assertClose(expected_verts, verts)
self.assertTrue(faces.normals_idx == [])
self.assertTrue(
torch.all(faces.normals_idx == -torch.ones_like(faces.textures_idx))
)
self.assertTrue(normals is None)
self.assertTrue(materials is None)
self.assertTrue(tex_maps is None)
@@ -408,6 +421,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"shininess": torch.tensor([10.0], dtype=dtype),
}
}
# Texture atlas is not created as `create_texture_atlas=True` was
# not set in the load_obj args
self.assertTrue(aux.texture_atlas is None)
# Check that there is an image with material name material_1.
self.assertTrue(tuple(tex_maps.keys()) == ("material_1",))
self.assertTrue(torch.is_tensor(tuple(tex_maps.values())[0]))
@@ -423,6 +439,36 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
torch.allclose(materials[n1][k1], expected_materials[n2][k2])
)
def test_load_mtl_texture_atlas_compare_softras(self):
# Load saved texture atlas created with SoftRas.
device = torch.device("cuda:0")
DATA_DIR = Path(__file__).resolve().parent.parent
obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj"
expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt"
# Note, the reference texture atlas generated using SoftRas load_obj function
# is too large to check in to the repo. Download the file to run the test locally.
if not os.path.exists(expected_atlas_fname):
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/tests/cow_texture_atlas_softras.pt"
msg = (
"cow_texture_atlas_softras.pt not found, download from %s, save it at the path %s, and rerun"
% (url, expected_atlas_fname)
)
warnings.warn(msg)
return True
expected_atlas = torch.load(expected_atlas_fname)
_, _, aux = load_obj(
obj_filename,
load_textures=True,
device=device,
create_texture_atlas=True,
texture_atlas_size=15,
texture_wrap="repeat",
)
self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5)
def test_load_mtl_noload(self):
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = "cow_mesh/cow.obj"
@@ -629,3 +675,51 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5)
@staticmethod
def bm_load_texture_atlas(R: int):
device = torch.device("cuda:0")
torch.cuda.set_device(device)
DATA_DIR = "/data/users/nikhilar/fbsource/fbcode/vision/fair/pytorch3d/docs/"
obj_filename = os.path.join(DATA_DIR, "tutorials/data/cow_mesh/cow.obj")
torch.cuda.synchronize()
def load():
load_obj(
obj_filename,
load_textures=True,
device=device,
create_texture_atlas=True,
texture_atlas_size=R,
)
torch.cuda.synchronize()
return load
@staticmethod
def bm_bilinear_sampling_vectorized(S: int, F: int, R: int):
device = torch.device("cuda:0")
torch.cuda.set_device(device)
image = torch.rand((S, S, 3))
grid = torch.rand((F, R, R, 2))
torch.cuda.synchronize()
def load():
_bilinear_interpolation_vectorized(image, grid)
torch.cuda.synchronize()
return load
@staticmethod
def bm_bilinear_sampling_grid_sample(S: int, F: int, R: int):
device = torch.device("cuda:0")
torch.cuda.set_device(device)
image = torch.rand((S, S, 3))
grid = torch.rand((F, R, R, 2))
torch.cuda.synchronize()
def load():
_bilinear_interpolation_grid_sample(image, grid)
torch.cuda.synchronize()
return load