mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
extract camera_difficulty_bin_breaks
Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level. Reviewed By: davnov134 Differential Revision: D37491040 fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
40fb189c29
commit
efb721320a
@@ -27,3 +27,6 @@ solver_args:
|
||||
max_epochs: 3000
|
||||
milestones:
|
||||
- 1000
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.666667
|
||||
- 0.833334
|
||||
|
||||
@@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
_run_eval(
|
||||
model, all_train_cameras, dataloaders.test, task, device=device
|
||||
model,
|
||||
all_train_cameras,
|
||||
dataloaders.test,
|
||||
task,
|
||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
@@ -588,7 +593,14 @@ def _eval_and_dump(
|
||||
if dataloader is None:
|
||||
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||
|
||||
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
|
||||
results = _run_eval(
|
||||
model,
|
||||
all_train_cameras,
|
||||
dataloader,
|
||||
task,
|
||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
for r in results:
|
||||
@@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
|
||||
return frame_data_for_eval
|
||||
|
||||
|
||||
def _run_eval(model, all_train_cameras, loader, task: Task, device):
|
||||
def _run_eval(
|
||||
model,
|
||||
all_train_cameras,
|
||||
loader,
|
||||
task: Task,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||
device,
|
||||
):
|
||||
"""
|
||||
Run the evaluation loop on the test dataloader
|
||||
"""
|
||||
@@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
|
||||
)
|
||||
|
||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task
|
||||
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||
)
|
||||
|
||||
return category_result["results"]
|
||||
@@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
visualize_interval: int = 1000
|
||||
clip_grad: float = 0.0
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
|
||||
@@ -375,6 +375,9 @@ visdom_port: 8097
|
||||
visdom_server: http://127.0.0.1
|
||||
visualize_interval: 1000
|
||||
clip_grad: 0.0
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.97
|
||||
- 0.98
|
||||
hydra:
|
||||
run:
|
||||
dir: .
|
||||
|
||||
Reference in New Issue
Block a user