extract camera_difficulty_bin_breaks

Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level.

Reviewed By: davnov134

Differential Revision: D37491040

fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
Jeremy Reizenstein 2022-07-06 07:13:41 -07:00 committed by Facebook GitHub Bot
parent 40fb189c29
commit efb721320a
5 changed files with 46 additions and 16 deletions

View File

@ -27,3 +27,6 @@ solver_args:
max_epochs: 3000 max_epochs: 3000
milestones: milestones:
- 1000 - 1000
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334

View File

@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and epoch % cfg.test_interval == 0 and epoch % cfg.test_interval == 0
): ):
_run_eval( _run_eval(
model, all_train_cameras, dataloaders.test, task, device=device model,
all_train_cameras,
dataloaders.test,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
) )
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
@ -588,7 +593,14 @@ def _eval_and_dump(
if dataloader is None: if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!') raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
results = _run_eval(model, all_train_cameras, dataloader, task, device=device) results = _run_eval(
model,
all_train_cameras,
dataloader,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
# add the evaluation epoch to the results # add the evaluation epoch to the results
for r in results: for r in results:
@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval return frame_data_for_eval
def _run_eval(model, all_train_cameras, loader, task: Task, device): def _run_eval(
model,
all_train_cameras,
loader,
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
):
""" """
Run the evaluation loop on the test dataloader Run the evaluation loop on the test dataloader
""" """
@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
) )
_, category_result = evaluate.summarize_nvs_eval_results( _, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task per_batch_eval_results, task, camera_difficulty_bin_breaks
) )
return category_result["results"] return category_result["results"]
@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
visdom_server: str = "http://127.0.0.1" visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000 visualize_interval: int = 1000
clip_grad: float = 0.0 clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field( hydra: dict = field(
default_factory=lambda: { default_factory=lambda: {

View File

@ -375,6 +375,9 @@ visdom_port: 8097
visdom_server: http://127.0.0.1 visdom_server: http://127.0.0.1
visualize_interval: 1000 visualize_interval: 1000
clip_grad: 0.0 clip_grad: 0.0
camera_difficulty_bin_breaks:
- 0.97
- 0.98
hydra: hydra:
run: run:
dir: . dir: .

View File

@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
) )
) )
if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98
else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
category_result_flat, category_result = summarize_nvs_eval_results( category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task per_batch_eval_results, task, camera_difficulty_bin_breaks
) )
return category_result["results"] return category_result["results"]

View File

@ -9,7 +9,7 @@ import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean() return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def _get_camera_difficulty_bin_edges(task: Task): def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
""" """
Get the edges of camera difficulty bins. Get the edges of camera difficulty bins.
""" """
_eps = 1e-5 _eps = 1e-5
if task == Task.MULTI_SEQUENCE: lower, upper = camera_difficulty_bin_breaks
# TODO: extract those to constants diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
diff_bin_names = ["hard", "medium", "easy"] diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names return diff_bin_edges, diff_bin_names
@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
def summarize_nvs_eval_results( def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]], per_batch_eval_results: List[Dict[str, Any]],
task: Task, task: Task,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
): ):
""" """
Compile the per-batch evaluation results `per_batch_eval_results` into Compile the per-batch evaluation results `per_batch_eval_results` into
@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
Args: Args:
per_batch_eval_results: Metrics of each per-batch evaluation. per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task. task: The type of the new-view synthesis task.
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
Returns: Returns:
nvs_results_flat: A flattened dict of all aggregate metrics. nvs_results_flat: A flattened dict of all aggregate metrics.
@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
# init the result database dict # init the result database dict
results = [] results = []
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task) diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
camera_difficulty_bin_breaks
)
n_diff_edges = diff_bin_edges.numel() n_diff_edges = diff_bin_edges.numel()
# add per set averages # add per set averages