mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	remove get_task
Summary: Remove the dataset's need to provide the task type. Reviewed By: davnov134, kjchalup Differential Revision: D38314000 fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
		
							parent
							
								
									37250a4326
								
							
						
					
					
						commit
						f8bf528043
					
				@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
    camera_difficulty_bin_breaks:
 | 
			
		||||
      - 0.666667
 | 
			
		||||
      - 0.833334
 | 
			
		||||
    is_multisequence: true
 | 
			
		||||
 | 
			
		||||
@ -206,7 +206,6 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
                val_loader,
 | 
			
		||||
            ) = accelerator.prepare(model, optimizer, train_loader, val_loader)
 | 
			
		||||
 | 
			
		||||
        task = self.data_source.get_task()
 | 
			
		||||
        all_train_cameras = self.data_source.all_train_cameras
 | 
			
		||||
 | 
			
		||||
        # Enter the main training loop.
 | 
			
		||||
@ -223,7 +222,6 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
            exp_dir=self.exp_dir,
 | 
			
		||||
            stats=stats,
 | 
			
		||||
            seed=self.seed,
 | 
			
		||||
            task=task,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_config_consistent(self) -> None:
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,6 @@ from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from accelerate import Accelerator
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import Task
 | 
			
		||||
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode
 | 
			
		||||
@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
        exp_dir: str,
 | 
			
		||||
        stats: Stats,
 | 
			
		||||
        seed: int,
 | 
			
		||||
        task: Task,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
                    epoch=stats.epoch,
 | 
			
		||||
                    exp_dir=exp_dir,
 | 
			
		||||
                    model=model,
 | 
			
		||||
                    task=task,
 | 
			
		||||
                )
 | 
			
		||||
                return
 | 
			
		||||
            else:
 | 
			
		||||
@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
                        device=device,
 | 
			
		||||
                        dataloader=test_loader,
 | 
			
		||||
                        model=model,
 | 
			
		||||
                        task=task,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                assert stats.epoch == epoch, "inconsistent stats!"
 | 
			
		||||
@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
                    exp_dir=exp_dir,
 | 
			
		||||
                    dataloader=test_loader,
 | 
			
		||||
                    model=model,
 | 
			
		||||
                    task=task,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
 | 
			
		||||
@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
    camera_difficulty_bin_breaks:
 | 
			
		||||
    - 0.97
 | 
			
		||||
    - 0.98
 | 
			
		||||
    is_multisequence: false
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider  # noqa
 | 
			
		||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
 | 
			
		||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider  # noqa
 | 
			
		||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2  # noqa
 | 
			
		||||
from .llff_dataset_map_provider import LlffDatasetMapProvider  # noqa
 | 
			
		||||
@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase):
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class ImplicitronDataSource(DataSourceBase):  # pyre-ignore[13]
 | 
			
		||||
@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase):  # pyre-ignore[13]
 | 
			
		||||
        dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
 | 
			
		||||
        return datasets, dataloaders
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return self.dataset_map_provider.get_task()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        if self._all_train_cameras_cache is None:  # pyre-ignore[16]
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Iterator, Optional
 | 
			
		||||
 | 
			
		||||
from iopath.common.file_io import PathManager
 | 
			
		||||
@ -53,11 +52,6 @@ class DatasetMap:
 | 
			
		||||
            yield self.test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Task(Enum):
 | 
			
		||||
    SINGLE_SEQUENCE = "singlesequence"
 | 
			
		||||
    MULTI_SEQUENCE = "multisequence"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetMapProviderBase(ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
    Base class for a provider of training / validation and testing
 | 
			
		||||
@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase):
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        """
 | 
			
		||||
        If the data is all for a single scene, returns a list
 | 
			
		||||
 | 
			
		||||
@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
    PathManagerFactory,
 | 
			
		||||
    Task,
 | 
			
		||||
)
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
 | 
			
		||||
from .json_index_dataset import JsonIndexDataset
 | 
			
		||||
 | 
			
		||||
from .utils import (
 | 
			
		||||
@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
        # This maps the common names of the dataset subsets ("train"/"val"/"test")
 | 
			
		||||
        # to the names of the subsets in the CO3D dataset.
 | 
			
		||||
        set_names_mapping = _get_co3d_set_names_mapping(
 | 
			
		||||
            self.get_task(),
 | 
			
		||||
            self.task_str,
 | 
			
		||||
            self.test_on_train,
 | 
			
		||||
            self.only_test_set,
 | 
			
		||||
        )
 | 
			
		||||
@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
            eval_batch_index = json.load(f)
 | 
			
		||||
        restrict_sequence_name = self.restrict_sequence_name
 | 
			
		||||
 | 
			
		||||
        if self.get_task() == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        if self.task_str == "singlesequence":
 | 
			
		||||
            if (
 | 
			
		||||
                self.test_restrict_sequence_id is None
 | 
			
		||||
                or self.test_restrict_sequence_id < 0
 | 
			
		||||
@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return self.dataset_map
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return Task(self.task_str)
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        if Task(self.task_str) == Task.MULTI_SEQUENCE:
 | 
			
		||||
        if self.task_str == "multisequence":
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        assert self.task_str == "singlesequence"
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        train_dataset = self.dataset_map.train
 | 
			
		||||
        assert isinstance(train_dataset, JsonIndexDataset)
 | 
			
		||||
@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_co3d_set_names_mapping(
 | 
			
		||||
    task: Task,
 | 
			
		||||
    task_str: str,
 | 
			
		||||
    test_on_train: bool,
 | 
			
		||||
    only_test: bool,
 | 
			
		||||
) -> Dict[str, List[str]]:
 | 
			
		||||
@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping(
 | 
			
		||||
        - val (if not test_on_train)
 | 
			
		||||
        - test (if not test_on_train)
 | 
			
		||||
    """
 | 
			
		||||
    single_seq = task == Task.SINGLE_SEQUENCE
 | 
			
		||||
    single_seq = task_str == "singlesequence"
 | 
			
		||||
 | 
			
		||||
    if only_test:
 | 
			
		||||
        set_names_mapping = {}
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
    PathManagerFactory,
 | 
			
		||||
    Task,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
        )
 | 
			
		||||
        return category_to_subset_name_list
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:  # TODO: we plan to get rid of tasks
 | 
			
		||||
        return {
 | 
			
		||||
            "manyview": Task.SINGLE_SEQUENCE,
 | 
			
		||||
            "fewview": Task.MULTI_SEQUENCE,
 | 
			
		||||
        }[self.subset_name.split("_")[0]]
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        train_dataset = self.dataset_map.train
 | 
			
		||||
 | 
			
		||||
@ -28,12 +28,7 @@ from pytorch3d.renderer import (
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes
 | 
			
		||||
 | 
			
		||||
from .dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
    PathManagerFactory,
 | 
			
		||||
    Task,
 | 
			
		||||
)
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
 | 
			
		||||
from .single_sequence_dataset import SingleSceneDataset
 | 
			
		||||
from .utils import DATASET_TYPE_KNOWN
 | 
			
		||||
 | 
			
		||||
@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return DatasetMap(train=self.train_dataset, val=None, test=None)
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return Task.SINGLE_SEQUENCE
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> CamerasBase:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return self.poses
 | 
			
		||||
 | 
			
		||||
@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
 | 
			
		||||
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
    PathManagerFactory,
 | 
			
		||||
    Task,
 | 
			
		||||
)
 | 
			
		||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
 | 
			
		||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
 | 
			
		||||
 | 
			
		||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
 | 
			
		||||
@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
 | 
			
		||||
            test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return Task.SINGLE_SEQUENCE
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        cameras = [self.poses[i] for i in self.i_split[0]]
 | 
			
		||||
 | 
			
		||||
@ -7,11 +7,12 @@
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
import os
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Any, cast, Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import lpips
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
 | 
			
		||||
    CO3D_CATEGORIES,
 | 
			
		||||
@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Task(Enum):
 | 
			
		||||
    SINGLE_SEQUENCE = "singlesequence"
 | 
			
		||||
    MULTI_SEQUENCE = "multisequence"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Evaluates new view synthesis metrics of a simple depth-based image rendering
 | 
			
		||||
@ -153,11 +159,15 @@ def evaluate_dbir_for_category(
 | 
			
		||||
 | 
			
		||||
    if task == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        camera_difficulty_bin_breaks = 0.97, 0.98
 | 
			
		||||
        multisequence_evaluation = False
 | 
			
		||||
    else:
 | 
			
		||||
        camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
 | 
			
		||||
        multisequence_evaluation = True
 | 
			
		||||
 | 
			
		||||
    category_result_flat, category_result = summarize_nvs_eval_results(
 | 
			
		||||
        per_batch_eval_results, task, camera_difficulty_bin_breaks
 | 
			
		||||
        per_batch_eval_results,
 | 
			
		||||
        camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
 | 
			
		||||
        is_multisequence=multisequence_evaluation,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return category_result["results"]
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import Task
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
 | 
			
		||||
@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
 | 
			
		||||
 | 
			
		||||
def summarize_nvs_eval_results(
 | 
			
		||||
    per_batch_eval_results: List[Dict[str, Any]],
 | 
			
		||||
    task: Task,
 | 
			
		||||
    camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
 | 
			
		||||
    is_multisequence: bool,
 | 
			
		||||
    camera_difficulty_bin_breaks: Tuple[float, float],
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Compile the per-batch evaluation results `per_batch_eval_results` into
 | 
			
		||||
    a set of aggregate metrics. The produced metrics depend on the task.
 | 
			
		||||
    a set of aggregate metrics. The produced metrics depend on is_multisequence.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        per_batch_eval_results: Metrics of each per-batch evaluation.
 | 
			
		||||
        task: The type of the new-view synthesis task.
 | 
			
		||||
        is_multisequence: Whether to evaluate as a multisequence task
 | 
			
		||||
        camera_difficulty_bin_breaks: edge hard-medium and medium-easy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -439,14 +438,9 @@ def summarize_nvs_eval_results(
 | 
			
		||||
    """
 | 
			
		||||
    n_batches = len(per_batch_eval_results)
 | 
			
		||||
    eval_sets: List[Optional[str]] = []
 | 
			
		||||
    if task == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        eval_sets = [None]
 | 
			
		||||
        # assert n_batches==100
 | 
			
		||||
    elif task == Task.MULTI_SEQUENCE:
 | 
			
		||||
    eval_sets = [None]
 | 
			
		||||
    if is_multisequence:
 | 
			
		||||
        eval_sets = ["train", "test"]
 | 
			
		||||
        # assert n_batches==1000
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(task)
 | 
			
		||||
    batch_sizes = torch.tensor(
 | 
			
		||||
        [r["meta"]["batch_size"] for r in per_batch_eval_results]
 | 
			
		||||
    ).long()
 | 
			
		||||
@ -466,11 +460,9 @@ def summarize_nvs_eval_results(
 | 
			
		||||
    # add per set averages
 | 
			
		||||
    for SET in eval_sets:
 | 
			
		||||
        if SET is None:
 | 
			
		||||
            assert task == Task.SINGLE_SEQUENCE
 | 
			
		||||
            ok_set = torch.ones(n_batches, dtype=torch.bool)
 | 
			
		||||
            set_name = "test"
 | 
			
		||||
        else:
 | 
			
		||||
            assert task == Task.MULTI_SEQUENCE
 | 
			
		||||
            ok_set = is_train == int(SET == "train")
 | 
			
		||||
            set_name = SET
 | 
			
		||||
 | 
			
		||||
@ -495,7 +487,7 @@ def summarize_nvs_eval_results(
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if task == Task.MULTI_SEQUENCE:
 | 
			
		||||
        if is_multisequence:
 | 
			
		||||
            # split based on n_src_views
 | 
			
		||||
            n_src_views = batch_sizes - 1
 | 
			
		||||
            for n_src in EVAL_N_SRC_VIEWS:
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,6 @@ import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset import utils as ds_utils
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import Task
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
 | 
			
		||||
@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
 | 
			
		||||
    is_multisequence: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase):
 | 
			
		||||
        self,
 | 
			
		||||
        model: ImplicitronModelBase,
 | 
			
		||||
        dataloader: DataLoader,
 | 
			
		||||
        task: Task,
 | 
			
		||||
        all_train_cameras: Optional[CamerasBase],
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        dump_to_json: bool = False,
 | 
			
		||||
@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase):
 | 
			
		||||
        Args:
 | 
			
		||||
            model: A (trained) model to evaluate.
 | 
			
		||||
            dataloader: A test dataloader.
 | 
			
		||||
            task: Type of the novel-view synthesis task we're working on.
 | 
			
		||||
            all_train_cameras: Camera instances we used for training.
 | 
			
		||||
            device: A torch device.
 | 
			
		||||
            dump_to_json: If True, will dump the results to a json file.
 | 
			
		||||
@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase):
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        _, category_result = evaluate.summarize_nvs_eval_results(
 | 
			
		||||
            per_batch_eval_results, task, self.camera_difficulty_bin_breaks
 | 
			
		||||
            per_batch_eval_results,
 | 
			
		||||
            self.is_multisequence,
 | 
			
		||||
            self.camera_difficulty_bin_breaks,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        results = category_result["results"]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user