Better seeding of random engines

Summary: Currently, seeds are set only inside the train loop. But this does not ensure that the model weights are initialized the same way everywhere which makes all experiments irreproducible. This diff fixes it.

Reviewed By: bottler

Differential Revision: D38315840

fbshipit-source-id: 3d2ecebbc36072c2b68dd3cd8c5e30708e7dd808
This commit is contained in:
David Novotny
2022-08-01 10:03:09 -07:00
committed by Facebook GitHub Bot
parent 0c3599e8ee
commit 80fc0ee0b6
4 changed files with 31 additions and 13 deletions

View File

@@ -53,6 +53,7 @@ import warnings
from dataclasses import field
import hydra
import torch
from accelerate import Accelerator
from omegaconf import DictConfig, OmegaConf
@@ -78,6 +79,7 @@ from pytorch3d.implicitron.tools.config import (
from .impl.model_factory import ModelFactoryBase
from .impl.optimizer_factory import OptimizerFactoryBase
from .impl.training_loop import TrainingLoopBase
from .impl.utils import seed_all_random_engines
logger = logging.getLogger(__name__)
@@ -110,6 +112,7 @@ class Experiment(Configurable): # pyre-ignore: 13
scheduler.
training_loop: An object that runs training given the outputs produced
by the data_source, model_factory and optimizer_factory.
seed: A random seed to ensure reproducibility.
detect_anomaly: Whether torch.autograd should detect anomalies. Useful
for debugging, but might slow down the training.
exp_dir: Root experimentation directory. Checkpoints and training stats
@@ -125,6 +128,7 @@ class Experiment(Configurable): # pyre-ignore: 13
training_loop: TrainingLoopBase
training_loop_class_type: str = "ImplicitronTrainingLoop"
seed: int = 42
detect_anomaly: bool = False
exp_dir: str = "./data/default_experiment/"
@@ -136,6 +140,10 @@ class Experiment(Configurable): # pyre-ignore: 13
)
def __post_init__(self):
seed_all_random_engines(
self.seed
) # Set all random engine seeds for reproducibility
run_auto_creation(self)
def run(self) -> None:
@@ -214,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
device=device,
exp_dir=self.exp_dir,
stats=stats,
seed=self.seed,
task=task,
)