mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
get_all_train_cameras
Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images. Reviewed By: shapovalov Differential Revision: D37313423 fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
committed by
Facebook GitHub Bot
parent
771cf8a328
commit
4e87c2b7f1
@@ -66,9 +66,7 @@ from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
@@ -456,6 +454,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||
task = datasource.get_task()
|
||||
all_train_cameras = datasource.get_all_train_cameras()
|
||||
|
||||
# init the model
|
||||
model, stats, optimizer_state = init_model(cfg)
|
||||
@@ -466,7 +465,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
|
||||
# only run evaluation on the test dataloader
|
||||
if cfg.eval_only:
|
||||
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(
|
||||
cfg,
|
||||
task,
|
||||
all_train_cameras,
|
||||
datasets,
|
||||
dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
)
|
||||
return
|
||||
|
||||
# init the optimizer
|
||||
@@ -528,7 +536,9 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
and cfg.test_interval > 0
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
_run_eval(model, stats, dataloaders.test, task, device=device)
|
||||
_run_eval(
|
||||
model, all_train_cameras, dataloaders.test, task, device=device
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
@@ -548,12 +558,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||
|
||||
if cfg.test_when_finished:
|
||||
_eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(
|
||||
cfg,
|
||||
task,
|
||||
all_train_cameras,
|
||||
datasets,
|
||||
dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def _eval_and_dump(
|
||||
cfg,
|
||||
task: Task,
|
||||
all_train_cameras: Optional[CamerasBase],
|
||||
datasets: DatasetMap,
|
||||
dataloaders: DataLoaderMap,
|
||||
model,
|
||||
@@ -570,13 +590,7 @@ def _eval_and_dump(
|
||||
if dataloader is None:
|
||||
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
if datasets.train is None:
|
||||
raise ValueError("train dataset must be provided")
|
||||
all_source_cameras = _get_all_source_cameras(datasets.train)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
|
||||
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
for r in results:
|
||||
@@ -603,7 +617,7 @@ def _get_eval_frame_data(frame_data):
|
||||
return frame_data_for_eval
|
||||
|
||||
|
||||
def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
def _run_eval(model, all_train_cameras, loader, task: Task, device):
|
||||
"""
|
||||
Run the evaluation loop on the test dataloader
|
||||
"""
|
||||
@@ -631,7 +645,7 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
implicitron_render,
|
||||
bg_color="black",
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
source_cameras=all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -642,31 +656,6 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: JsonIndexDataset,
|
||||
num_workers: int = 8,
|
||||
) -> CamerasBase:
|
||||
"""
|
||||
Load and return all the source cameras in the training dataset
|
||||
"""
|
||||
|
||||
all_frame_data = next(
|
||||
iter(
|
||||
torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=len(dataset),
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
|
||||
return source_cameras
|
||||
|
||||
|
||||
def _seed_all_random_engines(seed: int):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user