mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
extract camera_difficulty_bin_breaks
Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level. Reviewed By: davnov134 Differential Revision: D37491040 fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
40fb189c29
commit
efb721320a
@@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
|
||||
)
|
||||
)
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||
else:
|
||||
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||
|
||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task
|
||||
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||
)
|
||||
|
||||
return category_result["results"]
|
||||
|
||||
@@ -9,7 +9,7 @@ import copy
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
|
||||
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
||||
|
||||
|
||||
def _get_camera_difficulty_bin_edges(task: Task):
|
||||
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
|
||||
"""
|
||||
Get the edges of camera difficulty bins.
|
||||
"""
|
||||
_eps = 1e-5
|
||||
if task == Task.MULTI_SEQUENCE:
|
||||
# TODO: extract those to constants
|
||||
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
|
||||
diff_bin_edges[0] = 0.0 - _eps
|
||||
elif task == Task.SINGLE_SEQUENCE:
|
||||
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
|
||||
else:
|
||||
raise ValueError(f"No such eval task {task}.")
|
||||
lower, upper = camera_difficulty_bin_breaks
|
||||
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
|
||||
diff_bin_names = ["hard", "medium", "easy"]
|
||||
return diff_bin_edges, diff_bin_names
|
||||
|
||||
@@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
|
||||
def summarize_nvs_eval_results(
|
||||
per_batch_eval_results: List[Dict[str, Any]],
|
||||
task: Task,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
||||
):
|
||||
"""
|
||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||
@@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
|
||||
Args:
|
||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||
task: The type of the new-view synthesis task.
|
||||
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||
|
||||
|
||||
Returns:
|
||||
nvs_results_flat: A flattened dict of all aggregate metrics.
|
||||
@@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
|
||||
# init the result database dict
|
||||
results = []
|
||||
|
||||
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
|
||||
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
|
||||
camera_difficulty_bin_breaks
|
||||
)
|
||||
n_diff_edges = diff_bin_edges.numel()
|
||||
|
||||
# add per set averages
|
||||
|
||||
Reference in New Issue
Block a user