Updates to support Accelerate and multigpu training (#37)

Summary:
## Changes:
- Added Accelerate Library and refactored experiment.py to use it
- Needed to move `init_optimizer` and `ExperimentConfig` to a separate file to be compatible with submitit/hydra
- Needed to make some modifications to data loaders etc to work well with the accelerate ddp wrappers
- Loading/saving checkpoints incorporates an unwrapping step so remove the ddp wrapped model

## Tests

Tested with both `torchrun` and `submitit/hydra` on two gpus locally. Here are the commands:

**Torchrun**

Modules loaded:
```sh
1) anaconda3/2021.05   2) cuda/11.3   3) NCCL/2.9.8-3-cuda.11.3   4) gcc/5.2.0. (but unload gcc when using submit)
```

```sh
torchrun --nnodes=1 --nproc_per_node=2 experiment.py --config-path ./configs --config-name repro_singleseq_nerf_test
```

**Submitit/Hydra Local test**

```sh
~/pytorch3d/projects/implicitron_trainer$ HYDRA_FULL_ERROR=1 python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs  hydra/launcher=submitit_local hydra.launcher.gpus_per_node=2 hydra.launcher.tasks_per_node=2 hydra.launcher.nodes=1
```

**Submitit/Hydra distributed test**

```sh
~/implicitron/pytorch3d$ python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs  hydra/launcher=submitit_slurm hydra.launcher.gpus_per_node=8 hydra.launcher.tasks_per_node=8 hydra.launcher.nodes=1 hydra.launcher.partition=learnlab hydra.launcher.timeout_min=4320
```

## TODOS:
- Fix distributed evaluation: currently this doesn't work as the input format to the evaluation function is not suitable for gathering across gpus (needs to be nested list/tuple/dicts of objects that satisfy `is_torch_tensor`) and currently `frame_data`  contains `Cameras` type.
- Refactor the `accelerator` object to be accessible by all functions instead of needing to pass it around everywhere? Maybe have a `Trainer` class and add it as a method?
- Update readme with installation instructions for accelerate and also commands for running jobs with torchrun and submitit/hydra

X-link: https://github.com/fairinternal/pytorch3d/pull/37

Reviewed By: davnov134, kjchalup

Differential Revision: D37543870

Pulled By: bottler

fbshipit-source-id: be9eb4e91244d4fe3740d87dafec622ae1e0cf76
This commit is contained in:
Nikhila Ravi 2022-07-11 19:29:58 -07:00 committed by Facebook GitHub Bot
parent 57a40b3688
commit aa8b03f31d
7 changed files with 290 additions and 162 deletions

View File

@ -45,7 +45,6 @@ The outputs of the experiment are saved and logged in multiple ways:
config file.
"""
import copy
import json
import logging
@ -53,7 +52,6 @@ import os
import random
import time
import warnings
from dataclasses import field
from typing import Any, Dict, Optional, Tuple
import hydra
@ -61,6 +59,7 @@ import lpips
import numpy as np
import torch
import tqdm
from accelerate import Accelerator
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
@ -69,17 +68,20 @@ from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Tas
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
Configurable,
enable_get_default_args,
expand_args_fields,
get_default_args_field,
remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
from .impl.experiment_config import ExperimentConfig
from .impl.optimization import init_optimizer
logger = logging.getLogger(__name__)
@ -101,6 +103,7 @@ def init_model(
force_load: bool = False,
clear_stats: bool = False,
load_model_only: bool = False,
accelerator: Accelerator = None,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
"""
Returns an instance of `GenericModel`.
@ -161,12 +164,20 @@ def init_model(
logger.info("found previous model %s" % model_path)
if force_load or cfg.resume:
logger.info(" -> resuming")
map_location = None
if not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
if load_model_only:
model_state_dict = torch.load(model_io.get_model_path(model_path))
model_state_dict = torch.load(
model_io.get_model_path(model_path), map_location=map_location
)
stats_load, optimizer_state = None, None
else:
model_state_dict, stats_load, optimizer_state = model_io.load_model(
model_path
model_path, map_location=map_location
)
# Determine if stats should be reset
@ -210,101 +221,6 @@ def init_model(
return model, stats, optimizer_state
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)
def trainvalidate(
model,
stats,
@ -318,6 +234,7 @@ def trainvalidate(
visdom_env_root: str = "trainvalidate",
clip_grad: float = 0.0,
device: str = "cuda:0",
accelerator: Accelerator = None,
**kwargs,
) -> None:
"""
@ -365,11 +282,11 @@ def trainvalidate(
# Iterate through the batches
n_batches = len(loader)
for it, batch in enumerate(loader):
for it, net_input in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = batch.to(device)
net_input = net_input.to(accelerator.device)
# run the forward pass
if not validation:
@ -395,7 +312,11 @@ def trainvalidate(
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if visualize_interval > 0 and it % visualize_interval == 0:
if (
accelerator.is_local_main_process
and visualize_interval > 0
and it % visualize_interval == 0
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
model.visualize(
@ -410,7 +331,7 @@ def trainvalidate(
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
loss.backward()
accelerator.backward(loss)
if clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
@ -425,12 +346,22 @@ def trainvalidate(
optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
def run_training(cfg: DictConfig) -> None:
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
# Initialize the accelerator
accelerator = Accelerator(device_placement=False)
logger.info(accelerator.state)
device = accelerator.device
logger.info(f"Running experiment on device: {device}")
if accelerator.is_local_main_process:
logger.info(OmegaConf.to_yaml(cfg))
# set the debug mode
if cfg.detect_anomaly:
logger.info("Anomaly detection!")
@ -455,11 +386,11 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
all_train_cameras = datasource.get_all_train_cameras()
# init the model
model, stats, optimizer_state = init_model(cfg)
model, stats, optimizer_state = init_model(cfg, accelerator=accelerator)
start_epoch = stats.epoch + 1
# move model to gpu
model.to(device)
model.to(accelerator.device)
# only run evaluation on the test dataloader
if cfg.eval_only:
@ -472,6 +403,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
model,
stats,
device=device,
accelerator=accelerator,
)
return
@ -487,6 +419,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
# Wrap all modules in the distributed library
# Note: we don't pass the scheduler to prepare as it
# doesn't need to be stepped at each optimizer step
(
model,
optimizer,
train_loader,
val_loader,
) = accelerator.prepare(model, optimizer, dataloaders.train, dataloaders.val)
past_scheduler_lrs = []
# loop through epochs
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
@ -506,25 +448,27 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
model,
stats,
epoch,
dataloaders.train,
train_loader,
optimizer,
False,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
accelerator=accelerator,
**cfg,
)
# val loop (optional)
if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
if val_loader is not None and epoch % cfg.validation_interval == 0:
trainvalidate(
model,
stats,
epoch,
dataloaders.val,
val_loader,
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
accelerator=accelerator,
**cfg,
)
@ -541,18 +485,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
accelerator=accelerator,
)
assert stats.epoch == epoch, "inconsistent stats!"
# delete previous models if required
# save model
if cfg.store_checkpoints:
# save model only on the main process
if cfg.store_checkpoints and accelerator.is_local_main_process:
if cfg.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
unwrapped_model = accelerator.unwrap_model(model)
model_io.safe_save_model(
unwrapped_model, stats, outfile, optimizer=optimizer
)
scheduler.step()
@ -582,6 +530,7 @@ def _eval_and_dump(
model,
stats,
device,
accelerator: Accelerator = None,
) -> None:
"""
Run the evaluation loop with the test data loader and
@ -600,6 +549,7 @@ def _eval_and_dump(
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
accelerator=accelerator,
)
# add the evaluation epoch to the results
@ -634,19 +584,20 @@ def _run_eval(
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
accelerator: Accelerator = None,
):
"""
Run the evaluation loop on the test dataloader
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
lpips_model = lpips_model.to(accelerator.device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(loader):
frame_data = frame_data.to(device)
frame_data = frame_data.to(accelerator.device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
@ -655,7 +606,15 @@ def _run_eval(
preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
)
# TODO: Cannot use accelerate gather for two reasons:.
# (1) TypeError: Can't apply _gpu_gather_one on object of type
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
# (2) Same error above but for frame_data which contains Cameras.
implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
@ -673,62 +632,65 @@ def _run_eval(
return category_result["results"]
def _seed_all_random_engines(seed: int):
def _seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
class ExperimentConfig(Configurable):
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
def _setup_envvars_for_cluster(cfg) -> bool:
"""
Prepares to run on cluster if relevant.
Returns whether FAIR cluster in use.
"""
# TODO: How much of this is needed in general?
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
try:
import submitit
except ImportError:
return False
try:
# Only needed when launching on cluster with slurm and submitit
job_env = submitit.JobEnvironment()
except RuntimeError:
return False
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
os.environ["RANK"] = str(job_env.global_rank)
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "42918"
logger.info(
"Num tasks %s, global_rank %s"
% (str(job_env.num_tasks), str(job_env.global_rank))
)
return True
expand_args_fields(ExperimentConfig)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
# Set the device
device = "cpu"
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
device = f"cuda:{cfg.gpu_idx}"
logger.info(f"Running experiment on device: {device}")
run_training(cfg, device)
# CUDA_VISIBLE_DEVICES must have been set.
if "CUDA_DEVICE_ORDER" not in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if not _setup_envvars_for_cluster():
logger.info("Running locally")
# TODO: The following may be needed for hydra/submitit it to work
expand_args_fields(GenericModel)
expand_args_fields(AdaptiveRaySampler)
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
expand_args_fields(ImplicitronDataSource)
run_training(cfg)
if __name__ == "__main__":

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
from .optimization import init_optimizer
class ExperimentConfig(Configurable):
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)

View File

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import enable_get_default_args
logger = logging.getLogger(__name__)
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)

View File

@ -8,11 +8,12 @@ import os
import unittest
from pathlib import Path
import experiment
import torch
from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf
from .. import experiment
def interactive_testing_requested() -> bool:
"""

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
@ -38,6 +37,8 @@ from pytorch3d.implicitron.tools.vis_utils import (
)
from tqdm import tqdm
from .experiment import init_model
def render_sequence(
dataset: DatasetBase,

View File

@ -9,6 +9,7 @@ import logging
import os
import shutil
import tempfile
from typing import Optional
import torch
@ -99,14 +100,14 @@ def save_model(model, stats, fl, optimizer=None, cfg=None):
return flstats, flmodel, flopt
def load_model(fl):
def load_model(fl, map_location: Optional[dict]):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
flopt = get_optimizer_path(fl)
model_state_dict = torch.load(flmodel)
model_state_dict = torch.load(flmodel, map_location=map_location)
stats = load_stats(flstats)
if os.path.isfile(flopt):
optimizer = torch.load(flopt)
optimizer = torch.load(flopt, map_location=map_location)
else:
optimizer = None