mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
remove get_task
Summary: Remove the dataset's need to provide the task type. Reviewed By: davnov134, kjchalup Differential Revision: D38314000 fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
parent
37250a4326
commit
f8bf528043
@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.666667
|
||||
- 0.833334
|
||||
is_multisequence: true
|
||||
|
@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
val_loader,
|
||||
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||
|
||||
task = self.data_source.get_task()
|
||||
all_train_cameras = self.data_source.all_train_cameras
|
||||
|
||||
# Enter the main training loop.
|
||||
@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
exp_dir=self.exp_dir,
|
||||
stats=stats,
|
||||
seed=self.seed,
|
||||
task=task,
|
||||
)
|
||||
|
||||
def _check_config_consistent(self) -> None:
|
||||
|
@ -10,7 +10,6 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
||||
@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
exp_dir: str,
|
||||
stats: Stats,
|
||||
seed: int,
|
||||
task: Task,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
epoch=stats.epoch,
|
||||
exp_dir=exp_dir,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
return
|
||||
else:
|
||||
@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
device=device,
|
||||
dataloader=test_loader,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
exp_dir=exp_dir,
|
||||
dataloader=test_loader,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.97
|
||||
- 0.98
|
||||
is_multisequence: false
|
||||
|
@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||
@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
||||
return datasets, dataloaders
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return self.dataset_map_provider.get_task()
|
||||
|
||||
@property
|
||||
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
||||
|
@ -7,7 +7,6 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
@ -53,11 +52,6 @@ class DatasetMap:
|
||||
yield self.test
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
SINGLE_SEQUENCE = "singlesequence"
|
||||
MULTI_SEQUENCE = "multisequence"
|
||||
|
||||
|
||||
class DatasetMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
Base class for a provider of training / validation and testing
|
||||
@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, returns a list
|
||||
|
@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||
from .json_index_dataset import JsonIndexDataset
|
||||
|
||||
from .utils import (
|
||||
@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||
# to the names of the subsets in the CO3D dataset.
|
||||
set_names_mapping = _get_co3d_set_names_mapping(
|
||||
self.get_task(),
|
||||
self.task_str,
|
||||
self.test_on_train,
|
||||
self.only_test_set,
|
||||
)
|
||||
@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
eval_batch_index = json.load(f)
|
||||
restrict_sequence_name = self.restrict_sequence_name
|
||||
|
||||
if self.get_task() == Task.SINGLE_SEQUENCE:
|
||||
if self.task_str == "singlesequence":
|
||||
if (
|
||||
self.test_restrict_sequence_id is None
|
||||
or self.test_restrict_sequence_id < 0
|
||||
@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
# pyre-ignore[16]
|
||||
return self.dataset_map
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return Task(self.task_str)
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
||||
if self.task_str == "multisequence":
|
||||
return None
|
||||
|
||||
assert self.task_str == "singlesequence"
|
||||
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
assert isinstance(train_dataset, JsonIndexDataset)
|
||||
@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
task: Task,
|
||||
task_str: str,
|
||||
test_on_train: bool,
|
||||
only_test: bool,
|
||||
) -> Dict[str, List[str]]:
|
||||
@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping(
|
||||
- val (if not test_on_train)
|
||||
- test (if not test_on_train)
|
||||
"""
|
||||
single_seq = task == Task.SINGLE_SEQUENCE
|
||||
single_seq = task_str == "singlesequence"
|
||||
|
||||
if only_test:
|
||||
set_names_mapping = {}
|
||||
|
@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
)
|
||||
return category_to_subset_name_list
|
||||
|
||||
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
||||
return {
|
||||
"manyview": Task.SINGLE_SEQUENCE,
|
||||
"fewview": Task.MULTI_SEQUENCE,
|
||||
}[self.subset_name.split("_")[0]]
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
|
@ -28,12 +28,7 @@ from pytorch3d.renderer import (
|
||||
)
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||
from .single_sequence_dataset import SingleSceneDataset
|
||||
from .utils import DATASET_TYPE_KNOWN
|
||||
|
||||
@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
|
||||
# pyre-ignore[16]
|
||||
return DatasetMap(train=self.train_dataset, val=None, test=None)
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return Task.SINGLE_SEQUENCE
|
||||
|
||||
def get_all_train_cameras(self) -> CamerasBase:
|
||||
# pyre-ignore[16]
|
||||
return self.poses
|
||||
|
@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
||||
|
||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
||||
@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
||||
)
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return Task.SINGLE_SEQUENCE
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||
|
@ -7,11 +7,12 @@
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
SINGLE_SEQUENCE = "singlesequence"
|
||||
MULTI_SEQUENCE = "multisequence"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
||||
@ -153,11 +159,15 @@ def evaluate_dbir_for_category(
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||
multisequence_evaluation = False
|
||||
else:
|
||||
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||
multisequence_evaluation = True
|
||||
|
||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||
per_batch_eval_results,
|
||||
camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
|
||||
is_multisequence=multisequence_evaluation,
|
||||
)
|
||||
|
||||
return category_result["results"]
|
||||
|
@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||
@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
|
||||
|
||||
def summarize_nvs_eval_results(
|
||||
per_batch_eval_results: List[Dict[str, Any]],
|
||||
task: Task,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
||||
is_multisequence: bool,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||
):
|
||||
"""
|
||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||
a set of aggregate metrics. The produced metrics depend on the task.
|
||||
a set of aggregate metrics. The produced metrics depend on is_multisequence.
|
||||
|
||||
Args:
|
||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||
task: The type of the new-view synthesis task.
|
||||
is_multisequence: Whether to evaluate as a multisequence task
|
||||
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||
|
||||
|
||||
@ -439,14 +438,9 @@ def summarize_nvs_eval_results(
|
||||
"""
|
||||
n_batches = len(per_batch_eval_results)
|
||||
eval_sets: List[Optional[str]] = []
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
eval_sets = [None]
|
||||
# assert n_batches==100
|
||||
elif task == Task.MULTI_SEQUENCE:
|
||||
eval_sets = [None]
|
||||
if is_multisequence:
|
||||
eval_sets = ["train", "test"]
|
||||
# assert n_batches==1000
|
||||
else:
|
||||
raise ValueError(task)
|
||||
batch_sizes = torch.tensor(
|
||||
[r["meta"]["batch_size"] for r in per_batch_eval_results]
|
||||
).long()
|
||||
@ -466,11 +460,9 @@ def summarize_nvs_eval_results(
|
||||
# add per set averages
|
||||
for SET in eval_sets:
|
||||
if SET is None:
|
||||
assert task == Task.SINGLE_SEQUENCE
|
||||
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
||||
set_name = "test"
|
||||
else:
|
||||
assert task == Task.MULTI_SEQUENCE
|
||||
ok_set = is_train == int(SET == "train")
|
||||
set_name = SET
|
||||
|
||||
@ -495,7 +487,7 @@ def summarize_nvs_eval_results(
|
||||
}
|
||||
)
|
||||
|
||||
if task == Task.MULTI_SEQUENCE:
|
||||
if is_multisequence:
|
||||
# split based on n_src_views
|
||||
n_src_views = batch_sizes - 1
|
||||
for n_src in EVAL_N_SRC_VIEWS:
|
||||
|
@ -16,7 +16,6 @@ import torch
|
||||
import tqdm
|
||||
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
||||
@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase):
|
||||
"""
|
||||
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
is_multisequence: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase):
|
||||
self,
|
||||
model: ImplicitronModelBase,
|
||||
dataloader: DataLoader,
|
||||
task: Task,
|
||||
all_train_cameras: Optional[CamerasBase],
|
||||
device: torch.device,
|
||||
dump_to_json: bool = False,
|
||||
@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase):
|
||||
Args:
|
||||
model: A (trained) model to evaluate.
|
||||
dataloader: A test dataloader.
|
||||
task: Type of the novel-view synthesis task we're working on.
|
||||
all_train_cameras: Camera instances we used for training.
|
||||
device: A torch device.
|
||||
dump_to_json: If True, will dump the results to a json file.
|
||||
@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase):
|
||||
)
|
||||
|
||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task, self.camera_difficulty_bin_breaks
|
||||
per_batch_eval_results,
|
||||
self.is_multisequence,
|
||||
self.camera_difficulty_bin_breaks,
|
||||
)
|
||||
|
||||
results = category_result["results"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user