mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
extract camera_difficulty_bin_breaks
Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level. Reviewed By: davnov134 Differential Revision: D37491040 fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
parent
40fb189c29
commit
efb721320a
@ -27,3 +27,6 @@ solver_args:
|
|||||||
max_epochs: 3000
|
max_epochs: 3000
|
||||||
milestones:
|
milestones:
|
||||||
- 1000
|
- 1000
|
||||||
|
camera_difficulty_bin_breaks:
|
||||||
|
- 0.666667
|
||||||
|
- 0.833334
|
||||||
|
@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
|||||||
and epoch % cfg.test_interval == 0
|
and epoch % cfg.test_interval == 0
|
||||||
):
|
):
|
||||||
_run_eval(
|
_run_eval(
|
||||||
model, all_train_cameras, dataloaders.test, task, device=device
|
model,
|
||||||
|
all_train_cameras,
|
||||||
|
dataloaders.test,
|
||||||
|
task,
|
||||||
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert stats.epoch == epoch, "inconsistent stats!"
|
assert stats.epoch == epoch, "inconsistent stats!"
|
||||||
@ -588,7 +593,14 @@ def _eval_and_dump(
|
|||||||
if dataloader is None:
|
if dataloader is None:
|
||||||
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||||
|
|
||||||
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
|
results = _run_eval(
|
||||||
|
model,
|
||||||
|
all_train_cameras,
|
||||||
|
dataloader,
|
||||||
|
task,
|
||||||
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
# add the evaluation epoch to the results
|
# add the evaluation epoch to the results
|
||||||
for r in results:
|
for r in results:
|
||||||
@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
|
|||||||
return frame_data_for_eval
|
return frame_data_for_eval
|
||||||
|
|
||||||
|
|
||||||
def _run_eval(model, all_train_cameras, loader, task: Task, device):
|
def _run_eval(
|
||||||
|
model,
|
||||||
|
all_train_cameras,
|
||||||
|
loader,
|
||||||
|
task: Task,
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||||
|
device,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Run the evaluation loop on the test dataloader
|
Run the evaluation loop on the test dataloader
|
||||||
"""
|
"""
|
||||||
@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||||
per_batch_eval_results, task
|
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||||
)
|
)
|
||||||
|
|
||||||
return category_result["results"]
|
return category_result["results"]
|
||||||
@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
|
|||||||
visdom_server: str = "http://127.0.0.1"
|
visdom_server: str = "http://127.0.0.1"
|
||||||
visualize_interval: int = 1000
|
visualize_interval: int = 1000
|
||||||
clip_grad: float = 0.0
|
clip_grad: float = 0.0
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||||
|
|
||||||
hydra: dict = field(
|
hydra: dict = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
|
@ -375,6 +375,9 @@ visdom_port: 8097
|
|||||||
visdom_server: http://127.0.0.1
|
visdom_server: http://127.0.0.1
|
||||||
visualize_interval: 1000
|
visualize_interval: 1000
|
||||||
clip_grad: 0.0
|
clip_grad: 0.0
|
||||||
|
camera_difficulty_bin_breaks:
|
||||||
|
- 0.97
|
||||||
|
- 0.98
|
||||||
hydra:
|
hydra:
|
||||||
run:
|
run:
|
||||||
dir: .
|
dir: .
|
||||||
|
@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if task == Task.SINGLE_SEQUENCE:
|
||||||
|
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||||
|
else:
|
||||||
|
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||||
|
|
||||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||||
per_batch_eval_results, task
|
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||||
)
|
)
|
||||||
|
|
||||||
return category_result["results"]
|
return category_result["results"]
|
||||||
|
@ -9,7 +9,7 @@ import copy
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
|
|||||||
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
||||||
|
|
||||||
|
|
||||||
def _get_camera_difficulty_bin_edges(task: Task):
|
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
|
||||||
"""
|
"""
|
||||||
Get the edges of camera difficulty bins.
|
Get the edges of camera difficulty bins.
|
||||||
"""
|
"""
|
||||||
_eps = 1e-5
|
_eps = 1e-5
|
||||||
if task == Task.MULTI_SEQUENCE:
|
lower, upper = camera_difficulty_bin_breaks
|
||||||
# TODO: extract those to constants
|
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
|
||||||
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
|
|
||||||
diff_bin_edges[0] = 0.0 - _eps
|
|
||||||
elif task == Task.SINGLE_SEQUENCE:
|
|
||||||
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No such eval task {task}.")
|
|
||||||
diff_bin_names = ["hard", "medium", "easy"]
|
diff_bin_names = ["hard", "medium", "easy"]
|
||||||
return diff_bin_edges, diff_bin_names
|
return diff_bin_edges, diff_bin_names
|
||||||
|
|
||||||
@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
|
|||||||
def summarize_nvs_eval_results(
|
def summarize_nvs_eval_results(
|
||||||
per_batch_eval_results: List[Dict[str, Any]],
|
per_batch_eval_results: List[Dict[str, Any]],
|
||||||
task: Task,
|
task: Task,
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||||
@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
|
|||||||
Args:
|
Args:
|
||||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||||
task: The type of the new-view synthesis task.
|
task: The type of the new-view synthesis task.
|
||||||
|
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nvs_results_flat: A flattened dict of all aggregate metrics.
|
nvs_results_flat: A flattened dict of all aggregate metrics.
|
||||||
@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
|
|||||||
# init the result database dict
|
# init the result database dict
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
|
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
|
||||||
|
camera_difficulty_bin_breaks
|
||||||
|
)
|
||||||
n_diff_edges = diff_bin_edges.numel()
|
n_diff_edges = diff_bin_edges.numel()
|
||||||
|
|
||||||
# add per set averages
|
# add per set averages
|
||||||
|
Loading…
x
Reference in New Issue
Block a user