New file for ImplicitronDatasetBase

Summary: Separate ImplicitronDatasetBase and FrameData (to be used by all data sources) from ImplicitronDataset (which is specific).

Reviewed By: shapovalov

Differential Revision: D36413111

fbshipit-source-id: 3725744cde2e08baa11aff4048237ba10c7efbc6
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 73dc109dba
commit 69c6d06ed8
12 changed files with 288 additions and 274 deletions

View File

@ -66,11 +66,9 @@ from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils

View File

@ -24,12 +24,9 @@ import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args

View File

@ -10,8 +10,8 @@ from typing import Optional, Sequence
import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_zoo import Datasets
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
from .scene_batch_sampler import SceneBatchSampler

View File

@ -0,0 +1,268 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np
import torch
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
@dataclass
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box capturing the object in the
format (x0, y0, width, height).
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
frame_timestamp: Optional[torch.Tensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # seen | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
@dataclass(eq=False)
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.
The dataset is made up of frames, and the frames are grouped into sequences.
Each sequence has a name (a string).
(A sequence could be a video, or a set of images of one scene.)
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
"""
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
_seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
"""
If the sequences in the dataset are videos rather than
unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition,
the values in _seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idx: frame index in self
Returns:
tuple of
- frame index in video
- timestamp of frame in video
"""
raise ValueError("This dataset does not contain videos.")
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self._seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
"""Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self._seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx

View File

@ -13,7 +13,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
from .dataset_base import ImplicitronDatasetBase
from .implicitron_dataset import ImplicitronDataset
from .utils import (
DATASET_TYPE_KNOWN,
DATASET_TYPE_TEST,

View File

@ -13,17 +13,13 @@ import os
import random
import warnings
from collections import defaultdict
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from itertools import islice
from pathlib import Path
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
@ -37,258 +33,16 @@ import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.io import IO
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
from . import types
from .dataset_base import FrameData, ImplicitronDatasetBase
logger = logging.getLogger(__name__)
@dataclass
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box capturing the object in the
format (x0, y0, width, height).
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
frame_timestamp: Optional[torch.Tensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # seen | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
@dataclass(eq=False)
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.
The dataset is made up of frames, and the frames are grouped into sequences.
Each sequence has a name (a string).
(A sequence could be a video, or a set of images of one scene.)
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
"""
# Maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
_seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
"""
If the sequences in the dataset are videos rather than
unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition,
the values in _seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idx: frame index in self
Returns:
tuple of
- frame index in video
- timestamp of frame in video
"""
raise ValueError("This dataset does not contain videos.")
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self._seq_to_idx.keys()
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
"""Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
seq_frame_indices = self._seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
yield from sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx
class FrameAnnotsEntry(TypedDict):
subset: Optional[str]
frame_annotation: types.FrameAnnotation

View File

@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
from .implicitron_dataset import ImplicitronDatasetBase
from .dataset_base import ImplicitronDatasetBase
@dataclass(eq=False) # TODO: do we need this if not init from config?

View File

@ -10,7 +10,8 @@ import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds
from .implicitron_dataset import FrameData, ImplicitronDataset
from .dataset_base import FrameData
from .implicitron_dataset import ImplicitronDataset
def get_implicitron_sequence_pointcloud(

View File

@ -14,12 +14,9 @@ import torch
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results,

View File

@ -15,7 +15,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender
from pytorch3d.implicitron.tools import vis_utils

View File

@ -9,7 +9,7 @@ import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler

View File

@ -14,10 +14,8 @@ import unittest
import lpips
import torch
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa