mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	clean renderer args
Summary: continued - don't duplicate inputs Reviewed By: kjchalup Differential Revision: D38248829 fbshipit-source-id: 2d56418ecbec9cc597c3cf0c122199e274661516
This commit is contained in:
		
							parent
							
								
									f45893b845
								
							
						
					
					
						commit
						078846d166
					
				@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
        line_step_iters: 3
 | 
					        line_step_iters: 3
 | 
				
			||||||
        n_secant_steps: 8
 | 
					        n_secant_steps: 8
 | 
				
			||||||
        n_steps: 100
 | 
					        n_steps: 100
 | 
				
			||||||
        object_bounding_sphere: 8.0
 | 
					 | 
				
			||||||
        sdf_threshold: 5.0e-05
 | 
					        sdf_threshold: 5.0e-05
 | 
				
			||||||
      ray_normal_coloring_network_args:
 | 
					      ray_normal_coloring_network_args:
 | 
				
			||||||
        d_in: 9
 | 
					        d_in: 9
 | 
				
			||||||
 | 
				
			|||||||
@ -36,7 +36,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
        line_step_iters: 3
 | 
					        line_step_iters: 3
 | 
				
			||||||
        n_secant_steps: 8
 | 
					        n_secant_steps: 8
 | 
				
			||||||
        n_steps: 100
 | 
					        n_steps: 100
 | 
				
			||||||
        object_bounding_sphere: 8.0
 | 
					 | 
				
			||||||
        sdf_threshold: 5.0e-05
 | 
					        sdf_threshold: 5.0e-05
 | 
				
			||||||
      ray_normal_coloring_network_args:
 | 
					      ray_normal_coloring_network_args:
 | 
				
			||||||
        d_in: 9
 | 
					        d_in: 9
 | 
				
			||||||
 | 
				
			|||||||
@ -241,15 +241,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
        density_relu: true
 | 
					        density_relu: true
 | 
				
			||||||
        blend_output: false
 | 
					        blend_output: false
 | 
				
			||||||
    renderer_SignedDistanceFunctionRenderer_args:
 | 
					    renderer_SignedDistanceFunctionRenderer_args:
 | 
				
			||||||
      render_features_dimensions: 3
 | 
					 | 
				
			||||||
      ray_tracer_args:
 | 
					 | 
				
			||||||
        object_bounding_sphere: 1.0
 | 
					 | 
				
			||||||
        sdf_threshold: 5.0e-05
 | 
					 | 
				
			||||||
        line_search_step: 0.5
 | 
					 | 
				
			||||||
        line_step_iters: 1
 | 
					 | 
				
			||||||
        sphere_tracing_iters: 10
 | 
					 | 
				
			||||||
        n_steps: 100
 | 
					 | 
				
			||||||
        n_secant_steps: 8
 | 
					 | 
				
			||||||
      ray_normal_coloring_network_args:
 | 
					      ray_normal_coloring_network_args:
 | 
				
			||||||
        feature_vector_size: 3
 | 
					        feature_vector_size: 3
 | 
				
			||||||
        mode: idr
 | 
					        mode: idr
 | 
				
			||||||
@ -266,6 +257,13 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      bg_color:
 | 
					      bg_color:
 | 
				
			||||||
      - 0.0
 | 
					      - 0.0
 | 
				
			||||||
      soft_mask_alpha: 50.0
 | 
					      soft_mask_alpha: 50.0
 | 
				
			||||||
 | 
					      ray_tracer_args:
 | 
				
			||||||
 | 
					        sdf_threshold: 5.0e-05
 | 
				
			||||||
 | 
					        line_search_step: 0.5
 | 
				
			||||||
 | 
					        line_step_iters: 1
 | 
				
			||||||
 | 
					        sphere_tracing_iters: 10
 | 
				
			||||||
 | 
					        n_steps: 100
 | 
				
			||||||
 | 
					        n_secant_steps: 8
 | 
				
			||||||
    image_feature_extractor_ResNetFeatureExtractor_args:
 | 
					    image_feature_extractor_ResNetFeatureExtractor_args:
 | 
				
			||||||
      name: resnet34
 | 
					      name: resnet34
 | 
				
			||||||
      pretrained: true
 | 
					      pretrained: true
 | 
				
			||||||
 | 
				
			|||||||
@ -641,35 +641,32 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
				
			|||||||
            **raysampler_args, **extra_args
 | 
					            **raysampler_args, **extra_args
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def renderer_tweak_args(cls, type, args: DictConfig) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        We don't expose certain fields of the renderer because we want to set
 | 
				
			||||||
 | 
					        them based on other inputs.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        args.pop("render_features_dimensions", None)
 | 
				
			||||||
 | 
					        args.pop("object_bounding_sphere", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_renderer(self):
 | 
					    def create_renderer(self):
 | 
				
			||||||
        raysampler_args = getattr(
 | 
					        extra_args = {}
 | 
				
			||||||
            self, "raysampler_" + self.raysampler_class_type + "_args"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.renderer_MultiPassEmissionAbsorptionRenderer_args[
 | 
					 | 
				
			||||||
            "stratified_sampling_coarse_training"
 | 
					 | 
				
			||||||
        ] = raysampler_args["stratified_point_sampling_training"]
 | 
					 | 
				
			||||||
        self.renderer_MultiPassEmissionAbsorptionRenderer_args[
 | 
					 | 
				
			||||||
            "stratified_sampling_coarse_evaluation"
 | 
					 | 
				
			||||||
        ] = raysampler_args["stratified_point_sampling_evaluation"]
 | 
					 | 
				
			||||||
        self.renderer_SignedDistanceFunctionRenderer_args[
 | 
					 | 
				
			||||||
            "render_features_dimensions"
 | 
					 | 
				
			||||||
        ] = self.render_features_dimensions
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.renderer_class_type == "SignedDistanceFunctionRenderer":
 | 
					        if self.renderer_class_type == "SignedDistanceFunctionRenderer":
 | 
				
			||||||
            if "scene_extent" not in raysampler_args:
 | 
					            extra_args["render_features_dimensions"] = self.render_features_dimensions
 | 
				
			||||||
 | 
					            if not hasattr(self.raysampler, "scene_extent"):
 | 
				
			||||||
                raise ValueError(
 | 
					                raise ValueError(
 | 
				
			||||||
                    "SignedDistanceFunctionRenderer requires"
 | 
					                    "SignedDistanceFunctionRenderer requires"
 | 
				
			||||||
                    + " a raysampler that defines the 'scene_extent' field"
 | 
					                    + " a raysampler that defines the 'scene_extent' field"
 | 
				
			||||||
                    + " (this field is supported by, e.g., the adaptive raysampler - "
 | 
					                    + " (this field is supported by, e.g., the adaptive raysampler - "
 | 
				
			||||||
                    + " self.raysampler_class_type='AdaptiveRaySampler')."
 | 
					                    + " self.raysampler_class_type='AdaptiveRaySampler')."
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
 | 
					            extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
 | 
				
			||||||
                "object_bounding_sphere"
 | 
					 | 
				
			||||||
            ] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
 | 
					        renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
 | 
				
			||||||
        self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
 | 
					        self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
 | 
				
			||||||
            **renderer_args
 | 
					            **renderer_args, **extra_args
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_implicit_function(self) -> None:
 | 
					    def create_implicit_function(self) -> None:
 | 
				
			||||||
 | 
				
			|||||||
@ -53,10 +53,12 @@ class MultiPassEmissionAbsorptionRenderer(  # pyre-ignore: 13
 | 
				
			|||||||
            fine rendering pass during training.
 | 
					            fine rendering pass during training.
 | 
				
			||||||
        n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
 | 
					        n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
 | 
				
			||||||
            fine rendering pass during evaluation.
 | 
					            fine rendering pass during evaluation.
 | 
				
			||||||
        stratified_sampling_coarse_training: Enable/disable stratified sampling during
 | 
					        stratified_sampling_coarse_training: Enable/disable stratified sampling in the
 | 
				
			||||||
            training.
 | 
					            refiner during training. Only matters if there are multiple implicit
 | 
				
			||||||
        stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
 | 
					            functions (i.e. in GenericModel if num_passes>1).
 | 
				
			||||||
            evaluation.
 | 
					        stratified_sampling_coarse_evaluation: Enable/disable stratified sampling in
 | 
				
			||||||
 | 
					            the refiner during evaluation. Only matters if there are multiple implicit
 | 
				
			||||||
 | 
					            functions (i.e. in GenericModel if num_passes>1).
 | 
				
			||||||
        append_coarse_samples_to_fine: Add the fine ray points to the coarse points
 | 
					        append_coarse_samples_to_fine: Add the fine ray points to the coarse points
 | 
				
			||||||
            after sampling.
 | 
					            after sampling.
 | 
				
			||||||
        density_noise_std_train: Standard deviation of the noise added to the
 | 
					        density_noise_std_train: Standard deviation of the noise added to the
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,11 @@ from typing import List, Optional, Tuple
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from omegaconf import DictConfig
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
 | 
					from pytorch3d.implicitron.tools.config import (
 | 
				
			||||||
 | 
					    get_default_args_field,
 | 
				
			||||||
 | 
					    registry,
 | 
				
			||||||
 | 
					    run_auto_creation,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from pytorch3d.implicitron.tools.utils import evaluating
 | 
					from pytorch3d.implicitron.tools.utils import evaluating
 | 
				
			||||||
from pytorch3d.renderer import RayBundle
 | 
					from pytorch3d.renderer import RayBundle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -18,9 +22,10 @@ from .rgb_net import RayNormalColoringNetwork
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@registry.register
 | 
					@registry.register
 | 
				
			||||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
 | 
					class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):  # pyre-ignore[13]
 | 
				
			||||||
    render_features_dimensions: int = 3
 | 
					    render_features_dimensions: int = 3
 | 
				
			||||||
    ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
 | 
					    object_bounding_sphere: float = 1.0
 | 
				
			||||||
 | 
					    ray_tracer: RayTracing
 | 
				
			||||||
    ray_normal_coloring_network_args: DictConfig = get_default_args_field(
 | 
					    ray_normal_coloring_network_args: DictConfig = get_default_args_field(
 | 
				
			||||||
        RayNormalColoringNetwork
 | 
					        RayNormalColoringNetwork
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
@ -37,8 +42,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
 | 
				
			|||||||
                f"Background color should have {render_features_dimensions} entries."
 | 
					                f"Background color should have {render_features_dimensions} entries."
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.ray_tracer = RayTracing(**self.ray_tracer_args)
 | 
					        run_auto_creation(self)
 | 
				
			||||||
        self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.ray_normal_coloring_network_args[
 | 
					        self.ray_normal_coloring_network_args[
 | 
				
			||||||
            "feature_vector_size"
 | 
					            "feature_vector_size"
 | 
				
			||||||
@ -49,6 +53,17 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
 | 
					        self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None:
 | 
				
			||||||
 | 
					        del args["object_bounding_sphere"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create_ray_tracer(self) -> None:
 | 
				
			||||||
 | 
					        self.ray_tracer = RayTracing(
 | 
				
			||||||
 | 
					            # pyre-ignore[32]
 | 
				
			||||||
 | 
					            **self.ray_tracer_args,
 | 
				
			||||||
 | 
					            object_bounding_sphere=self.object_bounding_sphere,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def requires_object_mask(self) -> bool:
 | 
					    def requires_object_mask(self) -> bool:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -97,7 +112,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
 | 
				
			|||||||
        object_mask = object_mask.reshape(batch_size, -1)
 | 
					        object_mask = object_mask.reshape(batch_size, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with torch.no_grad(), evaluating(implicit_function):
 | 
					        with torch.no_grad(), evaluating(implicit_function):
 | 
				
			||||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
					 | 
				
			||||||
            points, network_object_mask, dists = self.ray_tracer(
 | 
					            points, network_object_mask, dists = self.ray_tracer(
 | 
				
			||||||
                sdf=lambda x: implicit_function(x)[
 | 
					                sdf=lambda x: implicit_function(x)[
 | 
				
			||||||
                    :, 0
 | 
					                    :, 0
 | 
				
			||||||
@ -128,7 +142,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
 | 
				
			|||||||
            N = surface_points.shape[0]
 | 
					            N = surface_points.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Sample points for the eikonal loss
 | 
					            # Sample points for the eikonal loss
 | 
				
			||||||
            # pyre-fixme[9]
 | 
					 | 
				
			||||||
            eik_bounding_box: float = self.object_bounding_sphere
 | 
					            eik_bounding_box: float = self.object_bounding_sphere
 | 
				
			||||||
            n_eik_points = batch_size * num_pixels // 2
 | 
					            n_eik_points = batch_size * num_pixels // 2
 | 
				
			||||||
            eikonal_points = torch.empty(
 | 
					            eikonal_points = torch.empty(
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user