Replace pluggable components to create a proper Configurable hierarchy.

Summary:
This large diff rewrites a significant portion of Implicitron's config hierarchy. The new hierarchy, and some of the default implementation classes, are as follows:
```
Experiment
    data_source: ImplicitronDataSource
        dataset_map_provider
        data_loader_map_provider
    model_factory: ImplicitronModelFactory
        model: GenericModel
    optimizer_factory: ImplicitronOptimizerFactory
    training_loop: ImplicitronTrainingLoop
        evaluator: ImplicitronEvaluator
```

1) Experiment (used to be ExperimentConfig) is now a top-level Configurable and contains as members mainly (mostly new) high-level factory Configurables.
2) Experiment's job is to run factories, do some accelerate setup and then pass the results to the main training loop.
3) ImplicitronOptimizerFactory and ImplicitronModelFactory are new high-level factories that create the optimizer, scheduler, model, and stats objects.
4) TrainingLoop is a new configurable that runs the main training loop and the inner train-validate step.
5) Evaluator is a new configurable that TrainingLoop uses to run validation/test steps.
6) GenericModel is not the only model choice anymore. Instead, ImplicitronModelBase (by default instantiated with GenericModel) is a member of Experiment and can be easily replaced by a custom implementation by the user.

All the new Configurables are children of ReplaceableBase, and can be easily replaced with custom implementations.

In addition, I added support for the exponential LR schedule, updated the config files and the test, as well as added a config file that reproduces NERF results and a test to run the repro experiment.

Reviewed By: bottler

Differential Revision: D37723227

fbshipit-source-id: b36bee880d6aa53efdd2abfaae4489d8ab1e8a27
This commit is contained in:
Krzysztof Chalupka
2022-07-29 17:32:51 -07:00
committed by Facebook GitHub Bot
parent 6b481595f0
commit 1b0584f7bd
42 changed files with 2045 additions and 1478 deletions

View File

@@ -40,6 +40,9 @@ class DataSourceBase(ReplaceableBase):
"""
raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
@registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]

View File

@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import lpips
import torch
import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
class EvaluatorBase(ReplaceableBase):
"""
Evaluate a trained model on given data. Returns a dict of loss/objective
names and their values.
"""
def run(
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
) -> Dict[str, Any]:
"""
Evaluate the results of Implicitron training.
"""
raise NotImplementedError()
@registry.register
class ImplicitronEvaluator(EvaluatorBase):
"""
Evaluate the results of Implicitron training.
Members:
camera_difficulty_bin_breaks: low/medium vals to divide camera difficulties into
[0-eps, low, medium, 1+eps].
"""
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
def __post_init__(self):
run_auto_creation(self)
def run(
self,
model: ImplicitronModelBase,
dataloader: DataLoader,
task: Task,
all_train_cameras: Optional[CamerasBase],
device: torch.device,
dump_to_json: bool = False,
exp_dir: Optional[str] = None,
epoch: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Evaluate the results of Implicitron training. Optionally, dump results to
exp_dir/results_test.json.
Args:
model: A (trained) model to evaluate.
dataloader: A test dataloader.
task: Type of the novel-view synthesis task we're working on.
all_train_cameras: Camera instances we used for training.
device: A torch device.
dump_to_json: If True, will dump the results to a json file.
exp_dir: Root expeirment directory.
epoch: Evaluation epoch (to be stored in the results dict).
Returns:
A dictionary of results.
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(dataloader):
frame_data = frame_data.to(device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
with torch.no_grad():
preds = model(
**{
**frame_data_for_eval,
"evaluation_mode": EvaluationMode.EVALUATION,
}
)
implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
implicitron_render,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_train_cameras,
)
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task, self.camera_difficulty_bin_breaks
)
results = category_result["results"]
if dump_to_json:
_dump_to_json(epoch, exp_dir, results)
return category_result["results"]
def _dump_to_json(
epoch: Optional[int], exp_dir: Optional[str], results: List[Dict[str, Any]]
) -> None:
if epoch is not None:
for r in results:
r["eval_epoch"] = int(epoch)
logger.info("Evaluation results")
evaluate.pretty_print_nvs_metrics(results)
if exp_dir is None:
raise ValueError("Cannot save results to json without a specified save path.")
with open(os.path.join(exp_dir, "results_test.json"), "w") as f:
json.dump(results, f)
def _get_eval_frame_data(frame_data: Any) -> Any:
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval = copy.deepcopy(frame_data)
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
frame_data.image_rgb
)[:, None, None, None]
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
setattr(frame_data_for_eval, k, value_masked)
return frame_data_for_eval

View File

@@ -37,10 +37,12 @@ class ImplicitronRender:
)
class ImplicitronModelBase(ReplaceableBase):
class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
"""
Replaceable abstract base for all image generation / rendering models.
`forward()` method produces a render with a depth map.
`forward()` method produces a render with a depth map. Derives from Module
so we can rely on basic functionality provided to torch for model
optimization.
"""
def __init__(self) -> None:

View File

@@ -16,10 +16,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import tqdm
from pytorch3d.implicitron.models.metrics import ( # noqa
RegularizationMetrics,
from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase,
ViewMetrics,
ViewMetricsBase,
)
from pytorch3d.implicitron.tools import image_utils, vis_utils
@@ -67,7 +65,7 @@ logger = logging.getLogger(__name__)
@registry.register
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
"""
GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists

View File

@@ -22,7 +22,7 @@ from .renderer.base import EvaluationMode
@registry.register
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
class ModelDBIR(ImplicitronModelBase):
"""
A simple depth-based image rendering model.

View File

@@ -218,7 +218,7 @@ class AdaptiveRaySampler(AbstractMaskRaySampler):
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
"""
Returns the adaptivelly calculated near/far planes.
Returns the adaptively calculated near/far planes.
"""
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self._scene_center, self.scene_extent

View File

@@ -74,6 +74,7 @@ class Stats(object):
"""
stats logging object useful for gathering statistics of training a deep net in pytorch
Example:
```
# init stats structure that logs statistics 'objective' and 'top1e'
stats = Stats( ('objective','top1e') )
network = init_net() # init a pytorch module (=nueral network)
@@ -94,6 +95,7 @@ class Stats(object):
# stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
```
"""
def __init__(

View File

@@ -14,20 +14,22 @@ from visdom import Visdom
logger = logging.getLogger(__name__)
def get_visdom_env(cfg):
def get_visdom_env(visdom_env: str, exp_dir: str) -> str:
"""
Parse out visdom environment name from the input config.
Args:
cfg: The global config file.
visdom_env: Name of the wisdom environment, could be empty string.
exp_dir: Root experiment directory.
Returns:
visdom_env: The name of the visdom environment.
visdom_env: The name of the visdom environment. If the given visdom_env is
empty, return the name of the bottom directory in exp_dir.
"""
if len(cfg.visdom_env) == 0:
visdom_env = cfg.exp_dir.split("/")[-1]
if len(visdom_env) == 0:
visdom_env = exp_dir.split("/")[-1]
else:
visdom_env = cfg.visdom_env
visdom_env = visdom_env
return visdom_env