mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Better seeding of random engines
Summary: Currently, seeds are set only inside the train loop. But this does not ensure that the model weights are initialized the same way everywhere which makes all experiments irreproducible. This diff fixes it. Reviewed By: bottler Differential Revision: D38315840 fbshipit-source-id: 3d2ecebbc36072c2b68dd3cd8c5e30708e7dd808
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0c3599e8ee
commit
80fc0ee0b6
@@ -53,6 +53,7 @@ import warnings
|
||||
from dataclasses import field
|
||||
|
||||
import hydra
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
@@ -78,6 +79,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
from .impl.model_factory import ModelFactoryBase
|
||||
from .impl.optimizer_factory import OptimizerFactoryBase
|
||||
from .impl.training_loop import TrainingLoopBase
|
||||
from .impl.utils import seed_all_random_engines
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -110,6 +112,7 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
scheduler.
|
||||
training_loop: An object that runs training given the outputs produced
|
||||
by the data_source, model_factory and optimizer_factory.
|
||||
seed: A random seed to ensure reproducibility.
|
||||
detect_anomaly: Whether torch.autograd should detect anomalies. Useful
|
||||
for debugging, but might slow down the training.
|
||||
exp_dir: Root experimentation directory. Checkpoints and training stats
|
||||
@@ -125,6 +128,7 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
training_loop: TrainingLoopBase
|
||||
training_loop_class_type: str = "ImplicitronTrainingLoop"
|
||||
|
||||
seed: int = 42
|
||||
detect_anomaly: bool = False
|
||||
exp_dir: str = "./data/default_experiment/"
|
||||
|
||||
@@ -136,6 +140,10 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
seed_all_random_engines(
|
||||
self.seed
|
||||
) # Set all random engine seeds for reproducibility
|
||||
|
||||
run_auto_creation(self)
|
||||
|
||||
def run(self) -> None:
|
||||
@@ -214,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
device=device,
|
||||
exp_dir=self.exp_dir,
|
||||
stats=stats,
|
||||
seed=self.seed,
|
||||
task=task,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user