clean raysampler args

Summary: Don't copy from one part of config to another, rather do the copy within GenericModel.

Reviewed By: davnov134

Differential Revision: D38248828

fbshipit-source-id: ff8af985c37ea1f7df9e0aa0a45a58df34c3f893
This commit is contained in:
Jeremy Reizenstein 2022-08-03 12:37:31 -07:00 committed by Facebook GitHub Bot
parent 5f069dbb7e
commit f45893b845
6 changed files with 22 additions and 33 deletions

View File

@ -191,10 +191,6 @@ model_factory_ImplicitronModelFactory_args:
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
@ -206,10 +202,6 @@ model_factory_ImplicitronModelFactory_args:
- 0.0
- 0.0
raysampler_NearFarRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024

View File

@ -9,7 +9,7 @@ import json
import os
from typing import Dict, List, Optional, Tuple, Type
from omegaconf import DictConfig, open_dict
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,

View File

@ -11,7 +11,7 @@ import os
import warnings
from typing import Dict, List, Optional, Tuple, Type
from omegaconf import DictConfig, open_dict
from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,

View File

@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import tqdm
from omegaconf import DictConfig
from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase,
ViewMetricsBase,
@ -27,7 +28,7 @@ from pytorch3d.implicitron.tools.config import (
run_auto_creation,
)
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
@ -615,20 +616,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
self.image_feature_extractor.get_feat_dims()
)
@classmethod
def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
"""
We don't expose certain fields of the raysampler because we want to set
them from our own members.
"""
del args["sampling_mode_training"]
del args["sampling_mode_evaluation"]
del args["image_width"]
del args["image_height"]
def create_raysampler(self):
extra_args = {
"sampling_mode_training": self.sampling_mode_training,
"sampling_mode_evaluation": self.sampling_mode_evaluation,
"image_width": self.render_image_width,
"image_height": self.render_image_height,
}
raysampler_args = getattr(
self, "raysampler_" + self.raysampler_class_type + "_args"
)
setattr_if_hasattr(
raysampler_args, "sampling_mode_training", self.sampling_mode_training
)
setattr_if_hasattr(
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
)
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
**raysampler_args
**raysampler_args, **extra_args
)
def create_renderer(self):

View File

@ -157,15 +157,6 @@ def cat_dataclass(batch, tensor_collator: Callable):
return type(elem)(**collated)
def setattr_if_hasattr(obj, name, value):
"""
Same as setattr(obj, name, value), but does nothing in case `name` is
not an attribe of `obj`.
"""
if hasattr(obj, name):
setattr(obj, name, value)
class Timer:
"""
A simple class for timing execution.

View File

@ -56,10 +56,6 @@ global_encoder_SequenceAutodecoder_args:
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024