Extract finding directories for test data

Summary: Make common functions for finding directories where test data is found, instead of lots of tests using their own `__file__`  while trying to get ./tests/data and the tutorials data.

Reviewed By: nikhilaravi

Differential Revision: D27633701

fbshipit-source-id: 1467bb6018cea16eba3cab097d713116d51071e9
This commit is contained in:
Rong Rong (AI Infra) 2021-04-08 20:02:16 -07:00 committed by Facebook GitHub Bot
parent 24ee279005
commit 1216b5765a
13 changed files with 81 additions and 50 deletions

View File

@ -9,6 +9,20 @@ import torch
from PIL import Image from PIL import Image
def get_tests_dir() -> Path:
"""
Returns Path for the directory containing this file.
"""
return Path(__file__).resolve().parent
def get_pytorch3d_dir() -> Path:
"""
Returns Path for the root PyTorch3D directory.
"""
return get_tests_dir().parent
def load_rgb_image(filename: str, data_dir: Union[str, Path]): def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename filepath = data_dir / filename
with Image.open(filepath) as raw_image: with Image.open(filepath) as raw_image:

View File

@ -3,7 +3,8 @@ import json
import os import os
import unittest import unittest
from collections import Counter from collections import Counter
from pathlib import Path
from common_testing import get_pytorch3d_dir, get_tests_dir
# This file groups together tests which look at the code without running it. # This file groups together tests which look at the code without running it.
@ -16,7 +17,7 @@ class TestBuild(unittest.TestCase):
def test_name_clash(self): def test_name_clash(self):
# For setup.py, all translation units need distinct names, so we # For setup.py, all translation units need distinct names, so we
# cannot have foo.cu and foo.cpp, even in different directories. # cannot have foo.cu and foo.cpp, even in different directories.
test_dir = Path(__file__).resolve().parent test_dir = get_tests_dir()
source_dir = test_dir.parent / "pytorch3d" source_dir = test_dir.parent / "pytorch3d"
stems = [] stems = []
@ -30,7 +31,7 @@ class TestBuild(unittest.TestCase):
@unittest.skipIf(in_conda_build, "In conda build") @unittest.skipIf(in_conda_build, "In conda build")
def test_copyright(self): def test_copyright(self):
test_dir = Path(__file__).resolve().parent test_dir = get_tests_dir()
root_dir = test_dir.parent root_dir = test_dir.parent
extensions = ("py", "cu", "cuh", "cpp", "h", "hpp", "sh") extensions = ("py", "cu", "cuh", "cpp", "h", "hpp", "sh")
@ -63,8 +64,8 @@ class TestBuild(unittest.TestCase):
@unittest.skipIf(in_conda_build, "In conda build") @unittest.skipIf(in_conda_build, "In conda build")
def test_valid_ipynbs(self): def test_valid_ipynbs(self):
# Check that the ipython notebooks are valid json # Check that the ipython notebooks are valid json
test_dir = Path(__file__).resolve().parent root_dir = get_pytorch3d_dir()
tutorials_dir = test_dir.parent / "docs" / "tutorials" tutorials_dir = root_dir / "docs" / "tutorials"
tutorials = sorted(tutorials_dir.glob("*.ipynb")) tutorials = sorted(tutorials_dir.glob("*.ipynb"))
for tutorial in tutorials: for tutorial in tutorials:

View File

@ -8,7 +8,11 @@ from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import torch import torch
from common_testing import TestCaseMixin from common_testing import (
TestCaseMixin,
get_pytorch3d_dir,
get_tests_dir,
)
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj from pytorch3d.io import IO, load_obj, load_objs_as_meshes, save_obj
from pytorch3d.io.mtl_io import ( from pytorch3d.io.mtl_io import (
@ -475,7 +479,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(actual_file, expected_file) self.assertEqual(actual_file, expected_file)
def test_load_mtl(self): def test_load_mtl(self):
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = "cow_mesh/cow.obj" obj_filename = "cow_mesh/cow.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
verts, faces, aux = load_obj(filename) verts, faces, aux = load_obj(filename)
@ -559,7 +563,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
def test_load_mtl_texture_atlas_compare_softras(self): def test_load_mtl_texture_atlas_compare_softras(self):
# Load saved texture atlas created with SoftRas. # Load saved texture atlas created with SoftRas.
device = torch.device("cuda:0") device = torch.device("cuda:0")
DATA_DIR = Path(__file__).resolve().parent.parent DATA_DIR = get_pytorch3d_dir()
obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj" obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj"
expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt" expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt"
@ -590,7 +594,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5) self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5)
def test_load_mtl_noload(self): def test_load_mtl_noload(self):
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = "cow_mesh/cow.obj" obj_filename = "cow_mesh/cow.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
verts, faces, aux = load_obj(filename, load_textures=False) verts, faces, aux = load_obj(filename, load_textures=False)
@ -628,7 +632,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(aux.verts_uvs is None) self.assertTrue(aux.verts_uvs is None)
def test_load_obj_mlt_no_image(self): def test_load_obj_mlt_no_image(self):
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
obj_filename = "obj_mtl_no_image/model.obj" obj_filename = "obj_mtl_no_image/model.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
R = 8 R = 8
@ -657,7 +661,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(list(aux.material_colors.keys()), ["material_1"]) self.assertEqual(list(aux.material_colors.keys()), ["material_1"])
def test_load_obj_missing_texture(self): def test_load_obj_missing_texture(self):
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
obj_filename = "missing_files_obj/model.obj" obj_filename = "missing_files_obj/model.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
with self.assertWarnsRegex(UserWarning, "Texture file does not exist"): with self.assertWarnsRegex(UserWarning, "Texture file does not exist"):
@ -672,7 +676,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces)) self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
def test_load_obj_missing_texture_noload(self): def test_load_obj_missing_texture_noload(self):
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
obj_filename = "missing_files_obj/model.obj" obj_filename = "missing_files_obj/model.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
verts, faces, aux = load_obj(filename, load_textures=False) verts, faces, aux = load_obj(filename, load_textures=False)
@ -688,7 +692,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(aux.texture_images is None) self.assertTrue(aux.texture_images is None)
def test_load_obj_missing_mtl(self): def test_load_obj_missing_mtl(self):
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
obj_filename = "missing_files_obj/model2.obj" obj_filename = "missing_files_obj/model2.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
with self.assertWarnsRegex(UserWarning, "Mtl file does not exist"): with self.assertWarnsRegex(UserWarning, "Mtl file does not exist"):
@ -703,7 +707,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces)) self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
def test_load_obj_missing_mtl_noload(self): def test_load_obj_missing_mtl_noload(self):
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
obj_filename = "missing_files_obj/model2.obj" obj_filename = "missing_files_obj/model2.obj"
filename = os.path.join(DATA_DIR, obj_filename) filename = os.path.join(DATA_DIR, obj_filename)
verts, faces, aux = load_obj(filename, load_textures=False) verts, faces, aux = load_obj(filename, load_textures=False)
@ -760,7 +764,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
mesh.textures.atlas_padded(), mesh3.textures.atlas_padded() mesh.textures.atlas_padded(), mesh3.textures.atlas_padded()
) )
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = DATA_DIR / "cow_mesh/cow.obj" obj_filename = DATA_DIR / "cow_mesh/cow.obj"
mesh = load_objs_as_meshes([obj_filename]) mesh = load_objs_as_meshes([obj_filename])

View File

@ -2,10 +2,9 @@
import os import os
import pickle import pickle
import unittest import unittest
from pathlib import Path
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_tests_dir
from pytorch3d.ops.marching_cubes import marching_cubes_naive from pytorch3d.ops.marching_cubes import marching_cubes_naive
@ -641,7 +640,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
volume, isolevel=64, return_local_coords=False volume, isolevel=64, return_local_coords=False
) )
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
data_filename = "test_marching_cubes_data/sphere_level64.pickle" data_filename = "test_marching_cubes_data/sphere_level64.pickle"
filename = os.path.join(DATA_DIR, data_filename) filename = os.path.join(DATA_DIR, data_filename)
with open(filename, "rb") as file: with open(filename, "rb") as file:
@ -677,7 +676,7 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
volume = volume.permute(0, 3, 2, 1) # (B, D, H, W) volume = volume.permute(0, 3, 2, 1) # (B, D, H, W)
verts, faces = marching_cubes_naive(volume, isolevel=0.001) verts, faces = marching_cubes_naive(volume, isolevel=0.001)
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
data_filename = "test_marching_cubes_data/double_ellipsoid.pickle" data_filename = "test_marching_cubes_data/double_ellipsoid.pickle"
filename = os.path.join(DATA_DIR, data_filename) filename = os.path.join(DATA_DIR, data_filename)
with open(filename, "rb") as file: with open(filename, "rb") as file:

View File

@ -1,11 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, get_tests_dir
from pytorch3d.ops import points_alignment from pytorch3d.ops import points_alignment
from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.transforms import rotation_conversions from pytorch3d.transforms import rotation_conversions
@ -40,7 +39,7 @@ class TestICP(TestCaseMixin, unittest.TestCase):
super().setUp() super().setUp()
torch.manual_seed(42) torch.manual_seed(42)
np.random.seed(42) np.random.seed(42)
trimesh_results_path = Path(__file__).resolve().parent / "data/icp_data.pth" trimesh_results_path = get_tests_dir() / "data/icp_data.pth"
self.trimesh_results = torch.load(trimesh_results_path) self.trimesh_results = torch.load(trimesh_results_path)
@staticmethod @staticmethod

View File

@ -9,7 +9,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import TestCaseMixin, load_rgb_image, get_tests_dir
from PIL import Image from PIL import Image
from pytorch3d.datasets import ( from pytorch3d.datasets import (
R2N2, R2N2,
@ -37,7 +37,7 @@ VOXELS_REL_PATH = "ShapeNetVox"
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
class TestR2N2(TestCaseMixin, unittest.TestCase): class TestR2N2(TestCaseMixin, unittest.TestCase):

View File

@ -2,11 +2,15 @@
import unittest import unittest
from itertools import product from itertools import product
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import (
TestCaseMixin,
load_rgb_image,
get_pytorch3d_dir,
get_tests_dir,
)
from PIL import Image from PIL import Image
from pytorch3d.io import load_obj from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
@ -42,7 +46,7 @@ from pytorch3d.utils import torus
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
# Verts/Faces for a simple mesh with two faces. # Verts/Faces for a simple mesh with two faces.
verts0 = torch.tensor( verts0 = torch.tensor(
@ -449,7 +453,7 @@ class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
Test a larger textured mesh is rendered correctly in a non square image. Test a larger textured mesh is rendered correctly in a non square image.
""" """
device = torch.device("cuda:0") device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" obj_dir = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj" obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture # Load mesh + texture

View File

@ -2,10 +2,10 @@
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import get_tests_dir
from PIL import Image from PIL import Image
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
@ -17,7 +17,7 @@ from pytorch3d.structures import Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
DEBUG = False # Set DEBUG to true to save outputs from the tests. DEBUG = False # Set DEBUG to true to save outputs from the tests.

View File

@ -6,11 +6,15 @@ Sanity checks for output images from the renderer.
""" """
import os import os
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import (
TestCaseMixin,
load_rgb_image,
get_pytorch3d_dir,
get_tests_dir,
)
from PIL import Image from PIL import Image
from pytorch3d.io import load_obj from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import ( from pytorch3d.renderer.cameras import (
@ -46,7 +50,7 @@ from pytorch3d.utils.torus import torus
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
class TestRenderMeshes(TestCaseMixin, unittest.TestCase): class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
@ -384,7 +388,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
The pupils in the eyes of the cow should always be looking to the left. The pupils in the eyes of the cow should always be looking to the left.
""" """
device = torch.device("cuda:0") device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" obj_dir = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj" obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture # Load mesh + texture
@ -966,7 +970,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Also check that the backward pass for texture atlas rendering is differentiable. Also check that the backward pass for texture atlas rendering is differentiable.
""" """
device = torch.device("cuda:0") device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" obj_dir = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj" obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh and texture as a per face texture atlas. # Load mesh and texture as a per face texture atlas.

View File

@ -9,12 +9,11 @@ See pytorch3d/renderer/mesh/clip.py for more details about the
clipping process. clipping process.
""" """
import unittest import unittest
from pathlib import Path
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import TestCaseMixin, load_rgb_image, get_tests_dir
from pytorch3d.io import save_obj from pytorch3d.io import save_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.lighting import PointLights
@ -34,7 +33,7 @@ from pytorch3d.structures.meshes import Meshes
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase): class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):

View File

@ -7,11 +7,15 @@ Sanity checks for output images from the pointcloud renderer.
import unittest import unittest
import warnings import warnings
from os import path from os import path
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import (
TestCaseMixin,
load_rgb_image,
get_pytorch3d_dir,
get_tests_dir,
)
from PIL import Image from PIL import Image
from pytorch3d.renderer.cameras import ( from pytorch3d.renderer.cameras import (
FoVOrthographicCameras, FoVOrthographicCameras,
@ -36,7 +40,7 @@ from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
class TestRenderPoints(TestCaseMixin, unittest.TestCase): class TestRenderPoints(TestCaseMixin, unittest.TestCase):
@ -216,7 +220,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
def test_pointcloud_with_features(self): def test_pointcloud_with_features(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" file_dir = get_pytorch3d_dir() / "docs/tutorials/data"
pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz" pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz"
# Note, this file is too large to check in to the repo. # Note, this file is too large to check in to the repo.

View File

@ -2,11 +2,15 @@
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, get_random_cuda_device from common_testing import (
TestCaseMixin,
get_random_cuda_device,
get_pytorch3d_dir,
get_tests_dir,
)
from PIL import Image from PIL import Image
from pytorch3d.io import load_objs_as_meshes from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes from pytorch3d.ops import sample_points_from_meshes
@ -26,7 +30,7 @@ from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
class TestSamplePoints(TestCaseMixin, unittest.TestCase): class TestSamplePoints(TestCaseMixin, unittest.TestCase):
@ -261,7 +265,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
Confirm that torch.multinomial does not sample elements which have Confirm that torch.multinomial does not sample elements which have
zero probability using a real example of input from a training run. zero probability using a real example of input from a training run.
""" """
weights = torch.load(Path(__file__).resolve().parent / "weights.pt") weights = torch.load(get_tests_dir() / "weights.pt")
S = 4096 S = 4096
num_trials = 100 num_trials = 100
for _ in range(0, num_trials): for _ in range(0, num_trials):
@ -378,7 +382,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
# the cow mesh and its texture uv to a pointcloud with texture # the cow mesh and its texture uv to a pointcloud with texture
device = torch.device("cuda:0") device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" obj_dir = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj" obj_filename = obj_dir / "cow_mesh/cow.obj"
for text_type in ("uv", "atlas"): for text_type in ("uv", "atlas"):

View File

@ -4,11 +4,10 @@ Sanity checks for loading ShapeNetCore.
""" """
import os import os
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin, load_rgb_image from common_testing import TestCaseMixin, load_rgb_image, get_tests_dir
from PIL import Image from PIL import Image
from pytorch3d.datasets import ShapeNetCore, collate_batched_meshes from pytorch3d.datasets import ShapeNetCore, collate_batched_meshes
from pytorch3d.renderer import ( from pytorch3d.renderer import (
@ -26,7 +25,7 @@ VERSION = 1
# If DEBUG=True, save out images generated in the tests for debugging. # If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_ # All saved images have prefix DEBUG_
DEBUG = False DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data" DATA_DIR = get_tests_dir() / "data"
class TestShapenetCore(TestCaseMixin, unittest.TestCase): class TestShapenetCore(TestCaseMixin, unittest.TestCase):