mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
clean renderer args
Summary: continued - don't duplicate inputs Reviewed By: kjchalup Differential Revision: D38248829 fbshipit-source-id: 2d56418ecbec9cc597c3cf0c122199e274661516
This commit is contained in:
parent
f45893b845
commit
078846d166
@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
line_step_iters: 3
|
line_step_iters: 3
|
||||||
n_secant_steps: 8
|
n_secant_steps: 8
|
||||||
n_steps: 100
|
n_steps: 100
|
||||||
object_bounding_sphere: 8.0
|
|
||||||
sdf_threshold: 5.0e-05
|
sdf_threshold: 5.0e-05
|
||||||
ray_normal_coloring_network_args:
|
ray_normal_coloring_network_args:
|
||||||
d_in: 9
|
d_in: 9
|
||||||
|
@ -36,7 +36,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
line_step_iters: 3
|
line_step_iters: 3
|
||||||
n_secant_steps: 8
|
n_secant_steps: 8
|
||||||
n_steps: 100
|
n_steps: 100
|
||||||
object_bounding_sphere: 8.0
|
|
||||||
sdf_threshold: 5.0e-05
|
sdf_threshold: 5.0e-05
|
||||||
ray_normal_coloring_network_args:
|
ray_normal_coloring_network_args:
|
||||||
d_in: 9
|
d_in: 9
|
||||||
|
@ -241,15 +241,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
density_relu: true
|
density_relu: true
|
||||||
blend_output: false
|
blend_output: false
|
||||||
renderer_SignedDistanceFunctionRenderer_args:
|
renderer_SignedDistanceFunctionRenderer_args:
|
||||||
render_features_dimensions: 3
|
|
||||||
ray_tracer_args:
|
|
||||||
object_bounding_sphere: 1.0
|
|
||||||
sdf_threshold: 5.0e-05
|
|
||||||
line_search_step: 0.5
|
|
||||||
line_step_iters: 1
|
|
||||||
sphere_tracing_iters: 10
|
|
||||||
n_steps: 100
|
|
||||||
n_secant_steps: 8
|
|
||||||
ray_normal_coloring_network_args:
|
ray_normal_coloring_network_args:
|
||||||
feature_vector_size: 3
|
feature_vector_size: 3
|
||||||
mode: idr
|
mode: idr
|
||||||
@ -266,6 +257,13 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
bg_color:
|
bg_color:
|
||||||
- 0.0
|
- 0.0
|
||||||
soft_mask_alpha: 50.0
|
soft_mask_alpha: 50.0
|
||||||
|
ray_tracer_args:
|
||||||
|
sdf_threshold: 5.0e-05
|
||||||
|
line_search_step: 0.5
|
||||||
|
line_step_iters: 1
|
||||||
|
sphere_tracing_iters: 10
|
||||||
|
n_steps: 100
|
||||||
|
n_secant_steps: 8
|
||||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
name: resnet34
|
name: resnet34
|
||||||
pretrained: true
|
pretrained: true
|
||||||
|
@ -641,35 +641,32 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
**raysampler_args, **extra_args
|
**raysampler_args, **extra_args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def renderer_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain fields of the renderer because we want to set
|
||||||
|
them based on other inputs.
|
||||||
|
"""
|
||||||
|
args.pop("render_features_dimensions", None)
|
||||||
|
args.pop("object_bounding_sphere", None)
|
||||||
|
|
||||||
def create_renderer(self):
|
def create_renderer(self):
|
||||||
raysampler_args = getattr(
|
extra_args = {}
|
||||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
|
||||||
)
|
|
||||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
|
||||||
"stratified_sampling_coarse_training"
|
|
||||||
] = raysampler_args["stratified_point_sampling_training"]
|
|
||||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
|
||||||
"stratified_sampling_coarse_evaluation"
|
|
||||||
] = raysampler_args["stratified_point_sampling_evaluation"]
|
|
||||||
self.renderer_SignedDistanceFunctionRenderer_args[
|
|
||||||
"render_features_dimensions"
|
|
||||||
] = self.render_features_dimensions
|
|
||||||
|
|
||||||
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||||
if "scene_extent" not in raysampler_args:
|
extra_args["render_features_dimensions"] = self.render_features_dimensions
|
||||||
|
if not hasattr(self.raysampler, "scene_extent"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"SignedDistanceFunctionRenderer requires"
|
"SignedDistanceFunctionRenderer requires"
|
||||||
+ " a raysampler that defines the 'scene_extent' field"
|
+ " a raysampler that defines the 'scene_extent' field"
|
||||||
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||||
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||||
)
|
)
|
||||||
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
|
||||||
"object_bounding_sphere"
|
|
||||||
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
|
|
||||||
|
|
||||||
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||||
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||||
**renderer_args
|
**renderer_args, **extra_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_implicit_function(self) -> None:
|
def create_implicit_function(self) -> None:
|
||||||
|
@ -53,10 +53,12 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
fine rendering pass during training.
|
fine rendering pass during training.
|
||||||
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
|
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
|
||||||
fine rendering pass during evaluation.
|
fine rendering pass during evaluation.
|
||||||
stratified_sampling_coarse_training: Enable/disable stratified sampling during
|
stratified_sampling_coarse_training: Enable/disable stratified sampling in the
|
||||||
training.
|
refiner during training. Only matters if there are multiple implicit
|
||||||
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
|
functions (i.e. in GenericModel if num_passes>1).
|
||||||
evaluation.
|
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling in
|
||||||
|
the refiner during evaluation. Only matters if there are multiple implicit
|
||||||
|
functions (i.e. in GenericModel if num_passes>1).
|
||||||
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
||||||
after sampling.
|
after sampling.
|
||||||
density_noise_std_train: Standard deviation of the noise added to the
|
density_noise_std_train: Standard deviation of the noise added to the
|
||||||
|
@ -8,7 +8,11 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
get_default_args_field,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.tools.utils import evaluating
|
from pytorch3d.implicitron.tools.utils import evaluating
|
||||||
from pytorch3d.renderer import RayBundle
|
from pytorch3d.renderer import RayBundle
|
||||||
|
|
||||||
@ -18,9 +22,10 @@ from .rgb_net import RayNormalColoringNetwork
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ignore[13]
|
||||||
render_features_dimensions: int = 3
|
render_features_dimensions: int = 3
|
||||||
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
object_bounding_sphere: float = 1.0
|
||||||
|
ray_tracer: RayTracing
|
||||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||||
RayNormalColoringNetwork
|
RayNormalColoringNetwork
|
||||||
)
|
)
|
||||||
@ -37,8 +42,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
f"Background color should have {render_features_dimensions} entries."
|
f"Background color should have {render_features_dimensions} entries."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ray_tracer = RayTracing(**self.ray_tracer_args)
|
run_auto_creation(self)
|
||||||
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
|
|
||||||
|
|
||||||
self.ray_normal_coloring_network_args[
|
self.ray_normal_coloring_network_args[
|
||||||
"feature_vector_size"
|
"feature_vector_size"
|
||||||
@ -49,6 +53,17 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
|
|
||||||
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
del args["object_bounding_sphere"]
|
||||||
|
|
||||||
|
def create_ray_tracer(self) -> None:
|
||||||
|
self.ray_tracer = RayTracing(
|
||||||
|
# pyre-ignore[32]
|
||||||
|
**self.ray_tracer_args,
|
||||||
|
object_bounding_sphere=self.object_bounding_sphere,
|
||||||
|
)
|
||||||
|
|
||||||
def requires_object_mask(self) -> bool:
|
def requires_object_mask(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -97,7 +112,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
object_mask = object_mask.reshape(batch_size, -1)
|
object_mask = object_mask.reshape(batch_size, -1)
|
||||||
|
|
||||||
with torch.no_grad(), evaluating(implicit_function):
|
with torch.no_grad(), evaluating(implicit_function):
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
|
||||||
points, network_object_mask, dists = self.ray_tracer(
|
points, network_object_mask, dists = self.ray_tracer(
|
||||||
sdf=lambda x: implicit_function(x)[
|
sdf=lambda x: implicit_function(x)[
|
||||||
:, 0
|
:, 0
|
||||||
@ -128,7 +142,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
N = surface_points.shape[0]
|
N = surface_points.shape[0]
|
||||||
|
|
||||||
# Sample points for the eikonal loss
|
# Sample points for the eikonal loss
|
||||||
# pyre-fixme[9]
|
|
||||||
eik_bounding_box: float = self.object_bounding_sphere
|
eik_bounding_box: float = self.object_bounding_sphere
|
||||||
n_eik_points = batch_size * num_pixels // 2
|
n_eik_points = batch_size * num_pixels // 2
|
||||||
eikonal_points = torch.empty(
|
eikonal_points = torch.empty(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user