mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
clean raysampler args
Summary: Don't copy from one part of config to another, rather do the copy within GenericModel. Reviewed By: davnov134 Differential Revision: D38248828 fbshipit-source-id: ff8af985c37ea1f7df9e0aa0a45a58df34c3f893
This commit is contained in:
parent
5f069dbb7e
commit
f45893b845
@ -191,10 +191,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
@ -206,10 +202,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
- 0.0
|
||||
- 0.0
|
||||
raysampler_NearFarRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
|
@ -9,7 +9,7 @@ import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
|
@ -11,7 +11,7 @@ import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
|
@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.models.metrics import (
|
||||
RegularizationMetricsBase,
|
||||
ViewMetricsBase,
|
||||
@ -27,7 +28,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
|
||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from visdom import Visdom
|
||||
@ -615,20 +616,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
self.image_feature_extractor.get_feat_dims()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
We don't expose certain fields of the raysampler because we want to set
|
||||
them from our own members.
|
||||
"""
|
||||
del args["sampling_mode_training"]
|
||||
del args["sampling_mode_evaluation"]
|
||||
del args["image_width"]
|
||||
del args["image_height"]
|
||||
|
||||
def create_raysampler(self):
|
||||
extra_args = {
|
||||
"sampling_mode_training": self.sampling_mode_training,
|
||||
"sampling_mode_evaluation": self.sampling_mode_evaluation,
|
||||
"image_width": self.render_image_width,
|
||||
"image_height": self.render_image_height,
|
||||
}
|
||||
raysampler_args = getattr(
|
||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||
)
|
||||
setattr_if_hasattr(
|
||||
raysampler_args, "sampling_mode_training", self.sampling_mode_training
|
||||
)
|
||||
setattr_if_hasattr(
|
||||
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
|
||||
)
|
||||
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
|
||||
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
|
||||
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||
**raysampler_args
|
||||
**raysampler_args, **extra_args
|
||||
)
|
||||
|
||||
def create_renderer(self):
|
||||
|
@ -157,15 +157,6 @@ def cat_dataclass(batch, tensor_collator: Callable):
|
||||
return type(elem)(**collated)
|
||||
|
||||
|
||||
def setattr_if_hasattr(obj, name, value):
|
||||
"""
|
||||
Same as setattr(obj, name, value), but does nothing in case `name` is
|
||||
not an attribe of `obj`.
|
||||
"""
|
||||
if hasattr(obj, name):
|
||||
setattr(obj, name, value)
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
A simple class for timing execution.
|
||||
|
@ -56,10 +56,6 @@ global_encoder_SequenceAutodecoder_args:
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
|
Loading…
x
Reference in New Issue
Block a user