mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
remove get_task
Summary: Remove the dataset's need to provide the task type. Reviewed By: davnov134, kjchalup Differential Revision: D38314000 fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
parent
37250a4326
commit
f8bf528043
@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
|
|||||||
camera_difficulty_bin_breaks:
|
camera_difficulty_bin_breaks:
|
||||||
- 0.666667
|
- 0.666667
|
||||||
- 0.833334
|
- 0.833334
|
||||||
|
is_multisequence: true
|
||||||
|
@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
val_loader,
|
val_loader,
|
||||||
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||||
|
|
||||||
task = self.data_source.get_task()
|
|
||||||
all_train_cameras = self.data_source.all_train_cameras
|
all_train_cameras = self.data_source.all_train_cameras
|
||||||
|
|
||||||
# Enter the main training loop.
|
# Enter the main training loop.
|
||||||
@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
exp_dir=self.exp_dir,
|
exp_dir=self.exp_dir,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
task=task,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_config_consistent(self) -> None:
|
def _check_config_consistent(self) -> None:
|
||||||
|
@ -10,7 +10,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from pytorch3d.implicitron.dataset.data_source import Task
|
|
||||||
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
||||||
@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
exp_dir: str,
|
exp_dir: str,
|
||||||
stats: Stats,
|
stats: Stats,
|
||||||
seed: int,
|
seed: int,
|
||||||
task: Task,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
epoch=stats.epoch,
|
epoch=stats.epoch,
|
||||||
exp_dir=exp_dir,
|
exp_dir=exp_dir,
|
||||||
model=model,
|
model=model,
|
||||||
task=task,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
device=device,
|
device=device,
|
||||||
dataloader=test_loader,
|
dataloader=test_loader,
|
||||||
model=model,
|
model=model,
|
||||||
task=task,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert stats.epoch == epoch, "inconsistent stats!"
|
assert stats.epoch == epoch, "inconsistent stats!"
|
||||||
@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
exp_dir=exp_dir,
|
exp_dir=exp_dir,
|
||||||
dataloader=test_loader,
|
dataloader=test_loader,
|
||||||
model=model,
|
model=model,
|
||||||
task=task,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
|
|||||||
camera_difficulty_bin_breaks:
|
camera_difficulty_bin_breaks:
|
||||||
- 0.97
|
- 0.97
|
||||||
- 0.98
|
- 0.98
|
||||||
|
is_multisequence: false
|
||||||
|
@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
|
|
||||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
||||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||||
@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_task(self) -> Task:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||||
@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
||||||
return datasets, dataloaders
|
return datasets, dataloaders
|
||||||
|
|
||||||
def get_task(self) -> Task:
|
|
||||||
return self.dataset_map_provider.get_task()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_train_cameras(self) -> Optional[CamerasBase]:
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
|
||||||
from typing import Iterator, Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
@ -53,11 +52,6 @@ class DatasetMap:
|
|||||||
yield self.test
|
yield self.test
|
||||||
|
|
||||||
|
|
||||||
class Task(Enum):
|
|
||||||
SINGLE_SEQUENCE = "singlesequence"
|
|
||||||
MULTI_SEQUENCE = "multisequence"
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetMapProviderBase(ReplaceableBase):
|
class DatasetMapProviderBase(ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
Base class for a provider of training / validation and testing
|
Base class for a provider of training / validation and testing
|
||||||
@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_task(self) -> Task:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
"""
|
"""
|
||||||
If the data is all for a single scene, returns a list
|
If the data is all for a single scene, returns a list
|
||||||
|
@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .dataset_map_provider import (
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||||
DatasetMap,
|
|
||||||
DatasetMapProviderBase,
|
|
||||||
PathManagerFactory,
|
|
||||||
Task,
|
|
||||||
)
|
|
||||||
from .json_index_dataset import JsonIndexDataset
|
from .json_index_dataset import JsonIndexDataset
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||||
# to the names of the subsets in the CO3D dataset.
|
# to the names of the subsets in the CO3D dataset.
|
||||||
set_names_mapping = _get_co3d_set_names_mapping(
|
set_names_mapping = _get_co3d_set_names_mapping(
|
||||||
self.get_task(),
|
self.task_str,
|
||||||
self.test_on_train,
|
self.test_on_train,
|
||||||
self.only_test_set,
|
self.only_test_set,
|
||||||
)
|
)
|
||||||
@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
eval_batch_index = json.load(f)
|
eval_batch_index = json.load(f)
|
||||||
restrict_sequence_name = self.restrict_sequence_name
|
restrict_sequence_name = self.restrict_sequence_name
|
||||||
|
|
||||||
if self.get_task() == Task.SINGLE_SEQUENCE:
|
if self.task_str == "singlesequence":
|
||||||
if (
|
if (
|
||||||
self.test_restrict_sequence_id is None
|
self.test_restrict_sequence_id is None
|
||||||
or self.test_restrict_sequence_id < 0
|
or self.test_restrict_sequence_id < 0
|
||||||
@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
return self.dataset_map
|
return self.dataset_map
|
||||||
|
|
||||||
def get_task(self) -> Task:
|
|
||||||
return Task(self.task_str)
|
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
if self.task_str == "multisequence":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
assert self.task_str == "singlesequence"
|
||||||
|
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
train_dataset = self.dataset_map.train
|
train_dataset = self.dataset_map.train
|
||||||
assert isinstance(train_dataset, JsonIndexDataset)
|
assert isinstance(train_dataset, JsonIndexDataset)
|
||||||
@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
|
|
||||||
def _get_co3d_set_names_mapping(
|
def _get_co3d_set_names_mapping(
|
||||||
task: Task,
|
task_str: str,
|
||||||
test_on_train: bool,
|
test_on_train: bool,
|
||||||
only_test: bool,
|
only_test: bool,
|
||||||
) -> Dict[str, List[str]]:
|
) -> Dict[str, List[str]]:
|
||||||
@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping(
|
|||||||
- val (if not test_on_train)
|
- val (if not test_on_train)
|
||||||
- test (if not test_on_train)
|
- test (if not test_on_train)
|
||||||
"""
|
"""
|
||||||
single_seq = task == Task.SINGLE_SEQUENCE
|
single_seq = task_str == "singlesequence"
|
||||||
|
|
||||||
if only_test:
|
if only_test:
|
||||||
set_names_mapping = {}
|
set_names_mapping = {}
|
||||||
|
@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
|||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
PathManagerFactory,
|
PathManagerFactory,
|
||||||
Task,
|
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
)
|
)
|
||||||
return category_to_subset_name_list
|
return category_to_subset_name_list
|
||||||
|
|
||||||
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
|
||||||
return {
|
|
||||||
"manyview": Task.SINGLE_SEQUENCE,
|
|
||||||
"fewview": Task.MULTI_SEQUENCE,
|
|
||||||
}[self.subset_name.split("_")[0]]
|
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
train_dataset = self.dataset_map.train
|
train_dataset = self.dataset_map.train
|
||||||
|
@ -28,12 +28,7 @@ from pytorch3d.renderer import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
|
||||||
from .dataset_map_provider import (
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||||
DatasetMap,
|
|
||||||
DatasetMapProviderBase,
|
|
||||||
PathManagerFactory,
|
|
||||||
Task,
|
|
||||||
)
|
|
||||||
from .single_sequence_dataset import SingleSceneDataset
|
from .single_sequence_dataset import SingleSceneDataset
|
||||||
from .utils import DATASET_TYPE_KNOWN
|
from .utils import DATASET_TYPE_KNOWN
|
||||||
|
|
||||||
@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
return DatasetMap(train=self.train_dataset, val=None, test=None)
|
return DatasetMap(train=self.train_dataset, val=None, test=None)
|
||||||
|
|
||||||
def get_task(self) -> Task:
|
|
||||||
return Task.SINGLE_SEQUENCE
|
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> CamerasBase:
|
def get_all_train_cameras(self) -> CamerasBase:
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
return self.poses
|
return self.poses
|
||||||
|
@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||||
|
|
||||||
from .dataset_base import DatasetBase, FrameData
|
from .dataset_base import DatasetBase, FrameData
|
||||||
from .dataset_map_provider import (
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||||
DatasetMap,
|
|
||||||
DatasetMapProviderBase,
|
|
||||||
PathManagerFactory,
|
|
||||||
Task,
|
|
||||||
)
|
|
||||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
||||||
|
|
||||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
||||||
@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_task(self) -> Task:
|
|
||||||
return Task.SINGLE_SEQUENCE
|
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
cameras = [self.poses[i] for i in self.i_split[0]]
|
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||||
|
@ -7,11 +7,12 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import lpips
|
import lpips
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||||
CO3D_CATEGORIES,
|
CO3D_CATEGORIES,
|
||||||
@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class Task(Enum):
|
||||||
|
SINGLE_SEQUENCE = "singlesequence"
|
||||||
|
MULTI_SEQUENCE = "multisequence"
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
||||||
@ -153,11 +159,15 @@ def evaluate_dbir_for_category(
|
|||||||
|
|
||||||
if task == Task.SINGLE_SEQUENCE:
|
if task == Task.SINGLE_SEQUENCE:
|
||||||
camera_difficulty_bin_breaks = 0.97, 0.98
|
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||||
|
multisequence_evaluation = False
|
||||||
else:
|
else:
|
||||||
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||||
|
multisequence_evaluation = True
|
||||||
|
|
||||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||||
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
per_batch_eval_results,
|
||||||
|
camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
|
||||||
|
is_multisequence=multisequence_evaluation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return category_result["results"]
|
return category_result["results"]
|
||||||
|
@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.implicitron.dataset.data_source import Task
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||||
@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
|
|||||||
|
|
||||||
def summarize_nvs_eval_results(
|
def summarize_nvs_eval_results(
|
||||||
per_batch_eval_results: List[Dict[str, Any]],
|
per_batch_eval_results: List[Dict[str, Any]],
|
||||||
task: Task,
|
is_multisequence: bool,
|
||||||
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||||
a set of aggregate metrics. The produced metrics depend on the task.
|
a set of aggregate metrics. The produced metrics depend on is_multisequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||||
task: The type of the new-view synthesis task.
|
is_multisequence: Whether to evaluate as a multisequence task
|
||||||
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||||
|
|
||||||
|
|
||||||
@ -439,14 +438,9 @@ def summarize_nvs_eval_results(
|
|||||||
"""
|
"""
|
||||||
n_batches = len(per_batch_eval_results)
|
n_batches = len(per_batch_eval_results)
|
||||||
eval_sets: List[Optional[str]] = []
|
eval_sets: List[Optional[str]] = []
|
||||||
if task == Task.SINGLE_SEQUENCE:
|
eval_sets = [None]
|
||||||
eval_sets = [None]
|
if is_multisequence:
|
||||||
# assert n_batches==100
|
|
||||||
elif task == Task.MULTI_SEQUENCE:
|
|
||||||
eval_sets = ["train", "test"]
|
eval_sets = ["train", "test"]
|
||||||
# assert n_batches==1000
|
|
||||||
else:
|
|
||||||
raise ValueError(task)
|
|
||||||
batch_sizes = torch.tensor(
|
batch_sizes = torch.tensor(
|
||||||
[r["meta"]["batch_size"] for r in per_batch_eval_results]
|
[r["meta"]["batch_size"] for r in per_batch_eval_results]
|
||||||
).long()
|
).long()
|
||||||
@ -466,11 +460,9 @@ def summarize_nvs_eval_results(
|
|||||||
# add per set averages
|
# add per set averages
|
||||||
for SET in eval_sets:
|
for SET in eval_sets:
|
||||||
if SET is None:
|
if SET is None:
|
||||||
assert task == Task.SINGLE_SEQUENCE
|
|
||||||
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
||||||
set_name = "test"
|
set_name = "test"
|
||||||
else:
|
else:
|
||||||
assert task == Task.MULTI_SEQUENCE
|
|
||||||
ok_set = is_train == int(SET == "train")
|
ok_set = is_train == int(SET == "train")
|
||||||
set_name = SET
|
set_name = SET
|
||||||
|
|
||||||
@ -495,7 +487,7 @@ def summarize_nvs_eval_results(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if task == Task.MULTI_SEQUENCE:
|
if is_multisequence:
|
||||||
# split based on n_src_views
|
# split based on n_src_views
|
||||||
n_src_views = batch_sizes - 1
|
n_src_views = batch_sizes - 1
|
||||||
for n_src in EVAL_N_SRC_VIEWS:
|
for n_src in EVAL_N_SRC_VIEWS:
|
||||||
|
@ -16,7 +16,6 @@ import torch
|
|||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||||
from pytorch3d.implicitron.dataset.data_source import Task
|
|
||||||
|
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
||||||
@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||||
|
is_multisequence: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase):
|
|||||||
self,
|
self,
|
||||||
model: ImplicitronModelBase,
|
model: ImplicitronModelBase,
|
||||||
dataloader: DataLoader,
|
dataloader: DataLoader,
|
||||||
task: Task,
|
|
||||||
all_train_cameras: Optional[CamerasBase],
|
all_train_cameras: Optional[CamerasBase],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dump_to_json: bool = False,
|
dump_to_json: bool = False,
|
||||||
@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase):
|
|||||||
Args:
|
Args:
|
||||||
model: A (trained) model to evaluate.
|
model: A (trained) model to evaluate.
|
||||||
dataloader: A test dataloader.
|
dataloader: A test dataloader.
|
||||||
task: Type of the novel-view synthesis task we're working on.
|
|
||||||
all_train_cameras: Camera instances we used for training.
|
all_train_cameras: Camera instances we used for training.
|
||||||
device: A torch device.
|
device: A torch device.
|
||||||
dump_to_json: If True, will dump the results to a json file.
|
dump_to_json: If True, will dump the results to a json file.
|
||||||
@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||||
per_batch_eval_results, task, self.camera_difficulty_bin_breaks
|
per_batch_eval_results,
|
||||||
|
self.is_multisequence,
|
||||||
|
self.camera_difficulty_bin_breaks,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = category_result["results"]
|
results = category_result["results"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user