remove get_task

Summary: Remove the dataset's need to provide the task type.

Reviewed By: davnov134, kjchalup

Differential Revision: D38314000

fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
Jeremy Reizenstein 2022-08-02 07:55:42 -07:00 committed by Facebook GitHub Bot
parent 37250a4326
commit f8bf528043
13 changed files with 36 additions and 83 deletions

View File

@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks: camera_difficulty_bin_breaks:
- 0.666667 - 0.666667
- 0.833334 - 0.833334
is_multisequence: true

View File

@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
val_loader, val_loader,
) = accelerator.prepare(model, optimizer, train_loader, val_loader) ) = accelerator.prepare(model, optimizer, train_loader, val_loader)
task = self.data_source.get_task()
all_train_cameras = self.data_source.all_train_cameras all_train_cameras = self.data_source.all_train_cameras
# Enter the main training loop. # Enter the main training loop.
@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir, exp_dir=self.exp_dir,
stats=stats, stats=stats,
seed=self.seed, seed=self.seed,
task=task,
) )
def _check_config_consistent(self) -> None: def _check_config_consistent(self) -> None:

View File

@ -10,7 +10,6 @@ from typing import Any, Optional
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode from pytorch3d.implicitron.models.generic_model import EvaluationMode
@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir: str, exp_dir: str,
stats: Stats, stats: Stats,
seed: int, seed: int,
task: Task,
**kwargs, **kwargs,
): ):
""" """
@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
epoch=stats.epoch, epoch=stats.epoch,
exp_dir=exp_dir, exp_dir=exp_dir,
model=model, model=model,
task=task,
) )
return return
else: else:
@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
device=device, device=device,
dataloader=test_loader, dataloader=test_loader,
model=model, model=model,
task=task,
) )
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir=exp_dir, exp_dir=exp_dir,
dataloader=test_loader, dataloader=test_loader,
model=model, model=model,
task=task,
) )
else: else:
raise ValueError( raise ValueError(

View File

@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks: camera_difficulty_bin_breaks:
- 0.97 - 0.97
- 0.98 - 0.98
is_multisequence: false

View File

@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
@registry.register @registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets) dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders return datasets, dataloaders
def get_task(self) -> Task:
return self.dataset_map_provider.get_task()
@property @property
def all_train_cameras(self) -> Optional[CamerasBase]: def all_train_cameras(self) -> Optional[CamerasBase]:
if self._all_train_cameras_cache is None: # pyre-ignore[16] if self._all_train_cameras_cache is None: # pyre-ignore[16]

View File

@ -7,7 +7,6 @@
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Iterator, Optional from typing import Iterator, Optional
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
@ -53,11 +52,6 @@ class DatasetMap:
yield self.test yield self.test
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DatasetMapProviderBase(ReplaceableBase): class DatasetMapProviderBase(ReplaceableBase):
""" """
Base class for a provider of training / validation and testing Base class for a provider of training / validation and testing
@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
""" """
If the data is all for a single scene, returns a list If the data is all for a single scene, returns a list

View File

@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .dataset_map_provider import ( from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .json_index_dataset import JsonIndexDataset from .json_index_dataset import JsonIndexDataset
from .utils import ( from .utils import (
@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# This maps the common names of the dataset subsets ("train"/"val"/"test") # This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset. # to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping( set_names_mapping = _get_co3d_set_names_mapping(
self.get_task(), self.task_str,
self.test_on_train, self.test_on_train,
self.only_test_set, self.only_test_set,
) )
@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
eval_batch_index = json.load(f) eval_batch_index = json.load(f)
restrict_sequence_name = self.restrict_sequence_name restrict_sequence_name = self.restrict_sequence_name
if self.get_task() == Task.SINGLE_SEQUENCE: if self.task_str == "singlesequence":
if ( if (
self.test_restrict_sequence_id is None self.test_restrict_sequence_id is None
or self.test_restrict_sequence_id < 0 or self.test_restrict_sequence_id < 0
@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# pyre-ignore[16] # pyre-ignore[16]
return self.dataset_map return self.dataset_map
def get_task(self) -> Task:
return Task(self.task_str)
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
if Task(self.task_str) == Task.MULTI_SEQUENCE: if self.task_str == "multisequence":
return None return None
assert self.task_str == "singlesequence"
# pyre-ignore[16] # pyre-ignore[16]
train_dataset = self.dataset_map.train train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset) assert isinstance(train_dataset, JsonIndexDataset)
@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
def _get_co3d_set_names_mapping( def _get_co3d_set_names_mapping(
task: Task, task_str: str,
test_on_train: bool, test_on_train: bool,
only_test: bool, only_test: bool,
) -> Dict[str, List[str]]: ) -> Dict[str, List[str]]:
@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping(
- val (if not test_on_train) - val (if not test_on_train)
- test (if not test_on_train) - test (if not test_on_train)
""" """
single_seq = task == Task.SINGLE_SEQUENCE single_seq = task_str == "singlesequence"
if only_test: if only_test:
set_names_mapping = {} set_names_mapping = {}

View File

@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap, DatasetMap,
DatasetMapProviderBase, DatasetMapProviderBase,
PathManagerFactory, PathManagerFactory,
Task,
) )
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
) )
return category_to_subset_name_list return category_to_subset_name_list
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
return {
"manyview": Task.SINGLE_SEQUENCE,
"fewview": Task.MULTI_SEQUENCE,
}[self.subset_name.split("_")[0]]
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16] # pyre-ignore[16]
train_dataset = self.dataset_map.train train_dataset = self.dataset_map.train

View File

@ -28,12 +28,7 @@ from pytorch3d.renderer import (
) )
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from .dataset_map_provider import ( from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .single_sequence_dataset import SingleSceneDataset from .single_sequence_dataset import SingleSceneDataset
from .utils import DATASET_TYPE_KNOWN from .utils import DATASET_TYPE_KNOWN
@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
# pyre-ignore[16] # pyre-ignore[16]
return DatasetMap(train=self.train_dataset, val=None, test=None) return DatasetMap(train=self.train_dataset, val=None, test=None)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> CamerasBase: def get_all_train_cameras(self) -> CamerasBase:
# pyre-ignore[16] # pyre-ignore[16]
return self.poses return self.poses

View File

@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import (
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import ( from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence" _SINGLE_SEQUENCE_NAME: str = "one_sequence"
@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True), test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
) )
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16] # pyre-ignore[16]
cameras = [self.poses[i] for i in self.i_split[0]] cameras = [self.poses[i] for i in self.i_split[0]]

View File

@ -7,11 +7,12 @@
import dataclasses import dataclasses
import os import os
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple from typing import Any, cast, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES, CO3D_CATEGORIES,
@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from tqdm import tqdm from tqdm import tqdm
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
def main() -> None: def main() -> None:
""" """
Evaluates new view synthesis metrics of a simple depth-based image rendering Evaluates new view synthesis metrics of a simple depth-based image rendering
@ -153,11 +159,15 @@ def evaluate_dbir_for_category(
if task == Task.SINGLE_SEQUENCE: if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98 camera_difficulty_bin_breaks = 0.97, 0.98
multisequence_evaluation = False
else: else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6 camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
multisequence_evaluation = True
category_result_flat, category_result = summarize_nvs_eval_results( category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task, camera_difficulty_bin_breaks per_batch_eval_results,
camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
is_multisequence=multisequence_evaluation,
) )
return category_result["results"] return category_result["results"]

View File

@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.models.base_model import ImplicitronRender
@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
def summarize_nvs_eval_results( def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]], per_batch_eval_results: List[Dict[str, Any]],
task: Task, is_multisequence: bool,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98), camera_difficulty_bin_breaks: Tuple[float, float],
): ):
""" """
Compile the per-batch evaluation results `per_batch_eval_results` into Compile the per-batch evaluation results `per_batch_eval_results` into
a set of aggregate metrics. The produced metrics depend on the task. a set of aggregate metrics. The produced metrics depend on is_multisequence.
Args: Args:
per_batch_eval_results: Metrics of each per-batch evaluation. per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task. is_multisequence: Whether to evaluate as a multisequence task
camera_difficulty_bin_breaks: edge hard-medium and medium-easy camera_difficulty_bin_breaks: edge hard-medium and medium-easy
@ -439,14 +438,9 @@ def summarize_nvs_eval_results(
""" """
n_batches = len(per_batch_eval_results) n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = [] eval_sets: List[Optional[str]] = []
if task == Task.SINGLE_SEQUENCE: eval_sets = [None]
eval_sets = [None] if is_multisequence:
# assert n_batches==100
elif task == Task.MULTI_SEQUENCE:
eval_sets = ["train", "test"] eval_sets = ["train", "test"]
# assert n_batches==1000
else:
raise ValueError(task)
batch_sizes = torch.tensor( batch_sizes = torch.tensor(
[r["meta"]["batch_size"] for r in per_batch_eval_results] [r["meta"]["batch_size"] for r in per_batch_eval_results]
).long() ).long()
@ -466,11 +460,9 @@ def summarize_nvs_eval_results(
# add per set averages # add per set averages
for SET in eval_sets: for SET in eval_sets:
if SET is None: if SET is None:
assert task == Task.SINGLE_SEQUENCE
ok_set = torch.ones(n_batches, dtype=torch.bool) ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test" set_name = "test"
else: else:
assert task == Task.MULTI_SEQUENCE
ok_set = is_train == int(SET == "train") ok_set = is_train == int(SET == "train")
set_name = SET set_name = SET
@ -495,7 +487,7 @@ def summarize_nvs_eval_results(
} }
) )
if task == Task.MULTI_SEQUENCE: if is_multisequence:
# split based on n_src_views # split based on n_src_views
n_src_views = batch_sizes - 1 n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS: for n_src in EVAL_N_SRC_VIEWS:

View File

@ -16,7 +16,6 @@ import torch
import tqdm import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase):
""" """
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98 camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
is_multisequence: bool = False
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase):
self, self,
model: ImplicitronModelBase, model: ImplicitronModelBase,
dataloader: DataLoader, dataloader: DataLoader,
task: Task,
all_train_cameras: Optional[CamerasBase], all_train_cameras: Optional[CamerasBase],
device: torch.device, device: torch.device,
dump_to_json: bool = False, dump_to_json: bool = False,
@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase):
Args: Args:
model: A (trained) model to evaluate. model: A (trained) model to evaluate.
dataloader: A test dataloader. dataloader: A test dataloader.
task: Type of the novel-view synthesis task we're working on.
all_train_cameras: Camera instances we used for training. all_train_cameras: Camera instances we used for training.
device: A torch device. device: A torch device.
dump_to_json: If True, will dump the results to a json file. dump_to_json: If True, will dump the results to a json file.
@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase):
) )
_, category_result = evaluate.summarize_nvs_eval_results( _, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task, self.camera_difficulty_bin_breaks per_batch_eval_results,
self.is_multisequence,
self.camera_difficulty_bin_breaks,
) )
results = category_result["results"] results = category_result["results"]