remove get_task

Summary: Remove the dataset's need to provide the task type.

Reviewed By: davnov134, kjchalup

Differential Revision: D38314000

fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
Jeremy Reizenstein 2022-08-02 07:55:42 -07:00 committed by Facebook GitHub Bot
parent 37250a4326
commit f8bf528043
13 changed files with 36 additions and 83 deletions

View File

@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334
is_multisequence: true

View File

@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
val_loader,
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
task = self.data_source.get_task()
all_train_cameras = self.data_source.all_train_cameras
# Enter the main training loop.
@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir,
stats=stats,
seed=self.seed,
task=task,
)
def _check_config_consistent(self) -> None:

View File

@ -10,7 +10,6 @@ from typing import Any, Optional
import torch
from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode
@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir: str,
stats: Stats,
seed: int,
task: Task,
**kwargs,
):
"""
@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
epoch=stats.epoch,
exp_dir=exp_dir,
model=model,
task=task,
)
return
else:
@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
device=device,
dataloader=test_loader,
model=model,
task=task,
)
assert stats.epoch == epoch, "inconsistent stats!"
@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir=exp_dir,
dataloader=test_loader,
model=model,
task=task,
)
else:
raise ValueError(

View File

@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks:
- 0.97
- 0.98
is_multisequence: false

View File

@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase):
"""
raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
@registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders
def get_task(self) -> Task:
return self.dataset_map_provider.get_task()
@property
def all_train_cameras(self) -> Optional[CamerasBase]:
if self._all_train_cameras_cache is None: # pyre-ignore[16]

View File

@ -7,7 +7,6 @@
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import Iterator, Optional
from iopath.common.file_io import PathManager
@ -53,11 +52,6 @@ class DatasetMap:
yield self.test
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DatasetMapProviderBase(ReplaceableBase):
"""
Base class for a provider of training / validation and testing
@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase):
"""
raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]:
"""
If the data is all for a single scene, returns a list

View File

@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import (
)
from pytorch3d.renderer.cameras import CamerasBase
from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
from .json_index_dataset import JsonIndexDataset
from .utils import (
@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping(
self.get_task(),
self.task_str,
self.test_on_train,
self.only_test_set,
)
@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
eval_batch_index = json.load(f)
restrict_sequence_name = self.restrict_sequence_name
if self.get_task() == Task.SINGLE_SEQUENCE:
if self.task_str == "singlesequence":
if (
self.test_restrict_sequence_id is None
or self.test_restrict_sequence_id < 0
@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# pyre-ignore[16]
return self.dataset_map
def get_task(self) -> Task:
return Task(self.task_str)
def get_all_train_cameras(self) -> Optional[CamerasBase]:
if Task(self.task_str) == Task.MULTI_SEQUENCE:
if self.task_str == "multisequence":
return None
assert self.task_str == "singlesequence"
# pyre-ignore[16]
train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset)
@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
def _get_co3d_set_names_mapping(
task: Task,
task_str: str,
test_on_train: bool,
only_test: bool,
) -> Dict[str, List[str]]:
@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping(
- val (if not test_on_train)
- test (if not test_on_train)
"""
single_seq = task == Task.SINGLE_SEQUENCE
single_seq = task_str == "singlesequence"
if only_test:
set_names_mapping = {}

View File

@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import (
@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
)
return category_to_subset_name_list
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
return {
"manyview": Task.SINGLE_SEQUENCE,
"fewview": Task.MULTI_SEQUENCE,
}[self.subset_name.split("_")[0]]
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
train_dataset = self.dataset_map.train

View File

@ -28,12 +28,7 @@ from pytorch3d.renderer import (
)
from pytorch3d.structures.meshes import Meshes
from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
from .single_sequence_dataset import SingleSceneDataset
from .utils import DATASET_TYPE_KNOWN
@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
# pyre-ignore[16]
return DatasetMap(train=self.train_dataset, val=None, test=None)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> CamerasBase:
# pyre-ignore[16]
return self.poses

View File

@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import (
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
cameras = [self.poses[i] for i in self.i_split[0]]

View File

@ -7,11 +7,12 @@
import dataclasses
import os
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple
import lpips
import torch
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES,
@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from tqdm import tqdm
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
def main() -> None:
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
@ -153,11 +159,15 @@ def evaluate_dbir_for_category(
if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98
multisequence_evaluation = False
else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
multisequence_evaluation = True
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task, camera_difficulty_bin_breaks
per_batch_eval_results,
camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
is_multisequence=multisequence_evaluation,
)
return category_result["results"]

View File

@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender
@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
is_multisequence: bool,
camera_difficulty_bin_breaks: Tuple[float, float],
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
a set of aggregate metrics. The produced metrics depend on the task.
a set of aggregate metrics. The produced metrics depend on is_multisequence.
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
is_multisequence: Whether to evaluate as a multisequence task
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
@ -439,14 +438,9 @@ def summarize_nvs_eval_results(
"""
n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = []
if task == Task.SINGLE_SEQUENCE:
eval_sets = [None]
# assert n_batches==100
elif task == Task.MULTI_SEQUENCE:
eval_sets = [None]
if is_multisequence:
eval_sets = ["train", "test"]
# assert n_batches==1000
else:
raise ValueError(task)
batch_sizes = torch.tensor(
[r["meta"]["batch_size"] for r in per_batch_eval_results]
).long()
@ -466,11 +460,9 @@ def summarize_nvs_eval_results(
# add per set averages
for SET in eval_sets:
if SET is None:
assert task == Task.SINGLE_SEQUENCE
ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test"
else:
assert task == Task.MULTI_SEQUENCE
ok_set = is_train == int(SET == "train")
set_name = SET
@ -495,7 +487,7 @@ def summarize_nvs_eval_results(
}
)
if task == Task.MULTI_SEQUENCE:
if is_multisequence:
# split based on n_src_views
n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS:

View File

@ -16,7 +16,6 @@ import torch
import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase):
"""
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
is_multisequence: bool = False
def __post_init__(self):
run_auto_creation(self)
@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase):
self,
model: ImplicitronModelBase,
dataloader: DataLoader,
task: Task,
all_train_cameras: Optional[CamerasBase],
device: torch.device,
dump_to_json: bool = False,
@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase):
Args:
model: A (trained) model to evaluate.
dataloader: A test dataloader.
task: Type of the novel-view synthesis task we're working on.
all_train_cameras: Camera instances we used for training.
device: A torch device.
dump_to_json: If True, will dump the results to a json file.
@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase):
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task, self.camera_difficulty_bin_breaks
per_batch_eval_results,
self.is_multisequence,
self.camera_difficulty_bin_breaks,
)
results = category_result["results"]