extract camera_difficulty_bin_breaks

Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level.

Reviewed By: davnov134

Differential Revision: D37491040

fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
Jeremy Reizenstein
2022-07-06 07:13:41 -07:00
committed by Facebook GitHub Bot
parent 40fb189c29
commit efb721320a
5 changed files with 46 additions and 16 deletions

View File

@@ -27,3 +27,6 @@ solver_args:
max_epochs: 3000
milestones:
- 1000
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334

View File

@@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and epoch % cfg.test_interval == 0
):
_run_eval(
model, all_train_cameras, dataloaders.test, task, device=device
model,
all_train_cameras,
dataloaders.test,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
assert stats.epoch == epoch, "inconsistent stats!"
@@ -588,7 +593,14 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
results = _run_eval(
model,
all_train_cameras,
dataloader,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
# add the evaluation epoch to the results
for r in results:
@@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def _run_eval(model, all_train_cameras, loader, task: Task, device):
def _run_eval(
model,
all_train_cameras,
loader,
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
):
"""
Run the evaluation loop on the test dataloader
"""
@@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]
@@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field(
default_factory=lambda: {

View File

@@ -375,6 +375,9 @@ visdom_port: 8097
visdom_server: http://127.0.0.1
visualize_interval: 1000
clip_grad: 0.0
camera_difficulty_bin_breaks:
- 0.97
- 0.98
hydra:
run:
dir: .