From 69c6d06ed880ff83419a960aaa20de0e5753f9a6 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 20 May 2022 07:50:30 -0700 Subject: [PATCH] New file for ImplicitronDatasetBase Summary: Separate ImplicitronDatasetBase and FrameData (to be used by all data sources) from ImplicitronDataset (which is specific). Reviewed By: shapovalov Differential Revision: D36413111 fbshipit-source-id: 3725744cde2e08baa11aff4048237ba10c7efbc6 --- projects/implicitron_trainer/experiment.py | 6 +- .../visualize_reconstruction.py | 7 +- .../implicitron/dataset/dataloader_zoo.py | 2 +- pytorch3d/implicitron/dataset/dataset_base.py | 268 ++++++++++++++++++ pytorch3d/implicitron/dataset/dataset_zoo.py | 3 +- .../dataset/implicitron_dataset.py | 254 +---------------- .../dataset/scene_batch_sampler.py | 2 +- pytorch3d/implicitron/dataset/visualize.py | 3 +- pytorch3d/implicitron/eval_demo.py | 7 +- .../evaluation/evaluate_new_view_synthesis.py | 2 +- tests/implicitron/test_batch_sampler.py | 2 +- tests/implicitron/test_evaluation.py | 6 +- 12 files changed, 288 insertions(+), 274 deletions(-) create mode 100644 pytorch3d/implicitron/dataset/dataset_base.py diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index d042f97e..738d848a 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -66,11 +66,9 @@ from packaging import version from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders +from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_zoo import Datasets -from pytorch3d.implicitron.dataset.implicitron_dataset import ( - FrameData, - ImplicitronDataset, -) +from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.tools import model_io, vis_utils diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 64f95224..66946ef5 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -24,12 +24,9 @@ import torch.nn.functional as Fu from experiment import init_model from omegaconf import OmegaConf from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource +from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo -from pytorch3d.implicitron.dataset.implicitron_dataset import ( - FrameData, - ImplicitronDataset, - ImplicitronDatasetBase, -) +from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.models.base_model import EvaluationMode from pytorch3d.implicitron.tools.configurable import get_default_args diff --git a/pytorch3d/implicitron/dataset/dataloader_zoo.py b/pytorch3d/implicitron/dataset/dataloader_zoo.py index 1f3bf3c2..bf7376af 100644 --- a/pytorch3d/implicitron/dataset/dataloader_zoo.py +++ b/pytorch3d/implicitron/dataset/dataloader_zoo.py @@ -10,8 +10,8 @@ from typing import Optional, Sequence import torch from pytorch3d.implicitron.tools.config import enable_get_default_args +from .dataset_base import FrameData, ImplicitronDatasetBase from .dataset_zoo import Datasets -from .implicitron_dataset import FrameData, ImplicitronDatasetBase from .scene_batch_sampler import SceneBatchSampler diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py new file mode 100644 index 00000000..9eb98d4f --- /dev/null +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -0,0 +1,268 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from dataclasses import dataclass, field, fields +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +import numpy as np +import torch +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds + + +@dataclass +class FrameData(Mapping[str, Any]): + """ + A type of the elements returned by indexing the dataset object. + It can represent both individual frames and batches of thereof; + in this documentation, the sizes of tensors refer to single frames; + add the first batch dimension for the collation result. + + Args: + frame_number: The number of the frame within its sequence. + 0-based continuous integers. + frame_timestamp: The time elapsed since the start of a sequence in sec. + sequence_name: The unique name of the frame's sequence. + sequence_category: The object category of the sequence. + image_size_hw: The size of the image in pixels; (height, width) tuple. + image_path: The qualified path to the loaded image (with dataset_root). + image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image + of the frame; elements are floats in [0, 1]. + mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image + regions. Regions can be invalid (mask_crop[i,j]=0) in case they + are a result of zero-padding of the image after cropping around + the object bounding box; elements are floats in {0.0, 1.0}. + depth_path: The qualified path to the frame's depth map. + depth_map: A float Tensor of shape `(1, H, W)` holding the depth map + of the frame; values correspond to distances from the camera; + use `depth_mask` and `mask_crop` to filter for valid pixels. + depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the + depth map that are valid for evaluation, they have been checked for + consistency across views; elements are floats in {0.0, 1.0}. + mask_path: A qualified path to the foreground probability mask. + fg_probability: A Tensor of `(1, H, W)` denoting the probability of the + pixels belonging to the captured object; elements are floats + in [0, 1]. + bbox_xywh: The bounding box capturing the object in the + format (x0, y0, width, height). + camera: A PyTorch3D camera object corresponding the frame's viewpoint, + corrected for cropping if it happened. + camera_quality_score: The score proportional to the confidence of the + frame's camera estimation (the higher the more accurate). + point_cloud_quality_score: The score proportional to the accuracy of the + frame's sequence point cloud (the higher the more accurate). + sequence_point_cloud_path: The path to the sequence's point cloud. + sequence_point_cloud: A PyTorch3D Pointclouds object holding the + point cloud corresponding to the frame's sequence. When the object + represents a batch of frames, point clouds may be deduplicated; + see `sequence_point_cloud_idx`. + sequence_point_cloud_idx: Integer indices mapping frame indices to the + corresponding point clouds in `sequence_point_cloud`; to get the + corresponding point cloud to `image_rgb[i]`, use + `sequence_point_cloud[sequence_point_cloud_idx[i]]`. + frame_type: The type of the loaded frame specified in + `subset_lists_file`, if provided. + meta: A dict for storing additional frame information. + """ + + frame_number: Optional[torch.LongTensor] + frame_timestamp: Optional[torch.Tensor] + sequence_name: Union[str, List[str]] + sequence_category: Union[str, List[str]] + image_size_hw: Optional[torch.Tensor] = None + image_path: Union[str, List[str], None] = None + image_rgb: Optional[torch.Tensor] = None + # masks out padding added due to cropping the square bit + mask_crop: Optional[torch.Tensor] = None + depth_path: Union[str, List[str], None] = None + depth_map: Optional[torch.Tensor] = None + depth_mask: Optional[torch.Tensor] = None + mask_path: Union[str, List[str], None] = None + fg_probability: Optional[torch.Tensor] = None + bbox_xywh: Optional[torch.Tensor] = None + camera: Optional[PerspectiveCameras] = None + camera_quality_score: Optional[torch.Tensor] = None + point_cloud_quality_score: Optional[torch.Tensor] = None + sequence_point_cloud_path: Union[str, List[str], None] = None + sequence_point_cloud: Optional[Pointclouds] = None + sequence_point_cloud_idx: Optional[torch.Tensor] = None + frame_type: Union[str, List[str], None] = None # seen | unseen + meta: dict = field(default_factory=lambda: {}) + + def to(self, *args, **kwargs): + new_params = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): + new_params[f.name] = value.to(*args, **kwargs) + else: + new_params[f.name] = value + return type(self)(**new_params) + + def cpu(self): + return self.to(device=torch.device("cpu")) + + def cuda(self): + return self.to(device=torch.device("cuda")) + + # the following functions make sure **frame_data can be passed to functions + def __iter__(self): + for f in fields(self): + yield f.name + + def __getitem__(self, key): + return getattr(self, key) + + def __len__(self): + return len(fields(self)) + + @classmethod + def collate(cls, batch): + """ + Given a list objects `batch` of class `cls`, collates them into a batched + representation suitable for processing with deep networks. + """ + + elem = batch[0] + + if isinstance(elem, cls): + pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] + id_to_idx = defaultdict(list) + for i, pc_id in enumerate(pointcloud_ids): + id_to_idx[pc_id].append(i) + + sequence_point_cloud = [] + sequence_point_cloud_idx = -np.ones((len(batch),)) + for i, ind in enumerate(id_to_idx.values()): + sequence_point_cloud_idx[ind] = i + sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) + assert (sequence_point_cloud_idx >= 0).all() + + override_fields = { + "sequence_point_cloud": sequence_point_cloud, + "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), + } + # note that the pre-collate value of sequence_point_cloud_idx is unused + + collated = {} + for f in fields(elem): + list_values = override_fields.get( + f.name, [getattr(d, f.name) for d in batch] + ) + collated[f.name] = ( + cls.collate(list_values) + if all(list_value is not None for list_value in list_values) + else None + ) + return cls(**collated) + + elif isinstance(elem, Pointclouds): + return join_pointclouds_as_batch(batch) + + elif isinstance(elem, CamerasBase): + # TODO: don't store K; enforce working in NDC space + return join_cameras_as_batch(batch) + else: + return torch.utils.data._utils.collate.default_collate(batch) + + +@dataclass(eq=False) +class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): + """ + Base class to describe a dataset to be used with Implicitron. + + The dataset is made up of frames, and the frames are grouped into sequences. + Each sequence has a name (a string). + (A sequence could be a video, or a set of images of one scene.) + + This means they have a __getitem__ which returns an instance of a FrameData, + which will describe one frame in one sequence. + """ + + # Maps sequence name to the sequence's global frame indices. + # It is used for the default implementations of some functions in this class. + # Implementations which override them are free to ignore this member. + _seq_to_idx: Dict[str, List[int]] = field(init=False) + + def __len__(self) -> int: + raise NotImplementedError + + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int] + ) -> List[Tuple[int, float]]: + """ + If the sequences in the dataset are videos rather than + unordered views, then the dataset should override this method to + return the index and timestamp in their videos of the frames whose + indices are given in `idxs`. In addition, + the values in _seq_to_idx should be in ascending order. + If timestamps are absent, they should be replaced with a constant. + + This is used for letting SceneBatchSampler identify consecutive + frames. + + Args: + idx: frame index in self + + Returns: + tuple of + - frame index in video + - timestamp of frame in video + """ + raise ValueError("This dataset does not contain videos.") + + def get_eval_batches(self) -> Optional[List[List[int]]]: + return None + + def sequence_names(self) -> Iterable[str]: + """Returns an iterator over sequence names in the dataset.""" + return self._seq_to_idx.keys() + + def sequence_frames_in_order( + self, seq_name: str + ) -> Iterator[Tuple[float, int, int]]: + """Returns an iterator over the frame indices in a given sequence. + We attempt to first sort by timestamp (if they are available), + then by frame number. + + Args: + seq_name: the name of the sequence. + + Returns: + an iterator over triplets `(timestamp, frame_no, dataset_idx)`, + where `frame_no` is the index within the sequence, and + `dataset_idx` is the index within the dataset. + `None` timestamps are replaced with 0s. + """ + seq_frame_indices = self._seq_to_idx[seq_name] + nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) + + yield from sorted( + [ + (timestamp, frame_no, idx) + for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps) + ] + ) + + def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]: + """Same as `sequence_frames_in_order` but returns the iterator over + only dataset indices. + """ + for _, _, idx in self.sequence_frames_in_order(seq_name): + yield idx diff --git a/pytorch3d/implicitron/dataset/dataset_zoo.py b/pytorch3d/implicitron/dataset/dataset_zoo.py index a32e0a86..023b538b 100644 --- a/pytorch3d/implicitron/dataset/dataset_zoo.py +++ b/pytorch3d/implicitron/dataset/dataset_zoo.py @@ -13,7 +13,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence from iopath.common.file_io import PathManager from pytorch3d.implicitron.tools.config import enable_get_default_args -from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase +from .dataset_base import ImplicitronDatasetBase +from .implicitron_dataset import ImplicitronDataset from .utils import ( DATASET_TYPE_KNOWN, DATASET_TYPE_TEST, diff --git a/pytorch3d/implicitron/dataset/implicitron_dataset.py b/pytorch3d/implicitron/dataset/implicitron_dataset.py index 65fe43a3..d5c158bd 100644 --- a/pytorch3d/implicitron/dataset/implicitron_dataset.py +++ b/pytorch3d/implicitron/dataset/implicitron_dataset.py @@ -13,17 +13,13 @@ import os import random import warnings from collections import defaultdict -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from itertools import islice from pathlib import Path from typing import ( - Any, ClassVar, Dict, - Iterable, - Iterator, List, - Mapping, Optional, Sequence, Tuple, @@ -37,258 +33,16 @@ import torch from iopath.common.file_io import PathManager from PIL import Image from pytorch3d.io import IO -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds +from pytorch3d.renderer.cameras import PerspectiveCameras +from pytorch3d.structures.pointclouds import Pointclouds from . import types +from .dataset_base import FrameData, ImplicitronDatasetBase logger = logging.getLogger(__name__) -@dataclass -class FrameData(Mapping[str, Any]): - """ - A type of the elements returned by indexing the dataset object. - It can represent both individual frames and batches of thereof; - in this documentation, the sizes of tensors refer to single frames; - add the first batch dimension for the collation result. - - Args: - frame_number: The number of the frame within its sequence. - 0-based continuous integers. - frame_timestamp: The time elapsed since the start of a sequence in sec. - sequence_name: The unique name of the frame's sequence. - sequence_category: The object category of the sequence. - image_size_hw: The size of the image in pixels; (height, width) tuple. - image_path: The qualified path to the loaded image (with dataset_root). - image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image - of the frame; elements are floats in [0, 1]. - mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image - regions. Regions can be invalid (mask_crop[i,j]=0) in case they - are a result of zero-padding of the image after cropping around - the object bounding box; elements are floats in {0.0, 1.0}. - depth_path: The qualified path to the frame's depth map. - depth_map: A float Tensor of shape `(1, H, W)` holding the depth map - of the frame; values correspond to distances from the camera; - use `depth_mask` and `mask_crop` to filter for valid pixels. - depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the - depth map that are valid for evaluation, they have been checked for - consistency across views; elements are floats in {0.0, 1.0}. - mask_path: A qualified path to the foreground probability mask. - fg_probability: A Tensor of `(1, H, W)` denoting the probability of the - pixels belonging to the captured object; elements are floats - in [0, 1]. - bbox_xywh: The bounding box capturing the object in the - format (x0, y0, width, height). - camera: A PyTorch3D camera object corresponding the frame's viewpoint, - corrected for cropping if it happened. - camera_quality_score: The score proportional to the confidence of the - frame's camera estimation (the higher the more accurate). - point_cloud_quality_score: The score proportional to the accuracy of the - frame's sequence point cloud (the higher the more accurate). - sequence_point_cloud_path: The path to the sequence's point cloud. - sequence_point_cloud: A PyTorch3D Pointclouds object holding the - point cloud corresponding to the frame's sequence. When the object - represents a batch of frames, point clouds may be deduplicated; - see `sequence_point_cloud_idx`. - sequence_point_cloud_idx: Integer indices mapping frame indices to the - corresponding point clouds in `sequence_point_cloud`; to get the - corresponding point cloud to `image_rgb[i]`, use - `sequence_point_cloud[sequence_point_cloud_idx[i]]`. - frame_type: The type of the loaded frame specified in - `subset_lists_file`, if provided. - meta: A dict for storing additional frame information. - """ - - frame_number: Optional[torch.LongTensor] - frame_timestamp: Optional[torch.Tensor] - sequence_name: Union[str, List[str]] - sequence_category: Union[str, List[str]] - image_size_hw: Optional[torch.Tensor] = None - image_path: Union[str, List[str], None] = None - image_rgb: Optional[torch.Tensor] = None - # masks out padding added due to cropping the square bit - mask_crop: Optional[torch.Tensor] = None - depth_path: Union[str, List[str], None] = None - depth_map: Optional[torch.Tensor] = None - depth_mask: Optional[torch.Tensor] = None - mask_path: Union[str, List[str], None] = None - fg_probability: Optional[torch.Tensor] = None - bbox_xywh: Optional[torch.Tensor] = None - camera: Optional[PerspectiveCameras] = None - camera_quality_score: Optional[torch.Tensor] = None - point_cloud_quality_score: Optional[torch.Tensor] = None - sequence_point_cloud_path: Union[str, List[str], None] = None - sequence_point_cloud: Optional[Pointclouds] = None - sequence_point_cloud_idx: Optional[torch.Tensor] = None - frame_type: Union[str, List[str], None] = None # seen | unseen - meta: dict = field(default_factory=lambda: {}) - - def to(self, *args, **kwargs): - new_params = {} - for f in fields(self): - value = getattr(self, f.name) - if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): - new_params[f.name] = value.to(*args, **kwargs) - else: - new_params[f.name] = value - return type(self)(**new_params) - - def cpu(self): - return self.to(device=torch.device("cpu")) - - def cuda(self): - return self.to(device=torch.device("cuda")) - - # the following functions make sure **frame_data can be passed to functions - def __iter__(self): - for f in fields(self): - yield f.name - - def __getitem__(self, key): - return getattr(self, key) - - def __len__(self): - return len(fields(self)) - - @classmethod - def collate(cls, batch): - """ - Given a list objects `batch` of class `cls`, collates them into a batched - representation suitable for processing with deep networks. - """ - - elem = batch[0] - - if isinstance(elem, cls): - pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] - id_to_idx = defaultdict(list) - for i, pc_id in enumerate(pointcloud_ids): - id_to_idx[pc_id].append(i) - - sequence_point_cloud = [] - sequence_point_cloud_idx = -np.ones((len(batch),)) - for i, ind in enumerate(id_to_idx.values()): - sequence_point_cloud_idx[ind] = i - sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) - assert (sequence_point_cloud_idx >= 0).all() - - override_fields = { - "sequence_point_cloud": sequence_point_cloud, - "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), - } - # note that the pre-collate value of sequence_point_cloud_idx is unused - - collated = {} - for f in fields(elem): - list_values = override_fields.get( - f.name, [getattr(d, f.name) for d in batch] - ) - collated[f.name] = ( - cls.collate(list_values) - if all(list_value is not None for list_value in list_values) - else None - ) - return cls(**collated) - - elif isinstance(elem, Pointclouds): - return join_pointclouds_as_batch(batch) - - elif isinstance(elem, CamerasBase): - # TODO: don't store K; enforce working in NDC space - return join_cameras_as_batch(batch) - else: - return torch.utils.data._utils.collate.default_collate(batch) - - -@dataclass(eq=False) -class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): - """ - Base class to describe a dataset to be used with Implicitron. - - The dataset is made up of frames, and the frames are grouped into sequences. - Each sequence has a name (a string). - (A sequence could be a video, or a set of images of one scene.) - - This means they have a __getitem__ which returns an instance of a FrameData, - which will describe one frame in one sequence. - """ - - # Maps sequence name to the sequence's global frame indices. - # It is used for the default implementations of some functions in this class. - # Implementations which override them are free to ignore this member. - _seq_to_idx: Dict[str, List[int]] = field(init=False) - - def __len__(self) -> int: - raise NotImplementedError - - def get_frame_numbers_and_timestamps( - self, idxs: Sequence[int] - ) -> List[Tuple[int, float]]: - """ - If the sequences in the dataset are videos rather than - unordered views, then the dataset should override this method to - return the index and timestamp in their videos of the frames whose - indices are given in `idxs`. In addition, - the values in _seq_to_idx should be in ascending order. - If timestamps are absent, they should be replaced with a constant. - - This is used for letting SceneBatchSampler identify consecutive - frames. - - Args: - idx: frame index in self - - Returns: - tuple of - - frame index in video - - timestamp of frame in video - """ - raise ValueError("This dataset does not contain videos.") - - def get_eval_batches(self) -> Optional[List[List[int]]]: - return None - - def sequence_names(self) -> Iterable[str]: - """Returns an iterator over sequence names in the dataset.""" - return self._seq_to_idx.keys() - - def sequence_frames_in_order( - self, seq_name: str - ) -> Iterator[Tuple[float, int, int]]: - """Returns an iterator over the frame indices in a given sequence. - We attempt to first sort by timestamp (if they are available), - then by frame number. - - Args: - seq_name: the name of the sequence. - - Returns: - an iterator over triplets `(timestamp, frame_no, dataset_idx)`, - where `frame_no` is the index within the sequence, and - `dataset_idx` is the index within the dataset. - `None` timestamps are replaced with 0s. - """ - seq_frame_indices = self._seq_to_idx[seq_name] - nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) - - yield from sorted( - [ - (timestamp, frame_no, idx) - for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps) - ] - ) - - def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]: - """Same as `sequence_frames_in_order` but returns the iterator over - only dataset indices. - """ - for _, _, idx in self.sequence_frames_in_order(seq_name): - yield idx - - class FrameAnnotsEntry(TypedDict): subset: Optional[str] frame_annotation: types.FrameAnnotation diff --git a/pytorch3d/implicitron/dataset/scene_batch_sampler.py b/pytorch3d/implicitron/dataset/scene_batch_sampler.py index fea6f19d..65973772 100644 --- a/pytorch3d/implicitron/dataset/scene_batch_sampler.py +++ b/pytorch3d/implicitron/dataset/scene_batch_sampler.py @@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple import numpy as np from torch.utils.data.sampler import Sampler -from .implicitron_dataset import ImplicitronDatasetBase +from .dataset_base import ImplicitronDatasetBase @dataclass(eq=False) # TODO: do we need this if not init from config? diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py index 4daa5f45..ebd957c5 100644 --- a/pytorch3d/implicitron/dataset/visualize.py +++ b/pytorch3d/implicitron/dataset/visualize.py @@ -10,7 +10,8 @@ import torch from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud from pytorch3d.structures import Pointclouds -from .implicitron_dataset import FrameData, ImplicitronDataset +from .dataset_base import FrameData +from .implicitron_dataset import ImplicitronDataset def get_implicitron_sequence_pointcloud( diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index 9b2e223c..a460ea0f 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -14,12 +14,9 @@ import torch from iopath.common.file_io import PathManager from pytorch3d.implicitron.dataset.data_source import Task from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo +from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo -from pytorch3d.implicitron.dataset.implicitron_dataset import ( - FrameData, - ImplicitronDataset, - ImplicitronDatasetBase, -) +from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.utils import is_known_frame from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( aggregate_nvs_results, diff --git a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py index c1a16068..a894c6c4 100644 --- a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py +++ b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py @@ -15,7 +15,7 @@ import numpy as np import torch import torch.nn.functional as F from pytorch3d.implicitron.dataset.data_source import Task -from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData +from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.tools import vis_utils diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index a7025038..0029783a 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -9,7 +9,7 @@ import unittest from collections import defaultdict from dataclasses import dataclass -from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase +from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py index 78e8a16f..3731d036 100644 --- a/tests/implicitron/test_evaluation.py +++ b/tests/implicitron/test_evaluation.py @@ -14,10 +14,8 @@ import unittest import lpips import torch -from pytorch3d.implicitron.dataset.implicitron_dataset import ( - FrameData, - ImplicitronDataset, -) +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch from pytorch3d.implicitron.models.base_model import ImplicitronModelBase from pytorch3d.implicitron.models.generic_model import GenericModel # noqa