mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Test rendering models for R2N2
Summary: Adding a render function for R2N2. Reviewed By: nikhilaravi Differential Revision: D22230228 fbshipit-source-id: a9f588ddcba15bb5d8be1401f68d730e810b4251
This commit is contained in:
		
							parent
							
								
									49b4ce1acc
								
							
						
					
					
						commit
						5636eb6152
					
				@ -64,6 +64,7 @@ class R2N2(ShapeNetBase):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            synset_set.add(synset)
 | 
			
		||||
            self.synset_starts[synset] = len(self.synset_ids)
 | 
			
		||||
            models = split_dict[synset].keys()
 | 
			
		||||
            for model in models:
 | 
			
		||||
                # Examine if the given model is present in the ShapeNetCore path.
 | 
			
		||||
@ -78,6 +79,7 @@ class R2N2(ShapeNetBase):
 | 
			
		||||
                    continue
 | 
			
		||||
                self.synset_ids.append(synset)
 | 
			
		||||
                self.model_ids.append(model)
 | 
			
		||||
            self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
 | 
			
		||||
 | 
			
		||||
        # Examine if all the synsets in the standard R2N2 mapping are present.
 | 
			
		||||
        # Update self.synset_inv so that it only includes the loaded categories.
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
        self.synset_starts = {}
 | 
			
		||||
        self.synset_lens = {}
 | 
			
		||||
        self.shapenet_dir = ""
 | 
			
		||||
        self.model_dir = ""
 | 
			
		||||
        self.model_dir = "model.obj"
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_categories_0.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_categories_0.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.2 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_categories_1.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_categories_1.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 3.7 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_categories_2.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_categories_2.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.3 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_idxs_and_ids_0.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_idxs_and_ids_0.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.1 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_idxs_and_ids_1.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_idxs_and_ids_1.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.0 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_idxs_and_ids_2.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_r2n2_render_by_idxs_and_ids_2.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.0 KiB  | 
@ -5,10 +5,19 @@ Sanity checks for loading R2N2.
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from common_testing import TestCaseMixin, load_rgb_image
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.datasets import R2N2, collate_batched_meshes
 | 
			
		||||
from pytorch3d.renderer import (
 | 
			
		||||
    OpenGLPerspectiveCameras,
 | 
			
		||||
    PointLights,
 | 
			
		||||
    RasterizationSettings,
 | 
			
		||||
    look_at_view_transform,
 | 
			
		||||
)
 | 
			
		||||
from torch.utils.data import DataLoader
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -17,6 +26,9 @@ R2N2_PATH = None
 | 
			
		||||
SHAPENET_PATH = None
 | 
			
		||||
SPLITS_PATH = None
 | 
			
		||||
 | 
			
		||||
DEBUG = False
 | 
			
		||||
DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_load_R2N2(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test loading the train split of R2N2. Check the loaded dataset return items
 | 
			
		||||
        of the correct shapes and types.
 | 
			
		||||
        Test the loaded train split of R2N2 return items of the correct shapes and types.
 | 
			
		||||
        """
 | 
			
		||||
        # Load dataset in the train split.
 | 
			
		||||
        split = "train"
 | 
			
		||||
        r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
 | 
			
		||||
        # Check total number of objects in the dataset is correct.
 | 
			
		||||
        with open(SPLITS_PATH) as splits:
 | 
			
		||||
            split_dict = json.load(splits)[split]
 | 
			
		||||
            split_dict = json.load(splits)["train"]
 | 
			
		||||
        model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
 | 
			
		||||
        self.assertEqual(len(r2n2_dataset), sum(model_nums))
 | 
			
		||||
 | 
			
		||||
@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        the correct shapes and types are returned.
 | 
			
		||||
        """
 | 
			
		||||
        # Load dataset in the train split.
 | 
			
		||||
        split = "train"
 | 
			
		||||
        r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
 | 
			
		||||
        # Randomly retrieve several objects from the dataset and collate them.
 | 
			
		||||
        collated_meshes = collate_batched_meshes(
 | 
			
		||||
@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertEqual(len(object_batch["label"]), batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
 | 
			
		||||
        self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
 | 
			
		||||
 | 
			
		||||
    def test_catch_render_arg_errors(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test rendering R2N2 with an invalid model_id, category or index, and
 | 
			
		||||
        catch corresponding errors.
 | 
			
		||||
        """
 | 
			
		||||
        # Load dataset in the train split.
 | 
			
		||||
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
 | 
			
		||||
        # Try loading with an invalid model_id and catch error.
 | 
			
		||||
        with self.assertRaises(ValueError) as err:
 | 
			
		||||
            r2n2_dataset.render(model_ids=["lamp0"])
 | 
			
		||||
        self.assertTrue("not found in the loaded dataset" in str(err.exception))
 | 
			
		||||
 | 
			
		||||
        # Try loading with an index out of bounds and catch error.
 | 
			
		||||
        with self.assertRaises(IndexError) as err:
 | 
			
		||||
            r2n2_dataset.render(idxs=[1000000])
 | 
			
		||||
        self.assertTrue("are out of bounds" in str(err.exception))
 | 
			
		||||
 | 
			
		||||
    def test_render_r2n2(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test rendering objects from R2N2 selected both by indices and model_ids.
 | 
			
		||||
        """
 | 
			
		||||
        # Set up device and seed for random selections.
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        torch.manual_seed(39)
 | 
			
		||||
 | 
			
		||||
        # Load dataset in the train split.
 | 
			
		||||
        r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
 | 
			
		||||
 | 
			
		||||
        # Render first three models in the dataset.
 | 
			
		||||
        R, T = look_at_view_transform(1.0, 1.0, 90)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
 | 
			
		||||
        raster_settings = RasterizationSettings(image_size=512)
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
 | 
			
		||||
            # TODO: debug the source of the discrepancy in two images when rendering on GPU.
 | 
			
		||||
            diffuse_color=((0, 0, 0),),
 | 
			
		||||
            specular_color=((0, 0, 0),),
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        r2n2_by_idxs = r2n2_dataset.render(
 | 
			
		||||
            idxs=list(range(3)),
 | 
			
		||||
            device=device,
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            raster_settings=raster_settings,
 | 
			
		||||
            lights=lights,
 | 
			
		||||
        )
 | 
			
		||||
        # Check that there are three images in the batch.
 | 
			
		||||
        self.assertEqual(r2n2_by_idxs.shape[0], 3)
 | 
			
		||||
 | 
			
		||||
        # Compare the rendered models to the reference images.
 | 
			
		||||
        for idx in range(3):
 | 
			
		||||
            r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx)
 | 
			
		||||
                )
 | 
			
		||||
            image_ref = load_rgb_image(
 | 
			
		||||
                "test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
        # Render the same models but by model_ids this time.
 | 
			
		||||
        r2n2_by_model_ids = r2n2_dataset.render(
 | 
			
		||||
            model_ids=[
 | 
			
		||||
                "1a4a8592046253ab5ff61a3a2a0e2484",
 | 
			
		||||
                "1a04dcce7027357ab540cc4083acfa57",
 | 
			
		||||
                "1a9d0480b74d782698f5bccb3529a48d",
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            raster_settings=raster_settings,
 | 
			
		||||
            lights=lights,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Compare the rendered models to the reference images.
 | 
			
		||||
        for idx in range(3):
 | 
			
		||||
            r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray(
 | 
			
		||||
                    (r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8)
 | 
			
		||||
                ).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx))
 | 
			
		||||
            image_ref = load_rgb_image(
 | 
			
		||||
                "test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
        ###############################
 | 
			
		||||
        # Test rendering by categories
 | 
			
		||||
        ###############################
 | 
			
		||||
 | 
			
		||||
        # Render a mixture of categories.
 | 
			
		||||
        categories = ["chair", "lamp"]
 | 
			
		||||
        mixed_objs = r2n2_dataset.render(
 | 
			
		||||
            categories=categories,
 | 
			
		||||
            sample_nums=[1, 2],
 | 
			
		||||
            device=device,
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            raster_settings=raster_settings,
 | 
			
		||||
            lights=lights,
 | 
			
		||||
        )
 | 
			
		||||
        # Compare the rendered models to the reference images.
 | 
			
		||||
        for idx in range(3):
 | 
			
		||||
            mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx)
 | 
			
		||||
                )
 | 
			
		||||
            image_ref = load_rgb_image(
 | 
			
		||||
                "test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(mixed_rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user