extract camera_difficulty_bin_breaks

Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level.

Reviewed By: davnov134

Differential Revision: D37491040

fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
Jeremy Reizenstein 2022-07-06 07:13:41 -07:00 committed by Facebook GitHub Bot
parent 40fb189c29
commit efb721320a
5 changed files with 46 additions and 16 deletions

View File

@ -27,3 +27,6 @@ solver_args:
max_epochs: 3000
milestones:
- 1000
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334

View File

@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and epoch % cfg.test_interval == 0
):
_run_eval(
model, all_train_cameras, dataloaders.test, task, device=device
model,
all_train_cameras,
dataloaders.test,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
assert stats.epoch == epoch, "inconsistent stats!"
@ -588,7 +593,14 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
results = _run_eval(
model,
all_train_cameras,
dataloader,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
# add the evaluation epoch to the results
for r in results:
@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def _run_eval(model, all_train_cameras, loader, task: Task, device):
def _run_eval(
model,
all_train_cameras,
loader,
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
):
"""
Run the evaluation loop on the test dataloader
"""
@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]
@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field(
default_factory=lambda: {

View File

@ -375,6 +375,9 @@ visdom_port: 8097
visdom_server: http://127.0.0.1
visualize_interval: 1000
clip_grad: 0.0
camera_difficulty_bin_breaks:
- 0.97
- 0.98
hydra:
run:
dir: .

View File

@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
)
)
if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98
else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]

View File

@ -9,7 +9,7 @@ import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def _get_camera_difficulty_bin_edges(task: Task):
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
"""
Get the edges of camera difficulty bins.
"""
_eps = 1e-5
if task == Task.MULTI_SEQUENCE:
# TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
lower, upper = camera_difficulty_bin_breaks
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names
@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
# init the result database dict
results = []
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
camera_difficulty_bin_breaks
)
n_diff_edges = diff_bin_edges.numel()
# add per set averages