Better seeding of random engines

Summary: Currently, seeds are set only inside the train loop. But this does not ensure that the model weights are initialized the same way everywhere which makes all experiments irreproducible. This diff fixes it.

Reviewed By: bottler

Differential Revision: D38315840

fbshipit-source-id: 3d2ecebbc36072c2b68dd3cd8c5e30708e7dd808
This commit is contained in:
David Novotny 2022-08-01 10:03:09 -07:00 committed by Facebook GitHub Bot
parent 0c3599e8ee
commit 80fc0ee0b6
4 changed files with 31 additions and 13 deletions

View File

@ -53,6 +53,7 @@ import warnings
from dataclasses import field from dataclasses import field
import hydra import hydra
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
@ -78,6 +79,7 @@ from pytorch3d.implicitron.tools.config import (
from .impl.model_factory import ModelFactoryBase from .impl.model_factory import ModelFactoryBase
from .impl.optimizer_factory import OptimizerFactoryBase from .impl.optimizer_factory import OptimizerFactoryBase
from .impl.training_loop import TrainingLoopBase from .impl.training_loop import TrainingLoopBase
from .impl.utils import seed_all_random_engines
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,6 +112,7 @@ class Experiment(Configurable): # pyre-ignore: 13
scheduler. scheduler.
training_loop: An object that runs training given the outputs produced training_loop: An object that runs training given the outputs produced
by the data_source, model_factory and optimizer_factory. by the data_source, model_factory and optimizer_factory.
seed: A random seed to ensure reproducibility.
detect_anomaly: Whether torch.autograd should detect anomalies. Useful detect_anomaly: Whether torch.autograd should detect anomalies. Useful
for debugging, but might slow down the training. for debugging, but might slow down the training.
exp_dir: Root experimentation directory. Checkpoints and training stats exp_dir: Root experimentation directory. Checkpoints and training stats
@ -125,6 +128,7 @@ class Experiment(Configurable): # pyre-ignore: 13
training_loop: TrainingLoopBase training_loop: TrainingLoopBase
training_loop_class_type: str = "ImplicitronTrainingLoop" training_loop_class_type: str = "ImplicitronTrainingLoop"
seed: int = 42
detect_anomaly: bool = False detect_anomaly: bool = False
exp_dir: str = "./data/default_experiment/" exp_dir: str = "./data/default_experiment/"
@ -136,6 +140,10 @@ class Experiment(Configurable): # pyre-ignore: 13
) )
def __post_init__(self): def __post_init__(self):
seed_all_random_engines(
self.seed
) # Set all random engine seeds for reproducibility
run_auto_creation(self) run_auto_creation(self)
def run(self) -> None: def run(self) -> None:
@ -214,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
device=device, device=device,
exp_dir=self.exp_dir, exp_dir=self.exp_dir,
stats=stats, stats=stats,
seed=self.seed,
task=task, task=task,
) )

View File

@ -5,11 +5,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
import random
import time import time
from typing import Any, Optional from typing import Any, Optional
import numpy as np
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task from pytorch3d.implicitron.dataset.data_source import Task
@ -26,6 +24,8 @@ from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from .utils import seed_all_random_engines
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,7 +52,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
max_epochs: Train for this many epochs. Note that if the model was max_epochs: Train for this many epochs. Note that if the model was
loaded from a checkpoint, we will restart training at the appropriate loaded from a checkpoint, we will restart training at the appropriate
epoch and run for (max_epochs - checkpoint_epoch) epochs. epoch and run for (max_epochs - checkpoint_epoch) epochs.
seed: A random seed to ensure reproducibility.
store_checkpoints: If True, store model and optimizer state checkpoints. store_checkpoints: If True, store model and optimizer state checkpoints.
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
to this many epochs. to this many epochs.
@ -73,7 +72,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
evaluator: EvaluatorBase evaluator: EvaluatorBase
evaluator_class_type: str = "ImplicitronEvaluator" evaluator_class_type: str = "ImplicitronEvaluator"
max_epochs: int = 1000 max_epochs: int = 1000
seed: int = 0
store_checkpoints: bool = True store_checkpoints: bool = True
store_checkpoints_purge: int = 1 store_checkpoints_purge: int = 1
test_interval: int = -1 test_interval: int = -1
@ -102,6 +100,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
device: torch.device, device: torch.device,
exp_dir: str, exp_dir: str,
stats: Stats, stats: Stats,
seed: int,
task: Task, task: Task,
**kwargs, **kwargs,
): ):
@ -109,7 +108,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
Entry point to run the training and validation loops Entry point to run the training and validation loops
based on the specified config file. based on the specified config file.
""" """
_seed_all_random_engines(self.seed)
start_epoch = stats.epoch + 1 start_epoch = stats.epoch + 1
assert scheduler.last_epoch == stats.epoch + 1 assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch assert scheduler.last_epoch == start_epoch
@ -140,7 +138,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
# Make sure to re-seed random generators to ensure reproducibility # Make sure to re-seed random generators to ensure reproducibility
# even after restart. # even after restart.
_seed_all_random_engines(self.seed + epoch) seed_all_random_engines(seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1]) cur_lr = float(scheduler.get_last_lr()[-1])
logger.debug(f"scheduler lr = {cur_lr:1.2e}") logger.debug(f"scheduler lr = {cur_lr:1.2e}")
@ -357,9 +355,3 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
model_io.safe_save_model( model_io.safe_save_model(
unwrapped_model, stats, outfile, optimizer=optimizer unwrapped_model, stats, outfile, optimizer=optimizer
) )
def _seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

View File

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import random
import numpy as np
import torch
def seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

View File

@ -2,6 +2,7 @@ data_source_class_type: ImplicitronDataSource
model_factory_class_type: ImplicitronModelFactory model_factory_class_type: ImplicitronModelFactory
optimizer_factory_class_type: ImplicitronOptimizerFactory optimizer_factory_class_type: ImplicitronOptimizerFactory
training_loop_class_type: ImplicitronTrainingLoop training_loop_class_type: ImplicitronTrainingLoop
seed: 42
detect_anomaly: false detect_anomaly: false
exp_dir: ./data/default_experiment/ exp_dir: ./data/default_experiment/
hydra: hydra:
@ -429,7 +430,6 @@ training_loop_ImplicitronTrainingLoop_args:
eval_only: false eval_only: false
evaluator_class_type: ImplicitronEvaluator evaluator_class_type: ImplicitronEvaluator
max_epochs: 1000 max_epochs: 1000
seed: 0
store_checkpoints: true store_checkpoints: true
store_checkpoints_purge: 1 store_checkpoints_purge: 1
test_interval: -1 test_interval: -1