mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
CO3Dv2 trainer configs
Summary: Adds yaml configs to train selected methods on CO3Dv2. Few more updates: 1) moved some fields to base classes so that we can check is_multisequence in experiment.py 2) skip loading all train cameras for multisequence datasets, without this, co3d-fewview is untrainable 3) fix bug in json index dataset provider v2 Reviewed By: kjchalup Differential Revision: D38952755 fbshipit-source-id: 3edac6fc8e20775aa70400bd73a0e6d52b091e0c
This commit is contained in:
parent
03562d87f5
commit
1163eaab43
@ -0,0 +1,8 @@
|
|||||||
|
data_source_ImplicitronDataSource_args:
|
||||||
|
dataset_map_provider_class_type: JsonIndexDatasetMapProviderV2
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||||
|
category: teddybear
|
||||||
|
subset_name: fewview_dev
|
||||||
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
|
evaluator_ImplicitronEvaluator_args:
|
||||||
|
is_multisequence: true
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_multiseq_nerf_wce.yaml
|
||||||
|
- repro_multiseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_multiseq_nerformer.yaml
|
||||||
|
- repro_multiseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_multiseq_srn_ad_hypernet.yaml
|
||||||
|
- repro_multiseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_multiseq_srn_wce.yaml
|
||||||
|
- repro_multiseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,8 @@
|
|||||||
|
data_source_ImplicitronDataSource_args:
|
||||||
|
dataset_map_provider_class_type: JsonIndexDatasetMapProviderV2
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||||
|
category: teddybear
|
||||||
|
subset_name: manyview_dev_0
|
||||||
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
|
evaluator_ImplicitronEvaluator_args:
|
||||||
|
is_multisequence: false
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_singleseq_idr.yaml
|
||||||
|
- repro_singleseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_singleseq_nerf.yaml
|
||||||
|
- repro_singleseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_singleseq_nerformer.yaml
|
||||||
|
- repro_singleseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
- repro_singleseq_srn_noharm.yaml
|
||||||
|
- repro_singleseq_co3dv2_base.yaml
|
||||||
|
- _self_
|
@ -207,7 +207,10 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
val_loader,
|
val_loader,
|
||||||
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||||
|
|
||||||
all_train_cameras = self.data_source.all_train_cameras
|
if not self.training_loop.evaluator.is_multisequence:
|
||||||
|
all_train_cameras = self.data_source.all_train_cameras
|
||||||
|
else:
|
||||||
|
all_train_cameras = None
|
||||||
|
|
||||||
# Enter the main training loop.
|
# Enter the main training loop.
|
||||||
self.training_loop.run(
|
self.training_loop.run(
|
||||||
|
@ -30,6 +30,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class TrainingLoopBase(ReplaceableBase):
|
class TrainingLoopBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Members:
|
||||||
|
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
evaluator: Optional[EvaluatorBase]
|
||||||
|
evaluator_class_type: Optional[str] = "ImplicitronEvaluator"
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
@ -58,7 +66,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
"""
|
"""
|
||||||
Members:
|
Members:
|
||||||
eval_only: If True, only run evaluation using the test dataloader.
|
eval_only: If True, only run evaluation using the test dataloader.
|
||||||
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
|
||||||
max_epochs: Train for this many epochs. Note that if the model was
|
max_epochs: Train for this many epochs. Note that if the model was
|
||||||
loaded from a checkpoint, we will restart training at the appropriate
|
loaded from a checkpoint, we will restart training at the appropriate
|
||||||
epoch and run for (max_epochs - checkpoint_epoch) epochs.
|
epoch and run for (max_epochs - checkpoint_epoch) epochs.
|
||||||
@ -82,8 +89,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
# Parameters of the outer training loop.
|
# Parameters of the outer training loop.
|
||||||
eval_only: bool = False
|
eval_only: bool = False
|
||||||
evaluator: EvaluatorBase
|
|
||||||
evaluator_class_type: str = "ImplicitronEvaluator"
|
|
||||||
max_epochs: int = 1000
|
max_epochs: int = 1000
|
||||||
store_checkpoints: bool = True
|
store_checkpoints: bool = True
|
||||||
store_checkpoints_purge: int = 1
|
store_checkpoints_purge: int = 1
|
||||||
|
@ -406,8 +406,13 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
|
|||||||
linear_exponential_lr_milestone: 200
|
linear_exponential_lr_milestone: 200
|
||||||
linear_exponential_start_gamma: 0.1
|
linear_exponential_start_gamma: 0.1
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
eval_only: false
|
|
||||||
evaluator_class_type: ImplicitronEvaluator
|
evaluator_class_type: ImplicitronEvaluator
|
||||||
|
evaluator_ImplicitronEvaluator_args:
|
||||||
|
is_multisequence: false
|
||||||
|
camera_difficulty_bin_breaks:
|
||||||
|
- 0.97
|
||||||
|
- 0.98
|
||||||
|
eval_only: false
|
||||||
max_epochs: 1000
|
max_epochs: 1000
|
||||||
store_checkpoints: true
|
store_checkpoints: true
|
||||||
store_checkpoints_purge: 1
|
store_checkpoints_purge: 1
|
||||||
@ -420,8 +425,3 @@ training_loop_ImplicitronTrainingLoop_args:
|
|||||||
visdom_env: ''
|
visdom_env: ''
|
||||||
visdom_port: 8097
|
visdom_port: 8097
|
||||||
visdom_server: http://127.0.0.1
|
visdom_server: http://127.0.0.1
|
||||||
evaluator_ImplicitronEvaluator_args:
|
|
||||||
camera_difficulty_bin_breaks:
|
|
||||||
- 0.97
|
|
||||||
- 0.98
|
|
||||||
is_multisequence: false
|
|
||||||
|
@ -190,6 +190,34 @@ class TestNerfRepro(unittest.TestCase):
|
|||||||
experiment.dump_cfg(cfg)
|
experiment.dump_cfg(cfg)
|
||||||
experiment_runner.run()
|
experiment_runner.run()
|
||||||
|
|
||||||
|
@unittest.skip("This test runs nerf training on co3d v2 - manyview.")
|
||||||
|
def test_nerf_co3dv2_manyview(self):
|
||||||
|
# Train NERF
|
||||||
|
if not interactive_testing_requested():
|
||||||
|
return
|
||||||
|
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||||
|
cfg = compose(
|
||||||
|
config_name="repro_singleseq_v2_nerf",
|
||||||
|
overrides=[],
|
||||||
|
)
|
||||||
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
|
experiment.dump_cfg(cfg)
|
||||||
|
experiment_runner.run()
|
||||||
|
|
||||||
|
@unittest.skip("This test runs nerformer training on co3d v2 - fewview.")
|
||||||
|
def test_nerformer_co3dv2_fewview(self):
|
||||||
|
# Train NeRFormer
|
||||||
|
if not interactive_testing_requested():
|
||||||
|
return
|
||||||
|
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||||
|
cfg = compose(
|
||||||
|
config_name="repro_multiseq_v2_nerformer",
|
||||||
|
overrides=[],
|
||||||
|
)
|
||||||
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
|
experiment.dump_cfg(cfg)
|
||||||
|
experiment_runner.run()
|
||||||
|
|
||||||
@unittest.skip("This test checks resuming of the NeRF training.")
|
@unittest.skip("This test checks resuming of the NeRF training.")
|
||||||
def test_nerf_blender_resume(self):
|
def test_nerf_blender_resume(self):
|
||||||
# Train one train batch of NeRF, then resume for one more batch.
|
# Train one train batch of NeRF, then resume for one more batch.
|
||||||
|
@ -36,6 +36,7 @@ from pytorch3d.io import IO
|
|||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import types
|
from . import types
|
||||||
from .dataset_base import DatasetBase, FrameData
|
from .dataset_base import DatasetBase, FrameData
|
||||||
@ -338,9 +339,10 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
Returns the cameras corresponding to all the known frames.
|
Returns the cameras corresponding to all the known frames.
|
||||||
"""
|
"""
|
||||||
|
logger.info("Loading all train cameras.")
|
||||||
cameras = []
|
cameras = []
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
for frame_idx, frame_annot in enumerate(self.frame_annots):
|
for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
|
||||||
frame_type = self._get_frame_type(frame_annot)
|
frame_type = self._get_frame_type(frame_annot)
|
||||||
if frame_type is None:
|
if frame_type is None:
|
||||||
raise ValueError("subsets not loaded")
|
raise ValueError("subsets not loaded")
|
||||||
|
@ -14,6 +14,7 @@ from collections import defaultdict
|
|||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from iopath.common.file_io import PathManager
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
@ -383,12 +384,11 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _get_available_subset_names(self):
|
def _get_available_subset_names(self):
|
||||||
path_manager = self.path_manager_factory.get()
|
return get_available_subset_names(
|
||||||
if path_manager is not None:
|
self.dataset_root,
|
||||||
dataset_root = path_manager.get_local_path(self.dataset_root)
|
self.category,
|
||||||
else:
|
path_manager=self.path_manager_factory.get(),
|
||||||
dataset_root = self.dataset_root
|
)
|
||||||
return get_available_subset_names(dataset_root, self.category)
|
|
||||||
|
|
||||||
def _extend_test_data_with_known_views(
|
def _extend_test_data_with_known_views(
|
||||||
self,
|
self,
|
||||||
@ -425,18 +425,30 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
return eval_batch_index_out, list(test_subset_mapping_set)
|
return eval_batch_index_out, list(test_subset_mapping_set)
|
||||||
|
|
||||||
|
|
||||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
def get_available_subset_names(
|
||||||
|
dataset_root: str,
|
||||||
|
category: str,
|
||||||
|
path_manager: Optional[PathManager] = None,
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the available subset names for a given category folder inside a root dataset
|
Get the available subset names for a given category folder inside a root dataset
|
||||||
folder `dataset_root`.
|
folder `dataset_root`.
|
||||||
"""
|
"""
|
||||||
category_dir = os.path.join(dataset_root, category)
|
category_dir = os.path.join(dataset_root, category)
|
||||||
if not os.path.isdir(category_dir):
|
category_dir_exists = (
|
||||||
|
(path_manager is not None) and path_manager.isdir(category_dir)
|
||||||
|
) or os.path.isdir(category_dir)
|
||||||
|
if not category_dir_exists:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Looking for dataset files in {category_dir}. "
|
f"Looking for dataset files in {category_dir}. "
|
||||||
+ "Please specify a correct dataset_root folder."
|
+ "Please specify a correct dataset_root folder."
|
||||||
)
|
)
|
||||||
set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists"))
|
|
||||||
|
set_list_dir = os.path.join(category_dir, "set_lists")
|
||||||
|
set_list_jsons = (os.listdir if path_manager is None else path_manager.ls)(
|
||||||
|
set_list_dir
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
json_file.replace("set_lists_", "").replace(".json", "")
|
json_file.replace("set_lists_", "").replace(".json", "")
|
||||||
for json_file in set_list_jsons
|
for json_file in set_list_jsons
|
||||||
|
@ -36,6 +36,8 @@ class EvaluatorBase(ReplaceableBase):
|
|||||||
names and their values.
|
names and their values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
is_multisequence: bool = False
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
|
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -56,7 +58,6 @@ class ImplicitronEvaluator(EvaluatorBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||||
is_multisequence: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user