extract camera_difficulty_bin_breaks

Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level.

Reviewed By: davnov134

Differential Revision: D37491040

fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
Jeremy Reizenstein
2022-07-06 07:13:41 -07:00
committed by Facebook GitHub Bot
parent 40fb189c29
commit efb721320a
5 changed files with 46 additions and 16 deletions

View File

@@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
)
)
if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98
else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]

View File

@@ -9,7 +9,7 @@ import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def _get_camera_difficulty_bin_edges(task: Task):
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
"""
Get the edges of camera difficulty bins.
"""
_eps = 1e-5
if task == Task.MULTI_SEQUENCE:
# TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
lower, upper = camera_difficulty_bin_breaks
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names
@@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
@@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
@@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
# init the result database dict
results = []
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
camera_difficulty_bin_breaks
)
n_diff_edges = diff_bin_edges.numel()
# add per set averages