mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
parent
760305e044
commit
c3f8dad55c
@ -3,6 +3,7 @@ defaults:
|
||||
- _self_
|
||||
exp_dir: ./data/exps/base/
|
||||
training_loop_ImplicitronTrainingLoop_args:
|
||||
visdom_port: 8097
|
||||
visualize_interval: 0
|
||||
max_epochs: 1000
|
||||
data_source_ImplicitronDataSource_args:
|
||||
@ -22,7 +23,6 @@ data_source_ImplicitronDataSource_args:
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
model_factory_ImplicitronModelFactory_args:
|
||||
visdom_port: 8097
|
||||
model_GenericModel_args:
|
||||
loss_weights:
|
||||
loss_mask_bce: 1.0
|
||||
|
@ -147,9 +147,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
run_auto_creation(self)
|
||||
|
||||
def run(self) -> None:
|
||||
# Make sure the config settings are self-consistent.
|
||||
self._check_config_consistent()
|
||||
|
||||
# Initialize the accelerator if desired.
|
||||
if no_accelerate:
|
||||
accelerator = None
|
||||
@ -176,9 +173,11 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
exp_dir=self.exp_dir,
|
||||
)
|
||||
|
||||
stats = self.model_factory.load_stats(
|
||||
exp_dir=self.exp_dir,
|
||||
stats = self.training_loop.load_stats(
|
||||
log_vars=model.log_vars,
|
||||
exp_dir=self.exp_dir,
|
||||
resume=self.model_factory.resume,
|
||||
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
|
||||
)
|
||||
start_epoch = stats.epoch + 1
|
||||
|
||||
@ -190,6 +189,8 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
exp_dir=self.exp_dir,
|
||||
last_epoch=start_epoch,
|
||||
model=model,
|
||||
resume=self.model_factory.resume,
|
||||
resume_epoch=self.model_factory.resume_epoch,
|
||||
)
|
||||
|
||||
# Wrap all modules in the distributed library
|
||||
@ -224,26 +225,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
def _check_config_consistent(self) -> None:
|
||||
if hasattr(self.optimizer_factory, "resume") and hasattr(
|
||||
self.model_factory, "resume"
|
||||
):
|
||||
assert (
|
||||
# pyre-ignore [16]
|
||||
not self.optimizer_factory.resume
|
||||
# pyre-ignore [16]
|
||||
or self.model_factory.resume
|
||||
), "Cannot resume the optimizer without resuming the model."
|
||||
if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
|
||||
self.model_factory, "resume_epoch"
|
||||
):
|
||||
assert (
|
||||
# pyre-ignore [16]
|
||||
self.optimizer_factory.resume_epoch
|
||||
# pyre-ignore [16]
|
||||
== self.model_factory.resume_epoch
|
||||
), "Optimizer and model must resume from the same epoch."
|
||||
|
||||
|
||||
def _setup_envvars_for_cluster() -> bool:
|
||||
"""
|
||||
|
@ -6,13 +6,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch.optim
|
||||
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools import model_io
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
"""
|
||||
Initialize the model (possibly from a previously saved state).
|
||||
@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
A factory class that initializes an implicit rendering model.
|
||||
|
||||
Members:
|
||||
force_load: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
model: An ImplicitronModelBase object.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a model with ini-
|
||||
tial weights unless `force_load` is True.
|
||||
tial weights unless `force_resume` is True.
|
||||
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
force_resume: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
|
||||
"""
|
||||
|
||||
force_load: bool = False
|
||||
model: ImplicitronModelBase
|
||||
model_class_type: str = "GenericModel"
|
||||
resume: bool = False
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
force_resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
model: The model with optionally loaded weights from checkpoint
|
||||
|
||||
Raise:
|
||||
FileNotFoundError if `force_load` is True but checkpoint not found.
|
||||
FileNotFoundError if `force_resume` is True but checkpoint not found.
|
||||
"""
|
||||
# Determine the network outputs that should be logged
|
||||
if hasattr(self.model, "log_vars"):
|
||||
log_vars = list(self.model.log_vars) # pyre-ignore [6]
|
||||
log_vars = list(self.model.log_vars)
|
||||
else:
|
||||
log_vars = ["objective"]
|
||||
|
||||
# Retrieve the last checkpoint
|
||||
if self.resume_epoch > 0:
|
||||
# Resume from a certain epoch
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
|
||||
else:
|
||||
# Retrieve the last checkpoint
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
logger.info("found previous model %s" % model_path)
|
||||
if self.force_load or self.resume:
|
||||
logger.info(" -> resuming")
|
||||
logger.info(f"Found previous model {model_path}")
|
||||
if self.force_resume or self.resume:
|
||||
logger.info("Resuming.")
|
||||
|
||||
map_location = None
|
||||
if accelerator is not None and not accelerator.is_local_main_process:
|
||||
@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
except RuntimeError as e:
|
||||
logger.error(e)
|
||||
logger.info(
|
||||
"Cant load state dict in strict mode! -> trying non-strict"
|
||||
"Cannot load state dict in strict mode! -> trying non-strict"
|
||||
)
|
||||
self.model.load_state_dict(model_state_dict, strict=False)
|
||||
self.model.log_vars = log_vars # pyre-ignore [16]
|
||||
self.model.log_vars = log_vars
|
||||
else:
|
||||
logger.info(" -> but not resuming -> starting from scratch")
|
||||
elif self.force_load:
|
||||
logger.info("Not resuming -> starting from scratch.")
|
||||
elif self.force_resume:
|
||||
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
||||
|
||||
return self.model
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
clear_stats: bool = False,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
clear_stats: If True, do not load stats from the checkpoint speci-
|
||||
fied by self.resume and self.resume_epoch; instead, create a
|
||||
fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
if self.resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if not clear_stats:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info(" -> clearing stats")
|
||||
|
||||
return stats
|
||||
|
@ -60,11 +60,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
multistep_lr_milestones: With MultiStepLR policy only: list of
|
||||
increasing epoch indices at which the learning rate is modified.
|
||||
momentum: Momentum factor for SGD optimizer.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a newly initialized
|
||||
optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
||||
"""
|
||||
|
||||
@ -76,8 +71,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
lr_policy: str = "MultiStepLR"
|
||||
momentum: float = 0.9
|
||||
multistep_lr_milestones: tuple = ()
|
||||
resume: bool = False
|
||||
resume_epoch: int = -1
|
||||
weight_decay: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
@ -89,6 +82,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
model: ImplicitronModelBase,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
exp_dir: Optional[str] = None,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
||||
"""
|
||||
@ -100,7 +95,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
model: The model with optionally loaded weights.
|
||||
accelerator: An optional Accelerator instance.
|
||||
exp_dir: Root experiment directory.
|
||||
|
||||
resume: If True, attempt to load optimizer checkpoint from exp_dir.
|
||||
Failure to do so will return a newly initialized optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
Returns:
|
||||
An optimizer module (optionally loaded from a checkpoint) and
|
||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
||||
@ -131,13 +129,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % self.breed)
|
||||
logger.info(" -> solver type = %s" % self.breed)
|
||||
raise ValueError(f"No such solver type {self.breed}")
|
||||
logger.info(f"Solver type = {self.breed}")
|
||||
|
||||
# Load state from checkpoint
|
||||
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
|
||||
optimizer_state = self._get_optimizer_state(
|
||||
exp_dir,
|
||||
accelerator,
|
||||
resume_epoch=resume_epoch,
|
||||
resume=resume,
|
||||
)
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
logger.info("Setting loaded optimizer state.")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
@ -169,20 +172,31 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
self,
|
||||
exp_dir: Optional[str],
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Load an optimizer state from a checkpoint.
|
||||
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a newly initialized
|
||||
optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
"""
|
||||
if exp_dir is None or not self.resume:
|
||||
if exp_dir is None or not resume:
|
||||
return None
|
||||
if self.resume_epoch > 0:
|
||||
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
if resume_epoch > 0:
|
||||
save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||
if not os.path.isfile(save_path):
|
||||
raise FileNotFoundError(
|
||||
f"Cannot find optimizer from epoch {resume_epoch}."
|
||||
)
|
||||
else:
|
||||
save_path = model_io.find_last_checkpoint(exp_dir)
|
||||
optimizer_state = None
|
||||
if save_path is not None:
|
||||
logger.info(f"Found previous optimizer state {save_path}.")
|
||||
logger.info(" -> resuming")
|
||||
logger.info(f"Found previous optimizer state {save_path} -> resuming.")
|
||||
opt_path = model_io.get_optimizer_path(save_path)
|
||||
|
||||
if os.path.isfile(opt_path):
|
||||
@ -193,5 +207,5 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
else:
|
||||
optimizer_state = None
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
@ -5,8 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
logged.
|
||||
visualize_interval: The batch interval at which the visualizations
|
||||
should be plotted
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
"""
|
||||
|
||||
# Parameters of the outer training loop.
|
||||
@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
|
||||
# Parameters of a single training-validation step.
|
||||
# Gradient clipping.
|
||||
clip_grad: float = 0.0
|
||||
|
||||
# Visualization/logging parameters.
|
||||
metric_print_interval: int = 5
|
||||
visualize_interval: int = 1000
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
"Cannot evaluate and dump results to json, no test data provided."
|
||||
)
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars and resume_epoch.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
resume: If False, do not load stats from the checkpoint speci-
|
||||
fied by resume and resume_epoch; instead, create a fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
|
||||
model_path = None
|
||||
if resume:
|
||||
if resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"Cannot find stats from epoch {resume_epoch}."
|
||||
)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if resume:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
logger.info(f"Found previous stats in {stats_path} -> resuming.")
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info("Clearing stats")
|
||||
|
||||
return stats
|
||||
|
||||
def _training_or_validation_epoch(
|
||||
self,
|
||||
epoch: int,
|
||||
|
@ -124,14 +124,32 @@ data_source_ImplicitronDataSource_args:
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
model_factory_ImplicitronModelFactory_args:
|
||||
force_load: false
|
||||
resume: true
|
||||
model_class_type: GenericModel
|
||||
resume: false
|
||||
resume_epoch: -1
|
||||
visdom_env: ''
|
||||
visdom_port: 8097
|
||||
visdom_server: http://127.0.0.1
|
||||
force_resume: false
|
||||
model_GenericModel_args:
|
||||
log_vars:
|
||||
- loss_rgb_psnr_fg
|
||||
- loss_rgb_psnr
|
||||
- loss_rgb_mse
|
||||
- loss_rgb_huber
|
||||
- loss_depth_abs
|
||||
- loss_depth_abs_fg
|
||||
- loss_mask_neg_iou
|
||||
- loss_mask_bce
|
||||
- loss_mask_beta_prior
|
||||
- loss_eikonal
|
||||
- loss_density_tv
|
||||
- loss_depth_neg_penalty
|
||||
- loss_autodecoder_norm
|
||||
- loss_prev_stage_rgb_mse
|
||||
- loss_prev_stage_rgb_psnr_fg
|
||||
- loss_prev_stage_rgb_psnr
|
||||
- loss_prev_stage_mask_bce
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
mask_images: true
|
||||
mask_depths: true
|
||||
render_image_width: 400
|
||||
@ -162,27 +180,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
loss_prev_stage_rgb_mse: 1.0
|
||||
loss_mask_bce: 0.0
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
log_vars:
|
||||
- loss_rgb_psnr_fg
|
||||
- loss_rgb_psnr
|
||||
- loss_rgb_mse
|
||||
- loss_rgb_huber
|
||||
- loss_depth_abs
|
||||
- loss_depth_abs_fg
|
||||
- loss_mask_neg_iou
|
||||
- loss_mask_bce
|
||||
- loss_mask_beta_prior
|
||||
- loss_eikonal
|
||||
- loss_density_tv
|
||||
- loss_depth_neg_penalty
|
||||
- loss_autodecoder_norm
|
||||
- loss_prev_stage_rgb_mse
|
||||
- loss_prev_stage_rgb_psnr_fg
|
||||
- loss_prev_stage_rgb_psnr
|
||||
- loss_prev_stage_mask_bce
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
global_encoder_HarmonicTimeEncoder_args:
|
||||
n_harmonic_functions: 10
|
||||
append_input: true
|
||||
@ -422,8 +419,6 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
|
||||
lr_policy: MultiStepLR
|
||||
momentum: 0.9
|
||||
multistep_lr_milestones: []
|
||||
resume: false
|
||||
resume_epoch: -1
|
||||
weight_decay: 0.0
|
||||
training_loop_ImplicitronTrainingLoop_args:
|
||||
eval_only: false
|
||||
@ -437,6 +432,9 @@ training_loop_ImplicitronTrainingLoop_args:
|
||||
clip_grad: 0.0
|
||||
metric_print_interval: 5
|
||||
visualize_interval: 1000
|
||||
visdom_env: ''
|
||||
visdom_port: 8097
|
||||
visdom_server: http://127.0.0.1
|
||||
evaluator_ImplicitronEvaluator_args:
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.97
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@ -172,3 +173,48 @@ class TestNerfRepro(unittest.TestCase):
|
||||
experiment_runner = experiment.Experiment(**cfg)
|
||||
experiment.dump_cfg(cfg)
|
||||
experiment_runner.run()
|
||||
|
||||
@unittest.skip("This test checks resuming of the NeRF training.")
|
||||
def test_nerf_blender_resume(self):
|
||||
# Train one train batch of NeRF, then resume for one more batch.
|
||||
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
|
||||
if not interactive_testing_requested():
|
||||
return
|
||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||
with tempfile.TemporaryDirectory() as exp_dir:
|
||||
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
|
||||
cfg.exp_dir = exp_dir
|
||||
|
||||
# set dataset len to 1
|
||||
|
||||
# fmt: off
|
||||
(
|
||||
cfg
|
||||
.data_source_ImplicitronDataSource_args
|
||||
.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
.dataset_length_train
|
||||
) = 1
|
||||
# fmt: on
|
||||
|
||||
# run for one epoch
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 1
|
||||
experiment_runner = experiment.Experiment(**cfg)
|
||||
experiment.dump_cfg(cfg)
|
||||
experiment_runner.run()
|
||||
|
||||
# update num epochs + 2, let the optimizer resume
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 3
|
||||
experiment_runner = experiment.Experiment(**cfg)
|
||||
experiment_runner.run()
|
||||
|
||||
# start from scratch
|
||||
cfg.model_factory_ImplicitronModelFactory_args.resume = False
|
||||
experiment_runner = experiment.Experiment(**cfg)
|
||||
experiment_runner.run()
|
||||
|
||||
# force resume from epoch 1
|
||||
cfg.model_factory_ImplicitronModelFactory_args.resume = True
|
||||
cfg.model_factory_ImplicitronModelFactory_args.force_resume = True
|
||||
cfg.model_factory_ImplicitronModelFactory_args.resume_epoch = 1
|
||||
experiment_runner = experiment.Experiment(**cfg)
|
||||
experiment_runner.run()
|
||||
|
@ -344,7 +344,7 @@ def export_scenes(
|
||||
|
||||
# Load the previously trained model
|
||||
experiment = Experiment(config)
|
||||
model = experiment.model_factory(force_load=True, load_model_only=True)
|
||||
model = experiment.model_factory(force_resume=True)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -45,6 +45,10 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
|
||||
optimization.
|
||||
"""
|
||||
|
||||
# The keys from `preds` (output of ImplicitronModelBase.forward) to be logged in
|
||||
# the training loop.
|
||||
log_vars: List[str] = field(default_factory=lambda: ["objective"])
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -1,3 +1,24 @@
|
||||
log_vars:
|
||||
- loss_rgb_psnr_fg
|
||||
- loss_rgb_psnr
|
||||
- loss_rgb_mse
|
||||
- loss_rgb_huber
|
||||
- loss_depth_abs
|
||||
- loss_depth_abs_fg
|
||||
- loss_mask_neg_iou
|
||||
- loss_mask_bce
|
||||
- loss_mask_beta_prior
|
||||
- loss_eikonal
|
||||
- loss_density_tv
|
||||
- loss_depth_neg_penalty
|
||||
- loss_autodecoder_norm
|
||||
- loss_prev_stage_rgb_mse
|
||||
- loss_prev_stage_rgb_psnr_fg
|
||||
- loss_prev_stage_rgb_psnr
|
||||
- loss_prev_stage_mask_bce
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
mask_images: true
|
||||
mask_depths: true
|
||||
render_image_width: 400
|
||||
@ -28,27 +49,6 @@ loss_weights:
|
||||
loss_prev_stage_rgb_mse: 1.0
|
||||
loss_mask_bce: 0.0
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
log_vars:
|
||||
- loss_rgb_psnr_fg
|
||||
- loss_rgb_psnr
|
||||
- loss_rgb_mse
|
||||
- loss_rgb_huber
|
||||
- loss_depth_abs
|
||||
- loss_depth_abs_fg
|
||||
- loss_mask_neg_iou
|
||||
- loss_mask_bce
|
||||
- loss_mask_beta_prior
|
||||
- loss_eikonal
|
||||
- loss_density_tv
|
||||
- loss_depth_neg_penalty
|
||||
- loss_autodecoder_norm
|
||||
- loss_prev_stage_rgb_mse
|
||||
- loss_prev_stage_rgb_psnr_fg
|
||||
- loss_prev_stage_rgb_psnr
|
||||
- loss_prev_stage_mask_bce
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
|
@ -90,5 +90,5 @@ class TestGenericModel(unittest.TestCase):
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||
(DATA_DIR / "overrides_.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||
|
Loading…
x
Reference in New Issue
Block a user