Render objects in a batch by the specified model_ids, categories or idxs for ShapeNetBase

Summary: Additional functionality for renderer in ShapeNetCore: users can select which objects to render by specifying their model_ids, or users could choose to render several random objects in some categories, or users could specify indices of the objects in the loaded dataset. (currently doesn't support changing lighting, still investigating why lighting is causing instability in renderings)

Reviewed By: nikhilaravi

Differential Revision: D22179594

fbshipit-source-id: 74c49094ffa3ea2eb71de9451f9e5da5053d356d
This commit is contained in:
Luya Gao
2020-07-14 14:52:21 -07:00
committed by Facebook GitHub Bot
parent 358e211cde
commit 22f2963cf1
12 changed files with 308 additions and 47 deletions

View File

@@ -1,4 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .shapenet import ShapeNetCore

View File

@@ -41,7 +41,7 @@ class ShapeNetCore(ShapeNetBase):
"""
super().__init__()
self.data_dir = data_dir
self.shapenet_dir = data_dir
if version not in [1, 2]:
raise ValueError("Version number must be either 1 or 2.")
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
@@ -100,6 +100,7 @@ class ShapeNetCore(ShapeNetBase):
# Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id.
for synset in synset_set:
self.synset_starts[synset] = len(self.synset_ids)
for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
msg = (
@@ -110,6 +111,7 @@ class ShapeNetCore(ShapeNetBase):
continue
self.synset_ids.append(synset)
self.model_ids.append(model)
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
def __getitem__(self, idx: int) -> Dict:
"""
@@ -128,7 +130,7 @@ class ShapeNetCore(ShapeNetBase):
"""
model = self._get_item_ids(idx)
model_path = path.join(
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx

View File

@@ -1,8 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Dict
import warnings
from os import path
from typing import Dict, List, Optional
import torch
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
HardPhongShader,
MeshRasterizer,
@@ -11,7 +14,7 @@ from pytorch3d.renderer import (
PointLights,
RasterizationSettings,
)
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures import Textures
class ShapeNetBase(torch.utils.data.Dataset):
@@ -27,6 +30,11 @@ class ShapeNetBase(torch.utils.data.Dataset):
"""
self.synset_ids = []
self.model_ids = []
self.synset_inv = {}
self.synset_starts = {}
self.synset_lens = {}
self.shapenet_dir = ""
self.model_dir = ""
def __len__(self):
"""
@@ -67,30 +75,46 @@ class ShapeNetBase(torch.utils.data.Dataset):
return model
def render(
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
self,
model_ids: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
sample_nums: Optional[List[int]] = None,
idxs: Optional[List[int]] = None,
shader_type=HardPhongShader,
device="cpu",
**kwargs
) -> torch.Tensor:
"""
Renders a model by the given index.
If a list of model_ids are supplied, render all the objects by the given model_ids.
If no model_ids are supplied, but categories and sample_nums are specified, randomly
select a number of objects (number specified in sample_nums) in the given categories
and render these objects. If instead a list of idxs is specified, check if the idxs
are all valid and render models by the given idxs. Otherwise, randomly select a number
(first number in sample_nums, default is set to be 1) of models from the loaded dataset
and render these models.
Args:
idx: The index of model to be rendered in the dataset.
shader_type: select shading. Valid options include HardPhongShader (default),
model_ids: List[str] of model_ids of models intended to be rendered.
categories: List[str] of categories intended to be rendered. categories
and sample_nums must be specified at the same time. categories can be given
in the form of synset offsets or labels, or a combination of both.
sample_nums: List[int] of number of models to be randomly sampled from
each category. Could also contain one single integer, in which case it
will be broadcasted for every category.
idxs: List[int] of indices of models to be rendered in the dataset.
shader_type: Select shading. Valid options include HardPhongShader (default),
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
SoftSilhouetteShader.
device: torch.device on which the tensors should be located.
**kwargs: Accepts any of the kwargs that the renderer supports.
Returns:
Rendered image of shape (1, H, W, 3).
Batch of rendered images of shape (N, H, W, 3).
"""
model = self.__getitem__(idx)
verts, faces = model["verts"], model["faces"]
verts_rgb = torch.ones_like(verts, device=device)[None]
mesh = Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=Textures(verts_rgb=verts_rgb.to(device)),
paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
meshes.textures = Textures(
verts_rgb=torch.ones_like(meshes.verts_padded(), device=device)
)
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
renderer = MeshRenderer(
@@ -104,4 +128,125 @@ class ShapeNetBase(torch.utils.data.Dataset):
lights=kwargs.get("lights", PointLights()).to(device),
),
)
return renderer(mesh)
return renderer(meshes)
def _handle_render_inputs(
self,
model_ids: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
sample_nums: Optional[List[int]] = None,
idxs: Optional[List[int]] = None,
) -> List[str]:
"""
Helper function for converting user provided model_ids, categories and sample_nums
to indices of models in the loaded dataset. If model idxs are provided, we check if
the idxs are valid. If no models are specified, the first model in the loaded dataset
is chosen. The function returns the file paths to the selected models.
Args:
model_ids: List[str] of model_ids of models to be rendered.
categories: List[str] of categories to be rendered.
sample_nums: List[int] of number of models to be randomly sampled from
each category.
idxs: List[int] of indices of models to be rendered in the dataset.
Returns:
List of paths of models to be rendered.
"""
# Get corresponding indices if model_ids are supplied.
if model_ids is not None and len(model_ids) > 0:
idxs = []
for model_id in model_ids:
if model_id not in self.model_ids:
raise ValueError(
"model_id %s not found in the loaded dataset." % model_id
)
idxs.append(self.model_ids.index(model_id))
# Sample random models if categories and sample_nums are supplied and get
# the corresponding indices.
elif categories is not None and len(categories) > 0:
sample_nums = [1] if sample_nums is None else sample_nums
if len(categories) != len(sample_nums) and len(sample_nums) != 1:
raise ValueError(
"categories and sample_nums needs to be of the same length or "
"sample_nums needs to be of length 1."
)
idxs_tensor = torch.empty(0, dtype=torch.int32)
for i in range(len(categories)):
category = self.synset_inv.get(categories[i], categories[i])
if category not in self.synset_inv.values():
raise ValueError(
"Category %s is not in the loaded dataset." % category
)
# Broadcast if sample_nums has length of 1.
sample_num = sample_nums[i] if len(sample_nums) > 1 else sample_nums[0]
sampled_idxs = self._sample_idxs_from_category(
sample_num=sample_num, category=category
)
idxs_tensor = torch.cat((idxs_tensor, sampled_idxs))
idxs = idxs_tensor.tolist()
# Check if the indices are valid if idxs are supplied.
elif idxs is not None and len(idxs) > 0:
if any(idx < 0 or idx >= len(self.model_ids) for idx in idxs):
raise IndexError(
"One or more idx values are out of bounds. Indices need to be"
"between 0 and %s." % (len(self.model_ids) - 1)
)
# Check if sample_nums is specified, if so sample sample_nums[0] number
# of indices from the entire loaded dataset. Otherwise randomly select one
# index from the dataset.
else:
sample_nums = [1] if sample_nums is None else sample_nums
if len(sample_nums) > 1:
msg = (
"More than one sample sizes specified, now sampling "
"%d models from the dataset." % sample_nums[0]
)
warnings.warn(msg)
idxs = self._sample_idxs_from_category(sample_nums[0])
return [
path.join(
self.shapenet_dir,
self.synset_ids[idx],
self.model_ids[idx],
self.model_dir,
)
for idx in idxs
]
def _sample_idxs_from_category(
self, sample_num: int = 1, category: Optional[str] = None
) -> List[int]:
"""
Helper function for sampling a number of indices from the given category.
Args:
sample_num: number of indicies to be sampled from the given category.
category: category synset of the category to be sampled from. If not
specified, sample from all models in the loaded dataset.
"""
start = self.synset_starts[category] if category is not None else 0
range_len = (
self.synset_lens[category] if category is not None else self.__len__()
)
replacement = sample_num > range_len
sampled_idxs = (
torch.multinomial(
torch.ones((range_len), dtype=torch.float32),
sample_num,
replacement=replacement,
)
+ start
)
if replacement:
msg = (
"Sample size %d is larger than the number of objects in %s, "
"values sampled with replacement."
) % (
sample_num,
"category " + category if category is not None else "all categories",
)
warnings.warn(msg)
return sampled_idxs