mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
clean renderer args
Summary: continued - don't duplicate inputs Reviewed By: kjchalup Differential Revision: D38248829 fbshipit-source-id: 2d56418ecbec9cc597c3cf0c122199e274661516
This commit is contained in:
parent
f45893b845
commit
078846d166
@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
line_step_iters: 3
|
||||
n_secant_steps: 8
|
||||
n_steps: 100
|
||||
object_bounding_sphere: 8.0
|
||||
sdf_threshold: 5.0e-05
|
||||
ray_normal_coloring_network_args:
|
||||
d_in: 9
|
||||
|
@ -36,7 +36,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
line_step_iters: 3
|
||||
n_secant_steps: 8
|
||||
n_steps: 100
|
||||
object_bounding_sphere: 8.0
|
||||
sdf_threshold: 5.0e-05
|
||||
ray_normal_coloring_network_args:
|
||||
d_in: 9
|
||||
|
@ -241,15 +241,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
density_relu: true
|
||||
blend_output: false
|
||||
renderer_SignedDistanceFunctionRenderer_args:
|
||||
render_features_dimensions: 3
|
||||
ray_tracer_args:
|
||||
object_bounding_sphere: 1.0
|
||||
sdf_threshold: 5.0e-05
|
||||
line_search_step: 0.5
|
||||
line_step_iters: 1
|
||||
sphere_tracing_iters: 10
|
||||
n_steps: 100
|
||||
n_secant_steps: 8
|
||||
ray_normal_coloring_network_args:
|
||||
feature_vector_size: 3
|
||||
mode: idr
|
||||
@ -266,6 +257,13 @@ model_factory_ImplicitronModelFactory_args:
|
||||
bg_color:
|
||||
- 0.0
|
||||
soft_mask_alpha: 50.0
|
||||
ray_tracer_args:
|
||||
sdf_threshold: 5.0e-05
|
||||
line_search_step: 0.5
|
||||
line_step_iters: 1
|
||||
sphere_tracing_iters: 10
|
||||
n_steps: 100
|
||||
n_secant_steps: 8
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
name: resnet34
|
||||
pretrained: true
|
||||
|
@ -641,35 +641,32 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
**raysampler_args, **extra_args
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def renderer_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
We don't expose certain fields of the renderer because we want to set
|
||||
them based on other inputs.
|
||||
"""
|
||||
args.pop("render_features_dimensions", None)
|
||||
args.pop("object_bounding_sphere", None)
|
||||
|
||||
def create_renderer(self):
|
||||
raysampler_args = getattr(
|
||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||
)
|
||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||
"stratified_sampling_coarse_training"
|
||||
] = raysampler_args["stratified_point_sampling_training"]
|
||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||
"stratified_sampling_coarse_evaluation"
|
||||
] = raysampler_args["stratified_point_sampling_evaluation"]
|
||||
self.renderer_SignedDistanceFunctionRenderer_args[
|
||||
"render_features_dimensions"
|
||||
] = self.render_features_dimensions
|
||||
extra_args = {}
|
||||
|
||||
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||
if "scene_extent" not in raysampler_args:
|
||||
extra_args["render_features_dimensions"] = self.render_features_dimensions
|
||||
if not hasattr(self.raysampler, "scene_extent"):
|
||||
raise ValueError(
|
||||
"SignedDistanceFunctionRenderer requires"
|
||||
+ " a raysampler that defines the 'scene_extent' field"
|
||||
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||
)
|
||||
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
||||
"object_bounding_sphere"
|
||||
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
|
||||
extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
|
||||
|
||||
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||
**renderer_args
|
||||
**renderer_args, **extra_args
|
||||
)
|
||||
|
||||
def create_implicit_function(self) -> None:
|
||||
|
@ -53,10 +53,12 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
||||
fine rendering pass during training.
|
||||
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
|
||||
fine rendering pass during evaluation.
|
||||
stratified_sampling_coarse_training: Enable/disable stratified sampling during
|
||||
training.
|
||||
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
|
||||
evaluation.
|
||||
stratified_sampling_coarse_training: Enable/disable stratified sampling in the
|
||||
refiner during training. Only matters if there are multiple implicit
|
||||
functions (i.e. in GenericModel if num_passes>1).
|
||||
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling in
|
||||
the refiner during evaluation. Only matters if there are multiple implicit
|
||||
functions (i.e. in GenericModel if num_passes>1).
|
||||
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
||||
after sampling.
|
||||
density_noise_std_train: Standard deviation of the noise added to the
|
||||
|
@ -8,7 +8,11 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
get_default_args_field,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.utils import evaluating
|
||||
from pytorch3d.renderer import RayBundle
|
||||
|
||||
@ -18,9 +22,10 @@ from .rgb_net import RayNormalColoringNetwork
|
||||
|
||||
|
||||
@registry.register
|
||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ignore[13]
|
||||
render_features_dimensions: int = 3
|
||||
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
||||
object_bounding_sphere: float = 1.0
|
||||
ray_tracer: RayTracing
|
||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||
RayNormalColoringNetwork
|
||||
)
|
||||
@ -37,8 +42,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
f"Background color should have {render_features_dimensions} entries."
|
||||
)
|
||||
|
||||
self.ray_tracer = RayTracing(**self.ray_tracer_args)
|
||||
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
|
||||
run_auto_creation(self)
|
||||
|
||||
self.ray_normal_coloring_network_args[
|
||||
"feature_vector_size"
|
||||
@ -49,6 +53,17 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
|
||||
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
||||
|
||||
@classmethod
|
||||
def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
del args["object_bounding_sphere"]
|
||||
|
||||
def create_ray_tracer(self) -> None:
|
||||
self.ray_tracer = RayTracing(
|
||||
# pyre-ignore[32]
|
||||
**self.ray_tracer_args,
|
||||
object_bounding_sphere=self.object_bounding_sphere,
|
||||
)
|
||||
|
||||
def requires_object_mask(self) -> bool:
|
||||
return True
|
||||
|
||||
@ -97,7 +112,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
object_mask = object_mask.reshape(batch_size, -1)
|
||||
|
||||
with torch.no_grad(), evaluating(implicit_function):
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
points, network_object_mask, dists = self.ray_tracer(
|
||||
sdf=lambda x: implicit_function(x)[
|
||||
:, 0
|
||||
@ -128,7 +142,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
N = surface_points.shape[0]
|
||||
|
||||
# Sample points for the eikonal loss
|
||||
# pyre-fixme[9]
|
||||
eik_bounding_box: float = self.object_bounding_sphere
|
||||
n_eik_points = batch_size * num_pixels // 2
|
||||
eikonal_points = torch.empty(
|
||||
|
Loading…
x
Reference in New Issue
Block a user