Move load_stats to TrainingLoop

Summary:
Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop.

Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc.

Reviewed By: kjchalup

Differential Revision: D38313475

fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
David Novotny 2022-08-02 15:40:53 -07:00 committed by Facebook GitHub Bot
parent 760305e044
commit c3f8dad55c
11 changed files with 259 additions and 189 deletions

View File

@ -3,6 +3,7 @@ defaults:
- _self_
exp_dir: ./data/exps/base/
training_loop_ImplicitronTrainingLoop_args:
visdom_port: 8097
visualize_interval: 0
max_epochs: 1000
data_source_ImplicitronDataSource_args:
@ -22,7 +23,6 @@ data_source_ImplicitronDataSource_args:
mask_depths: false
mask_images: false
model_factory_ImplicitronModelFactory_args:
visdom_port: 8097
model_GenericModel_args:
loss_weights:
loss_mask_bce: 1.0

View File

@ -147,9 +147,6 @@ class Experiment(Configurable): # pyre-ignore: 13
run_auto_creation(self)
def run(self) -> None:
# Make sure the config settings are self-consistent.
self._check_config_consistent()
# Initialize the accelerator if desired.
if no_accelerate:
accelerator = None
@ -176,9 +173,11 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir,
)
stats = self.model_factory.load_stats(
exp_dir=self.exp_dir,
stats = self.training_loop.load_stats(
log_vars=model.log_vars,
exp_dir=self.exp_dir,
resume=self.model_factory.resume,
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
)
start_epoch = stats.epoch + 1
@ -190,6 +189,8 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir,
last_epoch=start_epoch,
model=model,
resume=self.model_factory.resume,
resume_epoch=self.model_factory.resume_epoch,
)
# Wrap all modules in the distributed library
@ -224,26 +225,6 @@ class Experiment(Configurable): # pyre-ignore: 13
seed=self.seed,
)
def _check_config_consistent(self) -> None:
if hasattr(self.optimizer_factory, "resume") and hasattr(
self.model_factory, "resume"
):
assert (
# pyre-ignore [16]
not self.optimizer_factory.resume
# pyre-ignore [16]
or self.model_factory.resume
), "Cannot resume the optimizer without resuming the model."
if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
self.model_factory, "resume_epoch"
):
assert (
# pyre-ignore [16]
self.optimizer_factory.resume_epoch
# pyre-ignore [16]
== self.model_factory.resume_epoch
), "Optimizer and model must resume from the same epoch."
def _setup_envvars_for_cluster() -> bool:
"""

View File

@ -6,13 +6,13 @@
import logging
import os
from typing import List, Optional
from typing import Optional
import torch.optim
from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools import model_io
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase):
resume: bool = True # resume from the last checkpoint
def __call__(self, **kwargs) -> ImplicitronModelBase:
"""
Initialize the model (possibly from a previously saved state).
@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
A factory class that initializes an implicit rendering model.
Members:
force_load: If True, throw a FileNotFoundError if `resume` is True but
a model checkpoint cannot be found.
model: An ImplicitronModelBase object.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a model with ini-
tial weights unless `force_load` is True.
tial weights unless `force_resume` is True.
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
`resume_epoch` <= 0, then resume from the latest checkpoint.
visdom_env: The name of the Visdom environment to use for plotting.
visdom_port: The Visdom port.
visdom_server: Address of the Visdom server.
force_resume: If True, throw a FileNotFoundError if `resume` is True but
a model checkpoint cannot be found.
"""
force_load: bool = False
model: ImplicitronModelBase
model_class_type: str = "GenericModel"
resume: bool = False
resume: bool = True
resume_epoch: int = -1
visdom_env: str = ""
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
force_resume: bool = False
def __post_init__(self):
run_auto_creation(self)
@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
model: The model with optionally loaded weights from checkpoint
Raise:
FileNotFoundError if `force_load` is True but checkpoint not found.
FileNotFoundError if `force_resume` is True but checkpoint not found.
"""
# Determine the network outputs that should be logged
if hasattr(self.model, "log_vars"):
log_vars = list(self.model.log_vars) # pyre-ignore [6]
log_vars = list(self.model.log_vars)
else:
log_vars = ["objective"]
# Retrieve the last checkpoint
if self.resume_epoch > 0:
# Resume from a certain epoch
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
if not os.path.isfile(model_path):
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
else:
# Retrieve the last checkpoint
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
logger.info("found previous model %s" % model_path)
if self.force_load or self.resume:
logger.info(" -> resuming")
logger.info(f"Found previous model {model_path}")
if self.force_resume or self.resume:
logger.info("Resuming.")
map_location = None
if accelerator is not None and not accelerator.is_local_main_process:
@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
except RuntimeError as e:
logger.error(e)
logger.info(
"Cant load state dict in strict mode! -> trying non-strict"
"Cannot load state dict in strict mode! -> trying non-strict"
)
self.model.load_state_dict(model_state_dict, strict=False)
self.model.log_vars = log_vars # pyre-ignore [16]
self.model.log_vars = log_vars
else:
logger.info(" -> but not resuming -> starting from scratch")
elif self.force_load:
logger.info("Not resuming -> starting from scratch.")
elif self.force_resume:
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
return self.model
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
clear_stats: bool = False,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
clear_stats: If True, do not load stats from the checkpoint speci-
fied by self.resume and self.resume_epoch; instead, create a
fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
if self.resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
return stats

View File

@ -60,11 +60,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
multistep_lr_milestones: With MultiStepLR policy only: list of
increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a newly initialized
optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
"""
@ -76,8 +71,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
lr_policy: str = "MultiStepLR"
momentum: float = 0.9
multistep_lr_milestones: tuple = ()
resume: bool = False
resume_epoch: int = -1
weight_decay: float = 0.0
def __post_init__(self):
@ -89,6 +82,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
model: ImplicitronModelBase,
accelerator: Optional[Accelerator] = None,
exp_dir: Optional[str] = None,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Tuple[torch.optim.Optimizer, Any]:
"""
@ -100,7 +95,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
model: The model with optionally loaded weights.
accelerator: An optional Accelerator instance.
exp_dir: Root experiment directory.
resume: If True, attempt to load optimizer checkpoint from exp_dir.
Failure to do so will return a newly initialized optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
Returns:
An optimizer module (optionally loaded from a checkpoint) and
a learning rate scheduler module (should be a subclass of torch.optim's
@ -131,13 +129,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
)
else:
raise ValueError("no such solver type %s" % self.breed)
logger.info(" -> solver type = %s" % self.breed)
raise ValueError(f"No such solver type {self.breed}")
logger.info(f"Solver type = {self.breed}")
# Load state from checkpoint
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
optimizer_state = self._get_optimizer_state(
exp_dir,
accelerator,
resume_epoch=resume_epoch,
resume=resume,
)
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
logger.info("Setting loaded optimizer state.")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
@ -169,20 +172,31 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
self,
exp_dir: Optional[str],
accelerator: Optional[Accelerator] = None,
resume: bool = True,
resume_epoch: int = -1,
) -> Optional[Dict[str, Any]]:
"""
Load an optimizer state from a checkpoint.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a newly initialized
optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
"""
if exp_dir is None or not self.resume:
if exp_dir is None or not resume:
return None
if self.resume_epoch > 0:
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
if resume_epoch > 0:
save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
if not os.path.isfile(save_path):
raise FileNotFoundError(
f"Cannot find optimizer from epoch {resume_epoch}."
)
else:
save_path = model_io.find_last_checkpoint(exp_dir)
optimizer_state = None
if save_path is not None:
logger.info(f"Found previous optimizer state {save_path}.")
logger.info(" -> resuming")
logger.info(f"Found previous optimizer state {save_path} -> resuming.")
opt_path = model_io.get_optimizer_path(save_path)
if os.path.isfile(opt_path):
@ -193,5 +207,5 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
}
optimizer_state = torch.load(opt_path, map_location)
else:
optimizer_state = None
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
return optimizer_state

View File

@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.
import logging
import os
import time
from typing import Any, Optional
from typing import Any, List, Optional
import torch
from accelerate import Accelerator
@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
) -> None:
raise NotImplementedError()
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Stats:
raise NotImplementedError()
@registry.register
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env: The name of the Visdom environment to use for plotting.
visdom_port: The Visdom port.
visdom_server: Address of the Visdom server.
"""
# Parameters of the outer training loop.
@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
test_when_finished: bool = False
validation_interval: int = 1
# Parameters of a single training-validation step.
# Gradient clipping.
clip_grad: float = 0.0
# Visualization/logging parameters.
metric_print_interval: int = 5
visualize_interval: int = 1000
visdom_env: str = ""
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
def __post_init__(self):
run_auto_creation(self)
@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
"Cannot evaluate and dump results to json, no test data provided."
)
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars and resume_epoch.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
resume: If False, do not load stats from the checkpoint speci-
fied by resume and resume_epoch; instead, create a fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
model_path = None
if resume:
if resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
if not os.path.isfile(model_path):
raise FileNotFoundError(
f"Cannot find stats from epoch {resume_epoch}."
)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if resume:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
logger.info(f"Found previous stats in {stats_path} -> resuming.")
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info("Clearing stats")
return stats
def _training_or_validation_epoch(
self,
epoch: int,

View File

@ -124,14 +124,32 @@ data_source_ImplicitronDataSource_args:
dataset_length_val: 0
dataset_length_test: 0
model_factory_ImplicitronModelFactory_args:
force_load: false
resume: true
model_class_type: GenericModel
resume: false
resume_epoch: -1
visdom_env: ''
visdom_port: 8097
visdom_server: http://127.0.0.1
force_resume: false
model_GenericModel_args:
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
mask_images: true
mask_depths: true
render_image_width: 400
@ -162,27 +180,6 @@ model_factory_ImplicitronModelFactory_args:
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
global_encoder_HarmonicTimeEncoder_args:
n_harmonic_functions: 10
append_input: true
@ -422,8 +419,6 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
lr_policy: MultiStepLR
momentum: 0.9
multistep_lr_milestones: []
resume: false
resume_epoch: -1
weight_decay: 0.0
training_loop_ImplicitronTrainingLoop_args:
eval_only: false
@ -437,6 +432,9 @@ training_loop_ImplicitronTrainingLoop_args:
clip_grad: 0.0
metric_print_interval: 5
visualize_interval: 1000
visdom_env: ''
visdom_port: 8097
visdom_server: http://127.0.0.1
evaluator_ImplicitronEvaluator_args:
camera_difficulty_bin_breaks:
- 0.97

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import os
import tempfile
import unittest
from pathlib import Path
@ -172,3 +173,48 @@ class TestNerfRepro(unittest.TestCase):
experiment_runner = experiment.Experiment(**cfg)
experiment.dump_cfg(cfg)
experiment_runner.run()
@unittest.skip("This test checks resuming of the NeRF training.")
def test_nerf_blender_resume(self):
# Train one train batch of NeRF, then resume for one more batch.
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
if not interactive_testing_requested():
return
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
with tempfile.TemporaryDirectory() as exp_dir:
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
cfg.exp_dir = exp_dir
# set dataset len to 1
# fmt: off
(
cfg
.data_source_ImplicitronDataSource_args
.data_loader_map_provider_SequenceDataLoaderMapProvider_args
.dataset_length_train
) = 1
# fmt: on
# run for one epoch
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 1
experiment_runner = experiment.Experiment(**cfg)
experiment.dump_cfg(cfg)
experiment_runner.run()
# update num epochs + 2, let the optimizer resume
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 3
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()
# start from scratch
cfg.model_factory_ImplicitronModelFactory_args.resume = False
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()
# force resume from epoch 1
cfg.model_factory_ImplicitronModelFactory_args.resume = True
cfg.model_factory_ImplicitronModelFactory_args.force_resume = True
cfg.model_factory_ImplicitronModelFactory_args.resume_epoch = 1
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()

View File

@ -344,7 +344,7 @@ def export_scenes(
# Load the previously trained model
experiment = Experiment(config)
model = experiment.model_factory(force_load=True, load_model_only=True)
model = experiment.model_factory(force_resume=True)
model.cuda()
model.eval()

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import torch
@ -45,6 +45,10 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
optimization.
"""
# The keys from `preds` (output of ImplicitronModelBase.forward) to be logged in
# the training loop.
log_vars: List[str] = field(default_factory=lambda: ["objective"])
def __init__(self) -> None:
super().__init__()

View File

@ -1,3 +1,24 @@
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
mask_images: true
mask_depths: true
render_image_width: 400
@ -28,27 +49,6 @@ loss_weights:
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0

View File

@ -90,5 +90,5 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
(DATA_DIR / "overrides.yaml_").write_text(yaml)
(DATA_DIR / "overrides_.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())