make ExperimentConfig Configurable

Summary: Preparing for pluggables in experiment.py

Reviewed By: davnov134

Differential Revision: D36830674

fbshipit-source-id: eab499d1bc19c690798fbf7da547544df7e88fa5
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 6275283202
commit c0f88e04a0

View File

@ -53,7 +53,7 @@ import os
import random import random
import time import time
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import field
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import hydra import hydra
@ -73,7 +73,9 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
Configurable,
enable_get_default_args, enable_get_default_args,
expand_args_fields,
get_default_args_field, get_default_args_field,
remove_unused_components, remove_unused_components,
) )
@ -671,8 +673,7 @@ def _seed_all_random_engines(seed: int):
random.seed(seed) random.seed(seed)
@dataclass(eq=False) class ExperimentConfig(Configurable):
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel) generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer) solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource) data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
@ -705,6 +706,8 @@ class ExperimentConfig:
) )
expand_args_fields(ExperimentConfig)
if __name__ == "__main__": if __name__ == "__main__":
cs = hydra.core.config_store.ConfigStore.instance() cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig) cs.store(name="default_config", node=ExperimentConfig)