mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
enable_get_default_args to allow pickling get_default_args(f)
Summary:
Try again to solve https://github.com/facebookresearch/pytorch3d/issues/1144 pickling problem.
D35258561 (24260130ce) didn't work.
When writing a function or vanilla class C which you want people to be able to call get_default_args on, you must add the line enable_get_default_args(C) to it. This causes autogeneration of a hidden dataclass in the module.
Reviewed By: davnov134
Differential Revision: D35364410
fbshipit-source-id: 53f6e6fff43e7142ae18ca3b06de7d0c849ef965
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4c48beb226
commit
e10a90140d
@@ -74,6 +74,7 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
enable_get_default_args,
|
||||
get_default_args_field,
|
||||
remove_unused_components,
|
||||
)
|
||||
@@ -304,6 +305,9 @@ def init_optimizer(
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
||||
|
||||
|
||||
def trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
@@ -663,9 +667,7 @@ def _seed_all_random_engines(seed: int):
|
||||
@dataclass(eq=False)
|
||||
class ExperimentConfig:
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(
|
||||
init_optimizer, _allow_untyped=True
|
||||
)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
architecture: str = "generic"
|
||||
|
||||
Reference in New Issue
Block a user