mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
clean raysampler args
Summary: Don't copy from one part of config to another, rather do the copy within GenericModel. Reviewed By: davnov134 Differential Revision: D38248828 fbshipit-source-id: ff8af985c37ea1f7df9e0aa0a45a58df34c3f893
This commit is contained in:
parent
5f069dbb7e
commit
f45893b845
@ -191,10 +191,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
init_scale: 1.0
|
init_scale: 1.0
|
||||||
ignore_input: false
|
ignore_input: false
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
image_width: 400
|
|
||||||
image_height: 400
|
|
||||||
sampling_mode_training: mask_sample
|
|
||||||
sampling_mode_evaluation: full_grid
|
|
||||||
n_pts_per_ray_training: 64
|
n_pts_per_ray_training: 64
|
||||||
n_pts_per_ray_evaluation: 64
|
n_pts_per_ray_evaluation: 64
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
@ -206,10 +202,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
- 0.0
|
- 0.0
|
||||||
- 0.0
|
- 0.0
|
||||||
raysampler_NearFarRaySampler_args:
|
raysampler_NearFarRaySampler_args:
|
||||||
image_width: 400
|
|
||||||
image_height: 400
|
|
||||||
sampling_mode_training: mask_sample
|
|
||||||
sampling_mode_evaluation: full_grid
|
|
||||||
n_pts_per_ray_training: 64
|
n_pts_per_ray_training: 64
|
||||||
n_pts_per_ray_evaluation: 64
|
n_pts_per_ray_evaluation: 64
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
@ -9,7 +9,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from omegaconf import DictConfig, open_dict
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
expand_args_fields,
|
expand_args_fields,
|
||||||
registry,
|
registry,
|
||||||
|
@ -11,7 +11,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from omegaconf import DictConfig, open_dict
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
|
@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.models.metrics import (
|
from pytorch3d.implicitron.models.metrics import (
|
||||||
RegularizationMetricsBase,
|
RegularizationMetricsBase,
|
||||||
ViewMetricsBase,
|
ViewMetricsBase,
|
||||||
@ -27,7 +28,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
|
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
@ -615,20 +616,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
self.image_feature_extractor.get_feat_dims()
|
self.image_feature_extractor.get_feat_dims()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain fields of the raysampler because we want to set
|
||||||
|
them from our own members.
|
||||||
|
"""
|
||||||
|
del args["sampling_mode_training"]
|
||||||
|
del args["sampling_mode_evaluation"]
|
||||||
|
del args["image_width"]
|
||||||
|
del args["image_height"]
|
||||||
|
|
||||||
def create_raysampler(self):
|
def create_raysampler(self):
|
||||||
|
extra_args = {
|
||||||
|
"sampling_mode_training": self.sampling_mode_training,
|
||||||
|
"sampling_mode_evaluation": self.sampling_mode_evaluation,
|
||||||
|
"image_width": self.render_image_width,
|
||||||
|
"image_height": self.render_image_height,
|
||||||
|
}
|
||||||
raysampler_args = getattr(
|
raysampler_args = getattr(
|
||||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
)
|
)
|
||||||
setattr_if_hasattr(
|
|
||||||
raysampler_args, "sampling_mode_training", self.sampling_mode_training
|
|
||||||
)
|
|
||||||
setattr_if_hasattr(
|
|
||||||
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
|
|
||||||
)
|
|
||||||
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
|
|
||||||
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
|
|
||||||
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||||
**raysampler_args
|
**raysampler_args, **extra_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_renderer(self):
|
def create_renderer(self):
|
||||||
|
@ -157,15 +157,6 @@ def cat_dataclass(batch, tensor_collator: Callable):
|
|||||||
return type(elem)(**collated)
|
return type(elem)(**collated)
|
||||||
|
|
||||||
|
|
||||||
def setattr_if_hasattr(obj, name, value):
|
|
||||||
"""
|
|
||||||
Same as setattr(obj, name, value), but does nothing in case `name` is
|
|
||||||
not an attribe of `obj`.
|
|
||||||
"""
|
|
||||||
if hasattr(obj, name):
|
|
||||||
setattr(obj, name, value)
|
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
"""
|
"""
|
||||||
A simple class for timing execution.
|
A simple class for timing execution.
|
||||||
|
@ -56,10 +56,6 @@ global_encoder_SequenceAutodecoder_args:
|
|||||||
init_scale: 1.0
|
init_scale: 1.0
|
||||||
ignore_input: false
|
ignore_input: false
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
image_width: 400
|
|
||||||
image_height: 400
|
|
||||||
sampling_mode_training: mask_sample
|
|
||||||
sampling_mode_evaluation: full_grid
|
|
||||||
n_pts_per_ray_training: 64
|
n_pts_per_ray_training: 64
|
||||||
n_pts_per_ray_evaluation: 64
|
n_pts_per_ray_evaluation: 64
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
Loading…
x
Reference in New Issue
Block a user