mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Adding renderer for ShapeNetBase
Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering) Reviewed By: nikhilaravi Differential Revision: D22102673 fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7
This commit is contained in:
		
							parent
							
								
									09c1762939
								
							
						
					
					
						commit
						358e211cde
					
				@ -1,4 +1,5 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from .shapenet_core import ShapeNetCore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,15 +5,16 @@ import os
 | 
			
		||||
import warnings
 | 
			
		||||
from os import path
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Dict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
 | 
			
		||||
from pytorch3d.io import load_obj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ShapeNetCore(torch.utils.data.Dataset):
 | 
			
		||||
class ShapeNetCore(ShapeNetBase):
 | 
			
		||||
    """
 | 
			
		||||
    This class loads ShapeNetCore from a given directory into a Dataset object.
 | 
			
		||||
    ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
 | 
			
		||||
@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
 | 
			
		||||
    def __init__(self, data_dir, synsets=None, version: int = 1):
 | 
			
		||||
        """
 | 
			
		||||
        Store each object's synset id and models id from data_dir.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            data_dir: Path to ShapeNetCore data.
 | 
			
		||||
            synsets: List of synset categories to load from ShapeNetCore in the form of
 | 
			
		||||
@ -38,6 +40,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
 | 
			
		||||
                version 1.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.data_dir = data_dir
 | 
			
		||||
        if version not in [1, 2]:
 | 
			
		||||
            raise ValueError("Version number must be either 1 or 2.")
 | 
			
		||||
@ -48,7 +51,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
 | 
			
		||||
        with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
 | 
			
		||||
            self.synset_dict = json.load(read_dict)
 | 
			
		||||
        # Inverse dicitonary mapping synset labels to corresponding offsets.
 | 
			
		||||
        synset_inv = {label: offset for offset, label in self.synset_dict.items()}
 | 
			
		||||
        self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
 | 
			
		||||
 | 
			
		||||
        # If categories are specified, check if each category is in the form of either
 | 
			
		||||
        # synset offset or synset label, and if the category exists in the given directory.
 | 
			
		||||
@ -60,62 +63,61 @@ class ShapeNetCore(torch.utils.data.Dataset):
 | 
			
		||||
                    path.isdir(path.join(data_dir, synset))
 | 
			
		||||
                ):
 | 
			
		||||
                    synset_set.add(synset)
 | 
			
		||||
                elif (synset in synset_inv.keys()) and (
 | 
			
		||||
                    (path.isdir(path.join(data_dir, synset_inv[synset])))
 | 
			
		||||
                elif (synset in self.synset_inv.keys()) and (
 | 
			
		||||
                    (path.isdir(path.join(data_dir, self.synset_inv[synset])))
 | 
			
		||||
                ):
 | 
			
		||||
                    synset_set.add(synset_inv[synset])
 | 
			
		||||
                    synset_set.add(self.synset_inv[synset])
 | 
			
		||||
                else:
 | 
			
		||||
                    msg = """Synset category %s either not part of ShapeNetCore dataset
 | 
			
		||||
                         or cannot be found in %s.""" % (
 | 
			
		||||
                        synset,
 | 
			
		||||
                        data_dir,
 | 
			
		||||
                    )
 | 
			
		||||
                    msg = (
 | 
			
		||||
                        "Synset category %s either not part of ShapeNetCore dataset "
 | 
			
		||||
                        "or cannot be found in %s."
 | 
			
		||||
                    ) % (synset, data_dir)
 | 
			
		||||
                    warnings.warn(msg)
 | 
			
		||||
        # If no category is given, load every category in the given directory.
 | 
			
		||||
        # Ignore synset folders not included in the official mapping.
 | 
			
		||||
        else:
 | 
			
		||||
            synset_set = {
 | 
			
		||||
                synset
 | 
			
		||||
                for synset in os.listdir(data_dir)
 | 
			
		||||
                if path.isdir(path.join(data_dir, synset))
 | 
			
		||||
                and synset in self.synset_dict
 | 
			
		||||
            }
 | 
			
		||||
            for synset in synset_set:
 | 
			
		||||
                if synset not in self.synset_dict.keys():
 | 
			
		||||
                    msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
 | 
			
		||||
                        but not found in %s.""" % (
 | 
			
		||||
                        synset,
 | 
			
		||||
                        self.synset_dict[synset],
 | 
			
		||||
                        version,
 | 
			
		||||
                        data_dir,
 | 
			
		||||
                    )
 | 
			
		||||
                    warnings.warn(msg)
 | 
			
		||||
 | 
			
		||||
        # Check if there are any categories in the official mapping that are not loaded.
 | 
			
		||||
        # Update self.synset_inv so that it only includes the loaded categories.
 | 
			
		||||
        synset_not_present = set(self.synset_dict.keys()).difference(synset_set)
 | 
			
		||||
        [self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present]
 | 
			
		||||
 | 
			
		||||
        if len(synset_not_present) > 0:
 | 
			
		||||
            msg = (
 | 
			
		||||
                "The following categories are included in ShapeNetCore ver.%d's "
 | 
			
		||||
                "official mapping but not found in the dataset location %s: %s"
 | 
			
		||||
                ""
 | 
			
		||||
            ) % (version, data_dir, ", ".join(synset_not_present))
 | 
			
		||||
            warnings.warn(msg)
 | 
			
		||||
 | 
			
		||||
        # Extract model_id of each object from directory names.
 | 
			
		||||
        # Each grandchildren directory of data_dir contains an object, and the name
 | 
			
		||||
        # of the directory is the object's model_id.
 | 
			
		||||
        self.synset_ids = []
 | 
			
		||||
        self.model_ids = []
 | 
			
		||||
        for synset in synset_set:
 | 
			
		||||
            for model in os.listdir(path.join(data_dir, synset)):
 | 
			
		||||
                if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
 | 
			
		||||
                    msg = """ Object file not found in the model directory %s
 | 
			
		||||
                        under synset directory %s.""" % (
 | 
			
		||||
                        model,
 | 
			
		||||
                        synset,
 | 
			
		||||
                    )
 | 
			
		||||
                    msg = (
 | 
			
		||||
                        "Object file not found in the model directory %s "
 | 
			
		||||
                        "under synset directory %s."
 | 
			
		||||
                    ) % (model, synset)
 | 
			
		||||
                    warnings.warn(msg)
 | 
			
		||||
                else:
 | 
			
		||||
                    self.synset_ids.append(synset)
 | 
			
		||||
                    self.model_ids.append(model)
 | 
			
		||||
                    continue
 | 
			
		||||
                self.synset_ids.append(synset)
 | 
			
		||||
                self.model_ids.append(model)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return number of total models in shapenet core.
 | 
			
		||||
        """
 | 
			
		||||
        return len(self.model_ids)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
    def __getitem__(self, idx: int) -> Dict:
 | 
			
		||||
        """
 | 
			
		||||
        Read a model by the given index.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            idx: The idx of the model to be retrieved in the dataset.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dictionary with following keys:
 | 
			
		||||
            - verts: FloatTensor of shape (V, 3).
 | 
			
		||||
@ -124,9 +126,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
 | 
			
		||||
            - model_id (str): model id
 | 
			
		||||
            - label (str): synset label.
 | 
			
		||||
        """
 | 
			
		||||
        model = {}
 | 
			
		||||
        model["synset_id"] = self.synset_ids[idx]
 | 
			
		||||
        model["model_id"] = self.model_ids[idx]
 | 
			
		||||
        model = self._get_item_ids(idx)
 | 
			
		||||
        model_path = path.join(
 | 
			
		||||
            self.data_dir, model["synset_id"], model["model_id"], self.model_dir
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										107
									
								
								pytorch3d/datasets/shapenet_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								pytorch3d/datasets/shapenet_base.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,107 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from typing import Dict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.renderer import (
 | 
			
		||||
    HardPhongShader,
 | 
			
		||||
    MeshRasterizer,
 | 
			
		||||
    MeshRenderer,
 | 
			
		||||
    OpenGLPerspectiveCameras,
 | 
			
		||||
    PointLights,
 | 
			
		||||
    RasterizationSettings,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Meshes, Textures
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
    """
 | 
			
		||||
    'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods.
 | 
			
		||||
    It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__
 | 
			
		||||
    and __getitem__ need to be implemented.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        """
 | 
			
		||||
        Set up lists of synset_ids and model_ids.
 | 
			
		||||
        """
 | 
			
		||||
        self.synset_ids = []
 | 
			
		||||
        self.model_ids = []
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return number of total models in the loaded dataset.
 | 
			
		||||
        """
 | 
			
		||||
        return len(self.model_ids)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx) -> Dict:
 | 
			
		||||
        """
 | 
			
		||||
        Read a model by the given index. Need to be implemented for every child class
 | 
			
		||||
        of ShapeNetBase.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            idx: The idx of the model to be retrieved in the dataset.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dictionary containing information about the model.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError(
 | 
			
		||||
            "__getitem__ should be implemented in the child class of ShapeNetBase"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _get_item_ids(self, idx) -> Dict:
 | 
			
		||||
        """
 | 
			
		||||
        Read a model by the given index.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            idx: The idx of the model to be retrieved in the dataset.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dictionary with following keys:
 | 
			
		||||
            - synset_id (str): synset id
 | 
			
		||||
            - model_id (str): model id
 | 
			
		||||
        """
 | 
			
		||||
        model = {}
 | 
			
		||||
        model["synset_id"] = self.synset_ids[idx]
 | 
			
		||||
        model["model_id"] = self.model_ids[idx]
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def render(
 | 
			
		||||
        self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Renders a model by the given index.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            idx: The index of model to be rendered in the dataset.
 | 
			
		||||
            shader_type: select shading. Valid options include HardPhongShader (default),
 | 
			
		||||
                SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
 | 
			
		||||
                SoftSilhouetteShader.
 | 
			
		||||
            device: torch.device on which the tensors should be located.
 | 
			
		||||
            **kwargs: Accepts any of the kwargs that the renderer supports.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Rendered image of shape (1, H, W, 3).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        model = self.__getitem__(idx)
 | 
			
		||||
        verts, faces = model["verts"], model["faces"]
 | 
			
		||||
        verts_rgb = torch.ones_like(verts, device=device)[None]
 | 
			
		||||
        mesh = Meshes(
 | 
			
		||||
            verts=[verts.to(device)],
 | 
			
		||||
            faces=[faces.to(device)],
 | 
			
		||||
            textures=Textures(verts_rgb=verts_rgb.to(device)),
 | 
			
		||||
        )
 | 
			
		||||
        cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(
 | 
			
		||||
                cameras=cameras,
 | 
			
		||||
                raster_settings=kwargs.get("raster_settings", RasterizationSettings()),
 | 
			
		||||
            ),
 | 
			
		||||
            shader=shader_type(
 | 
			
		||||
                device=device,
 | 
			
		||||
                cameras=cameras,
 | 
			
		||||
                lights=kwargs.get("lights", PointLights()).to(device),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return renderer(mesh)
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_shapenet_core_render_piano.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_shapenet_core_render_piano.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 3.2 KiB  | 
@ -6,17 +6,32 @@ import os
 | 
			
		||||
import random
 | 
			
		||||
import unittest
 | 
			
		||||
import warnings
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from common_testing import TestCaseMixin, load_rgb_image
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.datasets import ShapeNetCore
 | 
			
		||||
from pytorch3d.renderer import (
 | 
			
		||||
    OpenGLPerspectiveCameras,
 | 
			
		||||
    PointLights,
 | 
			
		||||
    RasterizationSettings,
 | 
			
		||||
    look_at_view_transform,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SHAPENET_PATH = None
 | 
			
		||||
# If DEBUG=True, save out images generated in the tests for debugging.
 | 
			
		||||
# All saved images have prefix DEBUG_
 | 
			
		||||
DEBUG = False
 | 
			
		||||
DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_load_shapenet_core(self):
 | 
			
		||||
        # Setup
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        # The ShapeNet dataset is not provided in the repo.
 | 
			
		||||
        # Download this separately and update the `shapenet_path`
 | 
			
		||||
@ -31,7 +46,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            warnings.warn(msg)
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        # Try load ShapeNetCore with an invalid version number and catch error.
 | 
			
		||||
        # Try loading ShapeNetCore with an invalid version number and catch error.
 | 
			
		||||
        with self.assertRaises(ValueError) as err:
 | 
			
		||||
            ShapeNetCore(SHAPENET_PATH, version=3)
 | 
			
		||||
        self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
 | 
			
		||||
@ -93,3 +108,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            for offset in subset_offsets
 | 
			
		||||
        ]
 | 
			
		||||
        self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
 | 
			
		||||
 | 
			
		||||
        # Render the first image in the piano category.
 | 
			
		||||
        R, T = look_at_view_transform(1.0, 1.0, 90)
 | 
			
		||||
        piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"])
 | 
			
		||||
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
 | 
			
		||||
        raster_settings = RasterizationSettings(image_size=512)
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
 | 
			
		||||
            # TODO: debug the source of the discrepancy in two images when rendering on GPU.
 | 
			
		||||
            diffuse_color=((0, 0, 0),),
 | 
			
		||||
            specular_color=((0, 0, 0),),
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        images = piano_dataset.render(
 | 
			
		||||
            0,
 | 
			
		||||
            device=device,
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            raster_settings=raster_settings,
 | 
			
		||||
            lights=lights,
 | 
			
		||||
        )
 | 
			
		||||
        rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                DATA_DIR / "DEBUG_shapenet_core_render_piano.png"
 | 
			
		||||
            )
 | 
			
		||||
        image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR)
 | 
			
		||||
        self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user