mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Replace pluggable components to create a proper Configurable hierarchy.
Summary:
This large diff rewrites a significant portion of Implicitron's config hierarchy. The new hierarchy, and some of the default implementation classes, are as follows:
```
Experiment
data_source: ImplicitronDataSource
dataset_map_provider
data_loader_map_provider
model_factory: ImplicitronModelFactory
model: GenericModel
optimizer_factory: ImplicitronOptimizerFactory
training_loop: ImplicitronTrainingLoop
evaluator: ImplicitronEvaluator
```
1) Experiment (used to be ExperimentConfig) is now a top-level Configurable and contains as members mainly (mostly new) high-level factory Configurables.
2) Experiment's job is to run factories, do some accelerate setup and then pass the results to the main training loop.
3) ImplicitronOptimizerFactory and ImplicitronModelFactory are new high-level factories that create the optimizer, scheduler, model, and stats objects.
4) TrainingLoop is a new configurable that runs the main training loop and the inner train-validate step.
5) Evaluator is a new configurable that TrainingLoop uses to run validation/test steps.
6) GenericModel is not the only model choice anymore. Instead, ImplicitronModelBase (by default instantiated with GenericModel) is a member of Experiment and can be easily replaced by a custom implementation by the user.
All the new Configurables are children of ReplaceableBase, and can be easily replaced with custom implementations.
In addition, I added support for the exponential LR schedule, updated the config files and the test, as well as added a config file that reproduces NERF results and a test to run the repro experiment.
Reviewed By: bottler
Differential Revision: D37723227
fbshipit-source-id: b36bee880d6aa53efdd2abfaae4489d8ab1e8a27
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6b481595f0
commit
1b0584f7bd
@@ -40,6 +40,9 @@ class DataSourceBase(ReplaceableBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
|
||||
161
pytorch3d/implicitron/evaluation/evaluator.py
Normal file
161
pytorch3d/implicitron/evaluation/evaluator.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluatorBase(ReplaceableBase):
|
||||
"""
|
||||
Evaluate a trained model on given data. Returns a dict of loss/objective
|
||||
names and their values.
|
||||
"""
|
||||
|
||||
def run(
|
||||
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate the results of Implicitron training.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronEvaluator(EvaluatorBase):
|
||||
"""
|
||||
Evaluate the results of Implicitron training.
|
||||
|
||||
Members:
|
||||
camera_difficulty_bin_breaks: low/medium vals to divide camera difficulties into
|
||||
[0-eps, low, medium, 1+eps].
|
||||
"""
|
||||
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def run(
|
||||
self,
|
||||
model: ImplicitronModelBase,
|
||||
dataloader: DataLoader,
|
||||
task: Task,
|
||||
all_train_cameras: Optional[CamerasBase],
|
||||
device: torch.device,
|
||||
dump_to_json: bool = False,
|
||||
exp_dir: Optional[str] = None,
|
||||
epoch: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate the results of Implicitron training. Optionally, dump results to
|
||||
exp_dir/results_test.json.
|
||||
|
||||
Args:
|
||||
model: A (trained) model to evaluate.
|
||||
dataloader: A test dataloader.
|
||||
task: Type of the novel-view synthesis task we're working on.
|
||||
all_train_cameras: Camera instances we used for training.
|
||||
device: A torch device.
|
||||
dump_to_json: If True, will dump the results to a json file.
|
||||
exp_dir: Root expeirment directory.
|
||||
epoch: Evaluation epoch (to be stored in the results dict).
|
||||
|
||||
Returns:
|
||||
A dictionary of results.
|
||||
"""
|
||||
lpips_model = lpips.LPIPS(net="vgg")
|
||||
lpips_model = lpips_model.to(device)
|
||||
|
||||
model.eval()
|
||||
|
||||
per_batch_eval_results = []
|
||||
logger.info("Evaluating model ...")
|
||||
for frame_data in tqdm.tqdm(dataloader):
|
||||
frame_data = frame_data.to(device)
|
||||
|
||||
# mask out the unknown images so that the model does not see them
|
||||
frame_data_for_eval = _get_eval_frame_data(frame_data)
|
||||
|
||||
with torch.no_grad():
|
||||
preds = model(
|
||||
**{
|
||||
**frame_data_for_eval,
|
||||
"evaluation_mode": EvaluationMode.EVALUATION,
|
||||
}
|
||||
)
|
||||
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||
per_batch_eval_results.append(
|
||||
evaluate.eval_batch(
|
||||
frame_data,
|
||||
implicitron_render,
|
||||
bg_color="black",
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task, self.camera_difficulty_bin_breaks
|
||||
)
|
||||
|
||||
results = category_result["results"]
|
||||
if dump_to_json:
|
||||
_dump_to_json(epoch, exp_dir, results)
|
||||
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _dump_to_json(
|
||||
epoch: Optional[int], exp_dir: Optional[str], results: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
if epoch is not None:
|
||||
for r in results:
|
||||
r["eval_epoch"] = int(epoch)
|
||||
logger.info("Evaluation results")
|
||||
|
||||
evaluate.pretty_print_nvs_metrics(results)
|
||||
if exp_dir is None:
|
||||
raise ValueError("Cannot save results to json without a specified save path.")
|
||||
with open(os.path.join(exp_dir, "results_test.json"), "w") as f:
|
||||
json.dump(results, f)
|
||||
|
||||
|
||||
def _get_eval_frame_data(frame_data: Any) -> Any:
|
||||
"""
|
||||
Masks the unknown image data to make sure we cannot use it at model evaluation time.
|
||||
"""
|
||||
frame_data_for_eval = copy.deepcopy(frame_data)
|
||||
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
|
||||
frame_data.image_rgb
|
||||
)[:, None, None, None]
|
||||
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
|
||||
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
|
||||
setattr(frame_data_for_eval, k, value_masked)
|
||||
return frame_data_for_eval
|
||||
@@ -37,10 +37,12 @@ class ImplicitronRender:
|
||||
)
|
||||
|
||||
|
||||
class ImplicitronModelBase(ReplaceableBase):
|
||||
class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Replaceable abstract base for all image generation / rendering models.
|
||||
`forward()` method produces a render with a depth map.
|
||||
`forward()` method produces a render with a depth map. Derives from Module
|
||||
so we can rely on basic functionality provided to torch for model
|
||||
optimization.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
@@ -16,10 +16,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from pytorch3d.implicitron.models.metrics import ( # noqa
|
||||
RegularizationMetrics,
|
||||
from pytorch3d.implicitron.models.metrics import (
|
||||
RegularizationMetricsBase,
|
||||
ViewMetrics,
|
||||
ViewMetricsBase,
|
||||
)
|
||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||
@@ -67,7 +65,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
"""
|
||||
GenericModel is a wrapper for the neural implicit
|
||||
rendering and reconstruction pipeline which consists
|
||||
|
||||
@@ -22,7 +22,7 @@ from .renderer.base import EvaluationMode
|
||||
|
||||
|
||||
@registry.register
|
||||
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
||||
class ModelDBIR(ImplicitronModelBase):
|
||||
"""
|
||||
A simple depth-based image rendering model.
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ class AdaptiveRaySampler(AbstractMaskRaySampler):
|
||||
|
||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||
"""
|
||||
Returns the adaptivelly calculated near/far planes.
|
||||
Returns the adaptively calculated near/far planes.
|
||||
"""
|
||||
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
||||
cameras, self._scene_center, self.scene_extent
|
||||
|
||||
@@ -74,6 +74,7 @@ class Stats(object):
|
||||
"""
|
||||
stats logging object useful for gathering statistics of training a deep net in pytorch
|
||||
Example:
|
||||
```
|
||||
# init stats structure that logs statistics 'objective' and 'top1e'
|
||||
stats = Stats( ('objective','top1e') )
|
||||
network = init_net() # init a pytorch module (=nueral network)
|
||||
@@ -94,6 +95,7 @@ class Stats(object):
|
||||
# stores the training plots into '/tmp/epoch_stats.pdf'
|
||||
# and plots into a visdom server running at localhost (if running)
|
||||
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -14,20 +14,22 @@ from visdom import Visdom
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_visdom_env(cfg):
|
||||
def get_visdom_env(visdom_env: str, exp_dir: str) -> str:
|
||||
"""
|
||||
Parse out visdom environment name from the input config.
|
||||
|
||||
Args:
|
||||
cfg: The global config file.
|
||||
visdom_env: Name of the wisdom environment, could be empty string.
|
||||
exp_dir: Root experiment directory.
|
||||
|
||||
Returns:
|
||||
visdom_env: The name of the visdom environment.
|
||||
visdom_env: The name of the visdom environment. If the given visdom_env is
|
||||
empty, return the name of the bottom directory in exp_dir.
|
||||
"""
|
||||
if len(cfg.visdom_env) == 0:
|
||||
visdom_env = cfg.exp_dir.split("/")[-1]
|
||||
if len(visdom_env) == 0:
|
||||
visdom_env = exp_dir.split("/")[-1]
|
||||
else:
|
||||
visdom_env = cfg.visdom_env
|
||||
visdom_env = visdom_env
|
||||
return visdom_env
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user