mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 11:50:35 +08:00
Extract finding directories for test data
Summary: Make common functions for finding directories where test data is found, instead of lots of tests using their own `__file__` while trying to get ./tests/data and the tutorials data. Reviewed By: nikhilaravi Differential Revision: D27633701 fbshipit-source-id: 1467bb6018cea16eba3cab097d713116d51071e9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
24ee279005
commit
1216b5765a
@@ -8,7 +8,11 @@ from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import (
|
||||
TestCaseMixin,
|
||||
get_pytorch3d_dir,
|
||||
get_tests_dir,
|
||||
)
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
|
||||
from pytorch3d.io.mtl_io import (
|
||||
@@ -475,7 +479,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(actual_file, expected_file)
|
||||
|
||||
def test_load_mtl(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
verts, faces, aux = load_obj(filename)
|
||||
@@ -559,7 +563,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
def test_load_mtl_texture_atlas_compare_softras(self):
|
||||
# Load saved texture atlas created with SoftRas.
|
||||
device = torch.device("cuda:0")
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent
|
||||
DATA_DIR = get_pytorch3d_dir()
|
||||
obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj"
|
||||
expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt"
|
||||
|
||||
@@ -590,7 +594,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5)
|
||||
|
||||
def test_load_mtl_noload(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
verts, faces, aux = load_obj(filename, load_textures=False)
|
||||
@@ -628,7 +632,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(aux.verts_uvs is None)
|
||||
|
||||
def test_load_obj_mlt_no_image(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
obj_filename = "obj_mtl_no_image/model.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
R = 8
|
||||
@@ -657,7 +661,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(list(aux.material_colors.keys()), ["material_1"])
|
||||
|
||||
def test_load_obj_missing_texture(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
obj_filename = "missing_files_obj/model.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
with self.assertWarnsRegex(UserWarning, "Texture file does not exist"):
|
||||
@@ -672,7 +676,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
|
||||
|
||||
def test_load_obj_missing_texture_noload(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
obj_filename = "missing_files_obj/model.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
verts, faces, aux = load_obj(filename, load_textures=False)
|
||||
@@ -688,7 +692,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(aux.texture_images is None)
|
||||
|
||||
def test_load_obj_missing_mtl(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
obj_filename = "missing_files_obj/model2.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
with self.assertWarnsRegex(UserWarning, "Mtl file does not exist"):
|
||||
@@ -703,7 +707,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
|
||||
|
||||
def test_load_obj_missing_mtl_noload(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
obj_filename = "missing_files_obj/model2.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
verts, faces, aux = load_obj(filename, load_textures=False)
|
||||
@@ -760,7 +764,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
mesh.textures.atlas_padded(), mesh3.textures.atlas_padded()
|
||||
)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
|
||||
|
||||
mesh = load_objs_as_meshes([obj_filename])
|
||||
|
||||
Reference in New Issue
Block a user