clean renderer args

Summary: continued - don't duplicate inputs

Reviewed By: kjchalup

Differential Revision: D38248829

fbshipit-source-id: 2d56418ecbec9cc597c3cf0c122199e274661516
This commit is contained in:
Jeremy Reizenstein 2022-08-03 12:37:31 -07:00 committed by Facebook GitHub Bot
parent f45893b845
commit 078846d166
6 changed files with 47 additions and 39 deletions

View File

@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
line_step_iters: 3 line_step_iters: 3
n_secant_steps: 8 n_secant_steps: 8
n_steps: 100 n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05 sdf_threshold: 5.0e-05
ray_normal_coloring_network_args: ray_normal_coloring_network_args:
d_in: 9 d_in: 9

View File

@ -36,7 +36,6 @@ model_factory_ImplicitronModelFactory_args:
line_step_iters: 3 line_step_iters: 3
n_secant_steps: 8 n_secant_steps: 8
n_steps: 100 n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05 sdf_threshold: 5.0e-05
ray_normal_coloring_network_args: ray_normal_coloring_network_args:
d_in: 9 d_in: 9

View File

@ -241,15 +241,6 @@ model_factory_ImplicitronModelFactory_args:
density_relu: true density_relu: true
blend_output: false blend_output: false
renderer_SignedDistanceFunctionRenderer_args: renderer_SignedDistanceFunctionRenderer_args:
render_features_dimensions: 3
ray_tracer_args:
object_bounding_sphere: 1.0
sdf_threshold: 5.0e-05
line_search_step: 0.5
line_step_iters: 1
sphere_tracing_iters: 10
n_steps: 100
n_secant_steps: 8
ray_normal_coloring_network_args: ray_normal_coloring_network_args:
feature_vector_size: 3 feature_vector_size: 3
mode: idr mode: idr
@ -266,6 +257,13 @@ model_factory_ImplicitronModelFactory_args:
bg_color: bg_color:
- 0.0 - 0.0
soft_mask_alpha: 50.0 soft_mask_alpha: 50.0
ray_tracer_args:
sdf_threshold: 5.0e-05
line_search_step: 0.5
line_step_iters: 1
sphere_tracing_iters: 10
n_steps: 100
n_secant_steps: 8
image_feature_extractor_ResNetFeatureExtractor_args: image_feature_extractor_ResNetFeatureExtractor_args:
name: resnet34 name: resnet34
pretrained: true pretrained: true

View File

@ -641,35 +641,32 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
**raysampler_args, **extra_args **raysampler_args, **extra_args
) )
@classmethod
def renderer_tweak_args(cls, type, args: DictConfig) -> None:
"""
We don't expose certain fields of the renderer because we want to set
them based on other inputs.
"""
args.pop("render_features_dimensions", None)
args.pop("object_bounding_sphere", None)
def create_renderer(self): def create_renderer(self):
raysampler_args = getattr( extra_args = {}
self, "raysampler_" + self.raysampler_class_type + "_args"
)
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_training"
] = raysampler_args["stratified_point_sampling_training"]
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_evaluation"
] = raysampler_args["stratified_point_sampling_evaluation"]
self.renderer_SignedDistanceFunctionRenderer_args[
"render_features_dimensions"
] = self.render_features_dimensions
if self.renderer_class_type == "SignedDistanceFunctionRenderer": if self.renderer_class_type == "SignedDistanceFunctionRenderer":
if "scene_extent" not in raysampler_args: extra_args["render_features_dimensions"] = self.render_features_dimensions
if not hasattr(self.raysampler, "scene_extent"):
raise ValueError( raise ValueError(
"SignedDistanceFunctionRenderer requires" "SignedDistanceFunctionRenderer requires"
+ " a raysampler that defines the 'scene_extent' field" + " a raysampler that defines the 'scene_extent' field"
+ " (this field is supported by, e.g., the adaptive raysampler - " + " (this field is supported by, e.g., the adaptive raysampler - "
+ " self.raysampler_class_type='AdaptiveRaySampler')." + " self.raysampler_class_type='AdaptiveRaySampler')."
) )
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[ extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
"object_bounding_sphere"
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args") renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)( self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
**renderer_args **renderer_args, **extra_args
) )
def create_implicit_function(self) -> None: def create_implicit_function(self) -> None:

View File

@ -53,10 +53,12 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
fine rendering pass during training. fine rendering pass during training.
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
fine rendering pass during evaluation. fine rendering pass during evaluation.
stratified_sampling_coarse_training: Enable/disable stratified sampling during stratified_sampling_coarse_training: Enable/disable stratified sampling in the
training. refiner during training. Only matters if there are multiple implicit
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during functions (i.e. in GenericModel if num_passes>1).
evaluation. stratified_sampling_coarse_evaluation: Enable/disable stratified sampling in
the refiner during evaluation. Only matters if there are multiple implicit
functions (i.e. in GenericModel if num_passes>1).
append_coarse_samples_to_fine: Add the fine ray points to the coarse points append_coarse_samples_to_fine: Add the fine ray points to the coarse points
after sampling. after sampling.
density_noise_std_train: Standard deviation of the noise added to the density_noise_std_train: Standard deviation of the noise added to the

View File

@ -8,7 +8,11 @@ from typing import List, Optional, Tuple
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, registry from pytorch3d.implicitron.tools.config import (
get_default_args_field,
registry,
run_auto_creation,
)
from pytorch3d.implicitron.tools.utils import evaluating from pytorch3d.implicitron.tools.utils import evaluating
from pytorch3d.renderer import RayBundle from pytorch3d.renderer import RayBundle
@ -18,9 +22,10 @@ from .rgb_net import RayNormalColoringNetwork
@registry.register @registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ignore[13]
render_features_dimensions: int = 3 render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing) object_bounding_sphere: float = 1.0
ray_tracer: RayTracing
ray_normal_coloring_network_args: DictConfig = get_default_args_field( ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork RayNormalColoringNetwork
) )
@ -37,8 +42,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
f"Background color should have {render_features_dimensions} entries." f"Background color should have {render_features_dimensions} entries."
) )
self.ray_tracer = RayTracing(**self.ray_tracer_args) run_auto_creation(self)
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
self.ray_normal_coloring_network_args[ self.ray_normal_coloring_network_args[
"feature_vector_size" "feature_vector_size"
@ -49,6 +53,17 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False) self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
@classmethod
def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None:
del args["object_bounding_sphere"]
def create_ray_tracer(self) -> None:
self.ray_tracer = RayTracing(
# pyre-ignore[32]
**self.ray_tracer_args,
object_bounding_sphere=self.object_bounding_sphere,
)
def requires_object_mask(self) -> bool: def requires_object_mask(self) -> bool:
return True return True
@ -97,7 +112,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
object_mask = object_mask.reshape(batch_size, -1) object_mask = object_mask.reshape(batch_size, -1)
with torch.no_grad(), evaluating(implicit_function): with torch.no_grad(), evaluating(implicit_function):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
points, network_object_mask, dists = self.ray_tracer( points, network_object_mask, dists = self.ray_tracer(
sdf=lambda x: implicit_function(x)[ sdf=lambda x: implicit_function(x)[
:, 0 :, 0
@ -128,7 +142,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
N = surface_points.shape[0] N = surface_points.shape[0]
# Sample points for the eikonal loss # Sample points for the eikonal loss
# pyre-fixme[9]
eik_bounding_box: float = self.object_bounding_sphere eik_bounding_box: float = self.object_bounding_sphere
n_eik_points = batch_size * num_pixels // 2 n_eik_points = batch_size * num_pixels // 2
eikonal_points = torch.empty( eikonal_points = torch.empty(