mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-30 02:28:56 +08:00
CO3Dv2 trainer configs
Summary: Adds yaml configs to train selected methods on CO3Dv2. Few more updates: 1) moved some fields to base classes so that we can check is_multisequence in experiment.py 2) skip loading all train cameras for multisequence datasets, without this, co3d-fewview is untrainable 3) fix bug in json index dataset provider v2 Reviewed By: kjchalup Differential Revision: D38952755 fbshipit-source-id: 3edac6fc8e20775aa70400bd73a0e6d52b091e0c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
03562d87f5
commit
1163eaab43
@@ -36,6 +36,7 @@ from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import types
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
@@ -338,9 +339,10 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
"""
|
||||
Returns the cameras corresponding to all the known frames.
|
||||
"""
|
||||
logger.info("Loading all train cameras.")
|
||||
cameras = []
|
||||
# pyre-ignore[16]
|
||||
for frame_idx, frame_annot in enumerate(self.frame_annots):
|
||||
for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
|
||||
frame_type = self._get_frame_type(frame_annot)
|
||||
if frame_type is None:
|
||||
raise ValueError("subsets not loaded")
|
||||
|
||||
@@ -14,6 +14,7 @@ from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from iopath.common.file_io import PathManager
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
@@ -383,12 +384,11 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
return data
|
||||
|
||||
def _get_available_subset_names(self):
|
||||
path_manager = self.path_manager_factory.get()
|
||||
if path_manager is not None:
|
||||
dataset_root = path_manager.get_local_path(self.dataset_root)
|
||||
else:
|
||||
dataset_root = self.dataset_root
|
||||
return get_available_subset_names(dataset_root, self.category)
|
||||
return get_available_subset_names(
|
||||
self.dataset_root,
|
||||
self.category,
|
||||
path_manager=self.path_manager_factory.get(),
|
||||
)
|
||||
|
||||
def _extend_test_data_with_known_views(
|
||||
self,
|
||||
@@ -425,18 +425,30 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
return eval_batch_index_out, list(test_subset_mapping_set)
|
||||
|
||||
|
||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||
def get_available_subset_names(
|
||||
dataset_root: str,
|
||||
category: str,
|
||||
path_manager: Optional[PathManager] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the available subset names for a given category folder inside a root dataset
|
||||
folder `dataset_root`.
|
||||
"""
|
||||
category_dir = os.path.join(dataset_root, category)
|
||||
if not os.path.isdir(category_dir):
|
||||
category_dir_exists = (
|
||||
(path_manager is not None) and path_manager.isdir(category_dir)
|
||||
) or os.path.isdir(category_dir)
|
||||
if not category_dir_exists:
|
||||
raise ValueError(
|
||||
f"Looking for dataset files in {category_dir}. "
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists"))
|
||||
|
||||
set_list_dir = os.path.join(category_dir, "set_lists")
|
||||
set_list_jsons = (os.listdir if path_manager is None else path_manager.ls)(
|
||||
set_list_dir
|
||||
)
|
||||
|
||||
return [
|
||||
json_file.replace("set_lists_", "").replace(".json", "")
|
||||
for json_file in set_list_jsons
|
||||
|
||||
@@ -36,6 +36,8 @@ class EvaluatorBase(ReplaceableBase):
|
||||
names and their values.
|
||||
"""
|
||||
|
||||
is_multisequence: bool = False
|
||||
|
||||
def run(
|
||||
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
@@ -56,7 +58,6 @@ class ImplicitronEvaluator(EvaluatorBase):
|
||||
"""
|
||||
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
is_multisequence: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
Reference in New Issue
Block a user