mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	extract camera_difficulty_bin_breaks
Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level. Reviewed By: davnov134 Differential Revision: D37491040 fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
This commit is contained in:
		
							parent
							
								
									40fb189c29
								
							
						
					
					
						commit
						efb721320a
					
				@ -27,3 +27,6 @@ solver_args:
 | 
			
		||||
  max_epochs: 3000
 | 
			
		||||
  milestones:
 | 
			
		||||
  - 1000
 | 
			
		||||
camera_difficulty_bin_breaks:
 | 
			
		||||
  - 0.666667
 | 
			
		||||
  - 0.833334
 | 
			
		||||
 | 
			
		||||
@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
 | 
			
		||||
                and epoch % cfg.test_interval == 0
 | 
			
		||||
            ):
 | 
			
		||||
                _run_eval(
 | 
			
		||||
                    model, all_train_cameras, dataloaders.test, task, device=device
 | 
			
		||||
                    model,
 | 
			
		||||
                    all_train_cameras,
 | 
			
		||||
                    dataloaders.test,
 | 
			
		||||
                    task,
 | 
			
		||||
                    camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
 | 
			
		||||
                    device=device,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            assert stats.epoch == epoch, "inconsistent stats!"
 | 
			
		||||
@ -588,7 +593,14 @@ def _eval_and_dump(
 | 
			
		||||
    if dataloader is None:
 | 
			
		||||
        raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
 | 
			
		||||
 | 
			
		||||
    results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
 | 
			
		||||
    results = _run_eval(
 | 
			
		||||
        model,
 | 
			
		||||
        all_train_cameras,
 | 
			
		||||
        dataloader,
 | 
			
		||||
        task,
 | 
			
		||||
        camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
 | 
			
		||||
        device=device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # add the evaluation epoch to the results
 | 
			
		||||
    for r in results:
 | 
			
		||||
@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
 | 
			
		||||
    return frame_data_for_eval
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_eval(model, all_train_cameras, loader, task: Task, device):
 | 
			
		||||
def _run_eval(
 | 
			
		||||
    model,
 | 
			
		||||
    all_train_cameras,
 | 
			
		||||
    loader,
 | 
			
		||||
    task: Task,
 | 
			
		||||
    camera_difficulty_bin_breaks: Tuple[float, float],
 | 
			
		||||
    device,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Run the evaluation loop on the test dataloader
 | 
			
		||||
    """
 | 
			
		||||
@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    _, category_result = evaluate.summarize_nvs_eval_results(
 | 
			
		||||
        per_batch_eval_results, task
 | 
			
		||||
        per_batch_eval_results, task, camera_difficulty_bin_breaks
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return category_result["results"]
 | 
			
		||||
@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
 | 
			
		||||
    visdom_server: str = "http://127.0.0.1"
 | 
			
		||||
    visualize_interval: int = 1000
 | 
			
		||||
    clip_grad: float = 0.0
 | 
			
		||||
    camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
 | 
			
		||||
 | 
			
		||||
    hydra: dict = field(
 | 
			
		||||
        default_factory=lambda: {
 | 
			
		||||
 | 
			
		||||
@ -375,6 +375,9 @@ visdom_port: 8097
 | 
			
		||||
visdom_server: http://127.0.0.1
 | 
			
		||||
visualize_interval: 1000
 | 
			
		||||
clip_grad: 0.0
 | 
			
		||||
camera_difficulty_bin_breaks:
 | 
			
		||||
- 0.97
 | 
			
		||||
- 0.98
 | 
			
		||||
hydra:
 | 
			
		||||
  run:
 | 
			
		||||
    dir: .
 | 
			
		||||
 | 
			
		||||
@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if task == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        camera_difficulty_bin_breaks = 0.97, 0.98
 | 
			
		||||
    else:
 | 
			
		||||
        camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
 | 
			
		||||
 | 
			
		||||
    category_result_flat, category_result = summarize_nvs_eval_results(
 | 
			
		||||
        per_batch_eval_results, task
 | 
			
		||||
        per_batch_eval_results, task, camera_difficulty_bin_breaks
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return category_result["results"]
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ import copy
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Union
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
 | 
			
		||||
    return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_camera_difficulty_bin_edges(task: Task):
 | 
			
		||||
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
 | 
			
		||||
    """
 | 
			
		||||
    Get the edges of camera difficulty bins.
 | 
			
		||||
    """
 | 
			
		||||
    _eps = 1e-5
 | 
			
		||||
    if task == Task.MULTI_SEQUENCE:
 | 
			
		||||
        # TODO: extract those to constants
 | 
			
		||||
        diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
 | 
			
		||||
        diff_bin_edges[0] = 0.0 - _eps
 | 
			
		||||
    elif task == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"No such eval task {task}.")
 | 
			
		||||
    lower, upper = camera_difficulty_bin_breaks
 | 
			
		||||
    diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
 | 
			
		||||
    diff_bin_names = ["hard", "medium", "easy"]
 | 
			
		||||
    return diff_bin_edges, diff_bin_names
 | 
			
		||||
 | 
			
		||||
@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
 | 
			
		||||
def summarize_nvs_eval_results(
 | 
			
		||||
    per_batch_eval_results: List[Dict[str, Any]],
 | 
			
		||||
    task: Task,
 | 
			
		||||
    camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Compile the per-batch evaluation results `per_batch_eval_results` into
 | 
			
		||||
@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
 | 
			
		||||
    Args:
 | 
			
		||||
        per_batch_eval_results: Metrics of each per-batch evaluation.
 | 
			
		||||
        task: The type of the new-view synthesis task.
 | 
			
		||||
        camera_difficulty_bin_breaks: edge hard-medium and medium-easy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        nvs_results_flat: A flattened dict of all aggregate metrics.
 | 
			
		||||
@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
 | 
			
		||||
    # init the result database dict
 | 
			
		||||
    results = []
 | 
			
		||||
 | 
			
		||||
    diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
 | 
			
		||||
    diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
 | 
			
		||||
        camera_difficulty_bin_breaks
 | 
			
		||||
    )
 | 
			
		||||
    n_diff_edges = diff_bin_edges.numel()
 | 
			
		||||
 | 
			
		||||
    # add per set averages
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user