enable_get_default_args to allow pickling get_default_args(f)

Summary:
Try again to solve https://github.com/facebookresearch/pytorch3d/issues/1144 pickling problem.
D35258561 (24260130ce) didn't work.

When writing a function or vanilla class C which you want people to be able to call get_default_args on, you must add the line enable_get_default_args(C) to it. This causes autogeneration of a hidden dataclass in the module.

Reviewed By: davnov134

Differential Revision: D35364410

fbshipit-source-id: 53f6e6fff43e7142ae18ca3b06de7d0c849ef965
This commit is contained in:
Jeremy Reizenstein
2022-04-06 03:32:31 -07:00
committed by Facebook GitHub Bot
parent 4c48beb226
commit e10a90140d
7 changed files with 102 additions and 50 deletions

View File

@@ -74,6 +74,7 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
enable_get_default_args,
get_default_args_field,
remove_unused_components,
)
@@ -304,6 +305,9 @@ def init_optimizer(
return optimizer, scheduler
enable_get_default_args(init_optimizer)
def trainvalidate(
model,
stats,
@@ -663,9 +667,7 @@ def _seed_all_random_engines(seed: int):
@dataclass(eq=False)
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(
init_optimizer, _allow_untyped=True
)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic"