Thomas Polasek 055ab3a2e3 Convert directory fbcode/vision to use the Ruff Formatter
Summary:
Converts the directory specified to use the Ruff formatter in pyfmt

ruff_dog

If this diff causes merge conflicts when rebasing, please run
`hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt`
on your diff, and amend any changes before rebasing onto latest.
That should help reduce or eliminate any merge conflicts.

allow-large-files

Reviewed By: bottler

Differential Revision: D66472063

fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
2024-11-26 02:38:20 -08:00

138 lines
4.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging
import os
from typing import Optional
import torch.optim
from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats
logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase):
resume: bool = True # resume from the last checkpoint
def __call__(self, **kwargs) -> ImplicitronModelBase:
"""
Initialize the model (possibly from a previously saved state).
Returns: An instance of ImplicitronModelBase.
"""
raise NotImplementedError()
def load_stats(self, **kwargs) -> Stats:
"""
Initialize or load a Stats object.
"""
raise NotImplementedError()
@registry.register
class ImplicitronModelFactory(ModelFactoryBase):
"""
A factory class that initializes an implicit rendering model.
Members:
model: An ImplicitronModelBase object.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a model with ini-
tial weights unless `force_resume` is True.
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
`resume_epoch` <= 0, then resume from the latest checkpoint.
force_resume: If True, throw a FileNotFoundError if `resume` is True but
a model checkpoint cannot be found.
"""
# pyre-fixme[13]: Attribute `model` is never initialized.
model: ImplicitronModelBase
model_class_type: str = "GenericModel"
resume: bool = True
resume_epoch: int = -1
force_resume: bool = False
def __post_init__(self):
run_auto_creation(self)
def __call__(
self,
exp_dir: str,
accelerator: Optional[Accelerator] = None,
) -> ImplicitronModelBase:
"""
Returns an instance of `ImplicitronModelBase`, possibly loaded from a
checkpoint (if self.resume, self.resume_epoch specify so).
Args:
exp_dir: Root experiment directory.
accelerator: An Accelerator object.
Returns:
model: The model with optionally loaded weights from checkpoint
Raise:
FileNotFoundError if `force_resume` is True but checkpoint not found.
"""
# Determine the network outputs that should be logged
if hasattr(self.model, "log_vars"):
log_vars = list(self.model.log_vars)
else:
log_vars = ["objective"]
if self.resume_epoch > 0:
# Resume from a certain epoch
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
if not os.path.isfile(model_path):
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
else:
# Retrieve the last checkpoint
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
logger.info(f"Found previous model {model_path}")
if self.force_resume or self.resume:
logger.info("Resuming.")
map_location = None
if accelerator is not None and not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
model_state_dict = torch.load(
model_io.get_model_path(model_path),
map_location=map_location,
weights_only=True,
)
try:
self.model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info(
"Cannot load state dict in strict mode! -> trying non-strict"
)
self.model.load_state_dict(model_state_dict, strict=False)
self.model.log_vars = log_vars
else:
logger.info("Not resuming -> starting from scratch.")
elif self.force_resume:
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
return self.model