mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Better seeding of random engines
Summary: Currently, seeds are set only inside the train loop. But this does not ensure that the model weights are initialized the same way everywhere which makes all experiments irreproducible. This diff fixes it. Reviewed By: bottler Differential Revision: D38315840 fbshipit-source-id: 3d2ecebbc36072c2b68dd3cd8c5e30708e7dd808
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0c3599e8ee
commit
80fc0ee0b6
@@ -5,11 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
@@ -26,6 +24,8 @@ from pytorch3d.implicitron.tools.stats import Stats
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .utils import seed_all_random_engines
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
max_epochs: Train for this many epochs. Note that if the model was
|
||||
loaded from a checkpoint, we will restart training at the appropriate
|
||||
epoch and run for (max_epochs - checkpoint_epoch) epochs.
|
||||
seed: A random seed to ensure reproducibility.
|
||||
store_checkpoints: If True, store model and optimizer state checkpoints.
|
||||
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
|
||||
to this many epochs.
|
||||
@@ -73,7 +72,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
evaluator: EvaluatorBase
|
||||
evaluator_class_type: str = "ImplicitronEvaluator"
|
||||
max_epochs: int = 1000
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
@@ -102,6 +100,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
device: torch.device,
|
||||
exp_dir: str,
|
||||
stats: Stats,
|
||||
seed: int,
|
||||
task: Task,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -109,7 +108,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
Entry point to run the training and validation loops
|
||||
based on the specified config file.
|
||||
"""
|
||||
_seed_all_random_engines(self.seed)
|
||||
start_epoch = stats.epoch + 1
|
||||
assert scheduler.last_epoch == stats.epoch + 1
|
||||
assert scheduler.last_epoch == start_epoch
|
||||
@@ -140,7 +138,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
|
||||
# Make sure to re-seed random generators to ensure reproducibility
|
||||
# even after restart.
|
||||
_seed_all_random_engines(self.seed + epoch)
|
||||
seed_all_random_engines(seed + epoch)
|
||||
|
||||
cur_lr = float(scheduler.get_last_lr()[-1])
|
||||
logger.debug(f"scheduler lr = {cur_lr:1.2e}")
|
||||
@@ -357,9 +355,3 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
model_io.safe_save_model(
|
||||
unwrapped_model, stats, outfile, optimizer=optimizer
|
||||
)
|
||||
|
||||
|
||||
def _seed_all_random_engines(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
17
projects/implicitron_trainer/impl/utils.py
Normal file
17
projects/implicitron_trainer/impl/utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def seed_all_random_engines(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
Reference in New Issue
Block a user