mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	clean raysampler args
Summary: Don't copy from one part of config to another, rather do the copy within GenericModel. Reviewed By: davnov134 Differential Revision: D38248828 fbshipit-source-id: ff8af985c37ea1f7df9e0aa0a45a58df34c3f893
This commit is contained in:
		
							parent
							
								
									5f069dbb7e
								
							
						
					
					
						commit
						f45893b845
					
				@ -191,10 +191,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
        init_scale: 1.0
 | 
					        init_scale: 1.0
 | 
				
			||||||
        ignore_input: false
 | 
					        ignore_input: false
 | 
				
			||||||
    raysampler_AdaptiveRaySampler_args:
 | 
					    raysampler_AdaptiveRaySampler_args:
 | 
				
			||||||
      image_width: 400
 | 
					 | 
				
			||||||
      image_height: 400
 | 
					 | 
				
			||||||
      sampling_mode_training: mask_sample
 | 
					 | 
				
			||||||
      sampling_mode_evaluation: full_grid
 | 
					 | 
				
			||||||
      n_pts_per_ray_training: 64
 | 
					      n_pts_per_ray_training: 64
 | 
				
			||||||
      n_pts_per_ray_evaluation: 64
 | 
					      n_pts_per_ray_evaluation: 64
 | 
				
			||||||
      n_rays_per_image_sampled_from_mask: 1024
 | 
					      n_rays_per_image_sampled_from_mask: 1024
 | 
				
			||||||
@ -206,10 +202,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      - 0.0
 | 
					      - 0.0
 | 
				
			||||||
      - 0.0
 | 
					      - 0.0
 | 
				
			||||||
    raysampler_NearFarRaySampler_args:
 | 
					    raysampler_NearFarRaySampler_args:
 | 
				
			||||||
      image_width: 400
 | 
					 | 
				
			||||||
      image_height: 400
 | 
					 | 
				
			||||||
      sampling_mode_training: mask_sample
 | 
					 | 
				
			||||||
      sampling_mode_evaluation: full_grid
 | 
					 | 
				
			||||||
      n_pts_per_ray_training: 64
 | 
					      n_pts_per_ray_training: 64
 | 
				
			||||||
      n_pts_per_ray_evaluation: 64
 | 
					      n_pts_per_ray_evaluation: 64
 | 
				
			||||||
      n_rays_per_image_sampled_from_mask: 1024
 | 
					      n_rays_per_image_sampled_from_mask: 1024
 | 
				
			||||||
 | 
				
			|||||||
@ -9,7 +9,7 @@ import json
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple, Type
 | 
					from typing import Dict, List, Optional, Tuple, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from omegaconf import DictConfig, open_dict
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
from pytorch3d.implicitron.tools.config import (
 | 
					from pytorch3d.implicitron.tools.config import (
 | 
				
			||||||
    expand_args_fields,
 | 
					    expand_args_fields,
 | 
				
			||||||
    registry,
 | 
					    registry,
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,7 @@ import os
 | 
				
			|||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple, Type
 | 
					from typing import Dict, List, Optional, Tuple, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from omegaconf import DictConfig, open_dict
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
					from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
				
			||||||
    DatasetMap,
 | 
					    DatasetMap,
 | 
				
			||||||
    DatasetMapProviderBase,
 | 
					    DatasetMapProviderBase,
 | 
				
			||||||
 | 
				
			|||||||
@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import tqdm
 | 
					import tqdm
 | 
				
			||||||
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
from pytorch3d.implicitron.models.metrics import (
 | 
					from pytorch3d.implicitron.models.metrics import (
 | 
				
			||||||
    RegularizationMetricsBase,
 | 
					    RegularizationMetricsBase,
 | 
				
			||||||
    ViewMetricsBase,
 | 
					    ViewMetricsBase,
 | 
				
			||||||
@ -27,7 +28,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
				
			|||||||
    run_auto_creation,
 | 
					    run_auto_creation,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
 | 
					from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
 | 
				
			||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
 | 
					from pytorch3d.implicitron.tools.utils import cat_dataclass
 | 
				
			||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
 | 
					from pytorch3d.renderer import RayBundle, utils as rend_utils
 | 
				
			||||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
					from pytorch3d.renderer.cameras import CamerasBase
 | 
				
			||||||
from visdom import Visdom
 | 
					from visdom import Visdom
 | 
				
			||||||
@ -615,20 +616,29 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
				
			|||||||
            self.image_feature_extractor.get_feat_dims()
 | 
					            self.image_feature_extractor.get_feat_dims()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        We don't expose certain fields of the raysampler because we want to set
 | 
				
			||||||
 | 
					        them from our own members.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        del args["sampling_mode_training"]
 | 
				
			||||||
 | 
					        del args["sampling_mode_evaluation"]
 | 
				
			||||||
 | 
					        del args["image_width"]
 | 
				
			||||||
 | 
					        del args["image_height"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_raysampler(self):
 | 
					    def create_raysampler(self):
 | 
				
			||||||
 | 
					        extra_args = {
 | 
				
			||||||
 | 
					            "sampling_mode_training": self.sampling_mode_training,
 | 
				
			||||||
 | 
					            "sampling_mode_evaluation": self.sampling_mode_evaluation,
 | 
				
			||||||
 | 
					            "image_width": self.render_image_width,
 | 
				
			||||||
 | 
					            "image_height": self.render_image_height,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        raysampler_args = getattr(
 | 
					        raysampler_args = getattr(
 | 
				
			||||||
            self, "raysampler_" + self.raysampler_class_type + "_args"
 | 
					            self, "raysampler_" + self.raysampler_class_type + "_args"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        setattr_if_hasattr(
 | 
					 | 
				
			||||||
            raysampler_args, "sampling_mode_training", self.sampling_mode_training
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        setattr_if_hasattr(
 | 
					 | 
				
			||||||
            raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
 | 
					 | 
				
			||||||
        setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
 | 
					 | 
				
			||||||
        self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
 | 
					        self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
 | 
				
			||||||
            **raysampler_args
 | 
					            **raysampler_args, **extra_args
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_renderer(self):
 | 
					    def create_renderer(self):
 | 
				
			||||||
 | 
				
			|||||||
@ -157,15 +157,6 @@ def cat_dataclass(batch, tensor_collator: Callable):
 | 
				
			|||||||
    return type(elem)(**collated)
 | 
					    return type(elem)(**collated)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def setattr_if_hasattr(obj, name, value):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Same as setattr(obj, name, value), but does nothing in case `name` is
 | 
					 | 
				
			||||||
    not an attribe of `obj`.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    if hasattr(obj, name):
 | 
					 | 
				
			||||||
        setattr(obj, name, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Timer:
 | 
					class Timer:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    A simple class for timing execution.
 | 
					    A simple class for timing execution.
 | 
				
			||||||
 | 
				
			|||||||
@ -56,10 +56,6 @@ global_encoder_SequenceAutodecoder_args:
 | 
				
			|||||||
    init_scale: 1.0
 | 
					    init_scale: 1.0
 | 
				
			||||||
    ignore_input: false
 | 
					    ignore_input: false
 | 
				
			||||||
raysampler_AdaptiveRaySampler_args:
 | 
					raysampler_AdaptiveRaySampler_args:
 | 
				
			||||||
  image_width: 400
 | 
					 | 
				
			||||||
  image_height: 400
 | 
					 | 
				
			||||||
  sampling_mode_training: mask_sample
 | 
					 | 
				
			||||||
  sampling_mode_evaluation: full_grid
 | 
					 | 
				
			||||||
  n_pts_per_ray_training: 64
 | 
					  n_pts_per_ray_training: 64
 | 
				
			||||||
  n_pts_per_ray_evaluation: 64
 | 
					  n_pts_per_ray_evaluation: 64
 | 
				
			||||||
  n_rays_per_image_sampled_from_mask: 1024
 | 
					  n_rays_per_image_sampled_from_mask: 1024
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user