mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Render objects in a batch by the specified model_ids, categories or idxs for ShapeNetBase
Summary: Additional functionality for renderer in ShapeNetCore: users can select which objects to render by specifying their model_ids, or users could choose to render several random objects in some categories, or users could specify indices of the objects in the loaded dataset. (currently doesn't support changing lighting, still investigating why lighting is causing instability in renderings) Reviewed By: nikhilaravi Differential Revision: D22179594 fbshipit-source-id: 74c49094ffa3ea2eb71de9451f9e5da5053d356d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
358e211cde
commit
22f2963cf1
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .shapenet import ShapeNetCore
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ShapeNetCore(ShapeNetBase):
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self.shapenet_dir = data_dir
|
||||
if version not in [1, 2]:
|
||||
raise ValueError("Version number must be either 1 or 2.")
|
||||
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
|
||||
@@ -100,6 +100,7 @@ class ShapeNetCore(ShapeNetBase):
|
||||
# Each grandchildren directory of data_dir contains an object, and the name
|
||||
# of the directory is the object's model_id.
|
||||
for synset in synset_set:
|
||||
self.synset_starts[synset] = len(self.synset_ids)
|
||||
for model in os.listdir(path.join(data_dir, synset)):
|
||||
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
||||
msg = (
|
||||
@@ -110,6 +111,7 @@ class ShapeNetCore(ShapeNetBase):
|
||||
continue
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""
|
||||
@@ -128,7 +130,7 @@ class ShapeNetCore(ShapeNetBase):
|
||||
"""
|
||||
model = self._get_item_ids(idx)
|
||||
model_path = path.join(
|
||||
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||
self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Dict
|
||||
import warnings
|
||||
from os import path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.renderer import (
|
||||
HardPhongShader,
|
||||
MeshRasterizer,
|
||||
@@ -11,7 +14,7 @@ from pytorch3d.renderer import (
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures
|
||||
from pytorch3d.structures import Textures
|
||||
|
||||
|
||||
class ShapeNetBase(torch.utils.data.Dataset):
|
||||
@@ -27,6 +30,11 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
"""
|
||||
self.synset_ids = []
|
||||
self.model_ids = []
|
||||
self.synset_inv = {}
|
||||
self.synset_starts = {}
|
||||
self.synset_lens = {}
|
||||
self.shapenet_dir = ""
|
||||
self.model_dir = ""
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
@@ -67,30 +75,46 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
return model
|
||||
|
||||
def render(
|
||||
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
|
||||
self,
|
||||
model_ids: Optional[List[str]] = None,
|
||||
categories: Optional[List[str]] = None,
|
||||
sample_nums: Optional[List[int]] = None,
|
||||
idxs: Optional[List[int]] = None,
|
||||
shader_type=HardPhongShader,
|
||||
device="cpu",
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Renders a model by the given index.
|
||||
If a list of model_ids are supplied, render all the objects by the given model_ids.
|
||||
If no model_ids are supplied, but categories and sample_nums are specified, randomly
|
||||
select a number of objects (number specified in sample_nums) in the given categories
|
||||
and render these objects. If instead a list of idxs is specified, check if the idxs
|
||||
are all valid and render models by the given idxs. Otherwise, randomly select a number
|
||||
(first number in sample_nums, default is set to be 1) of models from the loaded dataset
|
||||
and render these models.
|
||||
|
||||
Args:
|
||||
idx: The index of model to be rendered in the dataset.
|
||||
shader_type: select shading. Valid options include HardPhongShader (default),
|
||||
model_ids: List[str] of model_ids of models intended to be rendered.
|
||||
categories: List[str] of categories intended to be rendered. categories
|
||||
and sample_nums must be specified at the same time. categories can be given
|
||||
in the form of synset offsets or labels, or a combination of both.
|
||||
sample_nums: List[int] of number of models to be randomly sampled from
|
||||
each category. Could also contain one single integer, in which case it
|
||||
will be broadcasted for every category.
|
||||
idxs: List[int] of indices of models to be rendered in the dataset.
|
||||
shader_type: Select shading. Valid options include HardPhongShader (default),
|
||||
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
|
||||
SoftSilhouetteShader.
|
||||
device: torch.device on which the tensors should be located.
|
||||
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||
|
||||
Returns:
|
||||
Rendered image of shape (1, H, W, 3).
|
||||
Batch of rendered images of shape (N, H, W, 3).
|
||||
"""
|
||||
|
||||
model = self.__getitem__(idx)
|
||||
verts, faces = model["verts"], model["faces"]
|
||||
verts_rgb = torch.ones_like(verts, device=device)[None]
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)],
|
||||
faces=[faces.to(device)],
|
||||
textures=Textures(verts_rgb=verts_rgb.to(device)),
|
||||
paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
||||
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
|
||||
meshes.textures = Textures(
|
||||
verts_rgb=torch.ones_like(meshes.verts_padded(), device=device)
|
||||
)
|
||||
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
||||
renderer = MeshRenderer(
|
||||
@@ -104,4 +128,125 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
lights=kwargs.get("lights", PointLights()).to(device),
|
||||
),
|
||||
)
|
||||
return renderer(mesh)
|
||||
return renderer(meshes)
|
||||
|
||||
def _handle_render_inputs(
|
||||
self,
|
||||
model_ids: Optional[List[str]] = None,
|
||||
categories: Optional[List[str]] = None,
|
||||
sample_nums: Optional[List[int]] = None,
|
||||
idxs: Optional[List[int]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Helper function for converting user provided model_ids, categories and sample_nums
|
||||
to indices of models in the loaded dataset. If model idxs are provided, we check if
|
||||
the idxs are valid. If no models are specified, the first model in the loaded dataset
|
||||
is chosen. The function returns the file paths to the selected models.
|
||||
|
||||
Args:
|
||||
model_ids: List[str] of model_ids of models to be rendered.
|
||||
categories: List[str] of categories to be rendered.
|
||||
sample_nums: List[int] of number of models to be randomly sampled from
|
||||
each category.
|
||||
idxs: List[int] of indices of models to be rendered in the dataset.
|
||||
|
||||
Returns:
|
||||
List of paths of models to be rendered.
|
||||
"""
|
||||
# Get corresponding indices if model_ids are supplied.
|
||||
if model_ids is not None and len(model_ids) > 0:
|
||||
idxs = []
|
||||
for model_id in model_ids:
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
"model_id %s not found in the loaded dataset." % model_id
|
||||
)
|
||||
idxs.append(self.model_ids.index(model_id))
|
||||
|
||||
# Sample random models if categories and sample_nums are supplied and get
|
||||
# the corresponding indices.
|
||||
elif categories is not None and len(categories) > 0:
|
||||
sample_nums = [1] if sample_nums is None else sample_nums
|
||||
if len(categories) != len(sample_nums) and len(sample_nums) != 1:
|
||||
raise ValueError(
|
||||
"categories and sample_nums needs to be of the same length or "
|
||||
"sample_nums needs to be of length 1."
|
||||
)
|
||||
|
||||
idxs_tensor = torch.empty(0, dtype=torch.int32)
|
||||
for i in range(len(categories)):
|
||||
category = self.synset_inv.get(categories[i], categories[i])
|
||||
if category not in self.synset_inv.values():
|
||||
raise ValueError(
|
||||
"Category %s is not in the loaded dataset." % category
|
||||
)
|
||||
# Broadcast if sample_nums has length of 1.
|
||||
sample_num = sample_nums[i] if len(sample_nums) > 1 else sample_nums[0]
|
||||
sampled_idxs = self._sample_idxs_from_category(
|
||||
sample_num=sample_num, category=category
|
||||
)
|
||||
idxs_tensor = torch.cat((idxs_tensor, sampled_idxs))
|
||||
idxs = idxs_tensor.tolist()
|
||||
# Check if the indices are valid if idxs are supplied.
|
||||
elif idxs is not None and len(idxs) > 0:
|
||||
if any(idx < 0 or idx >= len(self.model_ids) for idx in idxs):
|
||||
raise IndexError(
|
||||
"One or more idx values are out of bounds. Indices need to be"
|
||||
"between 0 and %s." % (len(self.model_ids) - 1)
|
||||
)
|
||||
# Check if sample_nums is specified, if so sample sample_nums[0] number
|
||||
# of indices from the entire loaded dataset. Otherwise randomly select one
|
||||
# index from the dataset.
|
||||
else:
|
||||
sample_nums = [1] if sample_nums is None else sample_nums
|
||||
if len(sample_nums) > 1:
|
||||
msg = (
|
||||
"More than one sample sizes specified, now sampling "
|
||||
"%d models from the dataset." % sample_nums[0]
|
||||
)
|
||||
warnings.warn(msg)
|
||||
idxs = self._sample_idxs_from_category(sample_nums[0])
|
||||
return [
|
||||
path.join(
|
||||
self.shapenet_dir,
|
||||
self.synset_ids[idx],
|
||||
self.model_ids[idx],
|
||||
self.model_dir,
|
||||
)
|
||||
for idx in idxs
|
||||
]
|
||||
|
||||
def _sample_idxs_from_category(
|
||||
self, sample_num: int = 1, category: Optional[str] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Helper function for sampling a number of indices from the given category.
|
||||
|
||||
Args:
|
||||
sample_num: number of indicies to be sampled from the given category.
|
||||
category: category synset of the category to be sampled from. If not
|
||||
specified, sample from all models in the loaded dataset.
|
||||
"""
|
||||
start = self.synset_starts[category] if category is not None else 0
|
||||
range_len = (
|
||||
self.synset_lens[category] if category is not None else self.__len__()
|
||||
)
|
||||
replacement = sample_num > range_len
|
||||
sampled_idxs = (
|
||||
torch.multinomial(
|
||||
torch.ones((range_len), dtype=torch.float32),
|
||||
sample_num,
|
||||
replacement=replacement,
|
||||
)
|
||||
+ start
|
||||
)
|
||||
if replacement:
|
||||
msg = (
|
||||
"Sample size %d is larger than the number of objects in %s, "
|
||||
"values sampled with replacement."
|
||||
) % (
|
||||
sample_num,
|
||||
"category " + category if category is not None else "all categories",
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return sampled_idxs
|
||||
|
||||
Reference in New Issue
Block a user