mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
297020a4b1 | ||
|
|
062e6c54ae | ||
|
|
c80180c96e | ||
|
|
23cd19fbc7 | ||
|
|
092400f1e7 | ||
|
|
ec87284c4b | ||
|
|
f5a117c74b | ||
|
|
b921efae3e | ||
|
|
c8d6cd427e | ||
|
|
ef5f620263 | ||
|
|
3e3644e534 | ||
|
|
178a7774d4 | ||
|
|
823ab75d27 | ||
|
|
32e1992924 | ||
|
|
7aeedd17a4 | ||
|
|
0e3138eca8 | ||
|
|
1af6bf4768 | ||
|
|
355d6332cb |
@@ -180,30 +180,6 @@ workflows:
|
||||
jobs:
|
||||
# - main:
|
||||
# context: DOCKERHUB_TOKEN
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py38_cu102_pyt190
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.9.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu111
|
||||
name: linux_conda_py38_cu111_pyt190
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.9.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py38_cu102_pyt191
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.9.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu111
|
||||
name: linux_conda_py38_cu111_pyt191
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.9.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
@@ -370,29 +346,19 @@ workflows:
|
||||
python_version: '3.8'
|
||||
pytorch_version: 2.0.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda117
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py39_cu102_pyt190
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.9.0
|
||||
cu_version: cu117
|
||||
name: linux_conda_py38_cu117_pyt201
|
||||
python_version: '3.8'
|
||||
pytorch_version: 2.0.1
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda118
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu111
|
||||
name: linux_conda_py39_cu111_pyt190
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.9.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py39_cu102_pyt191
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.9.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu111
|
||||
name: linux_conda_py39_cu111_pyt191
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.9.1
|
||||
cu_version: cu118
|
||||
name: linux_conda_py38_cu118_pyt201
|
||||
python_version: '3.8'
|
||||
pytorch_version: 2.0.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
@@ -558,6 +524,20 @@ workflows:
|
||||
name: linux_conda_py39_cu118_pyt200
|
||||
python_version: '3.9'
|
||||
pytorch_version: 2.0.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda117
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu117
|
||||
name: linux_conda_py39_cu117_pyt201
|
||||
python_version: '3.9'
|
||||
pytorch_version: 2.0.1
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda118
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu118
|
||||
name: linux_conda_py39_cu118_pyt201
|
||||
python_version: '3.9'
|
||||
pytorch_version: 2.0.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
@@ -666,6 +646,20 @@ workflows:
|
||||
name: linux_conda_py310_cu118_pyt200
|
||||
python_version: '3.10'
|
||||
pytorch_version: 2.0.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda117
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu117
|
||||
name: linux_conda_py310_cu117_pyt201
|
||||
python_version: '3.10'
|
||||
pytorch_version: 2.0.1
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda118
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu118
|
||||
name: linux_conda_py310_cu118_pyt201
|
||||
python_version: '3.10'
|
||||
pytorch_version: 2.0.1
|
||||
- binary_linux_conda_cuda:
|
||||
name: testrun_conda_cuda_py38_cu102_pyt190
|
||||
context: DOCKERHUB_TOKEN
|
||||
|
||||
@@ -20,8 +20,6 @@ from packaging import version
|
||||
# version of pytorch.
|
||||
# Pytorch 1.4 also supports cuda 10.0 but we no longer build for cuda 10.0 at all.
|
||||
CONDA_CUDA_VERSIONS = {
|
||||
"1.9.0": ["cu102", "cu111"],
|
||||
"1.9.1": ["cu102", "cu111"],
|
||||
"1.10.0": ["cu102", "cu111", "cu113"],
|
||||
"1.10.1": ["cu102", "cu111", "cu113"],
|
||||
"1.10.2": ["cu102", "cu111", "cu113"],
|
||||
@@ -31,6 +29,7 @@ CONDA_CUDA_VERSIONS = {
|
||||
"1.13.0": ["cu116", "cu117"],
|
||||
"1.13.1": ["cu116", "cu117"],
|
||||
"2.0.0": ["cu117", "cu118"],
|
||||
"2.0.1": ["cu117", "cu118"],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ The core library is written in PyTorch. Several components have underlying imple
|
||||
|
||||
- Linux or macOS or Windows
|
||||
- Python 3.8, 3.9 or 3.10
|
||||
- PyTorch 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0 or 2.0.0.
|
||||
- PyTorch 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 2.0.0 or 2.0.1.
|
||||
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
|
||||
- gcc & g++ ≥ 4.9
|
||||
- [fvcore](https://github.com/facebookresearch/fvcore)
|
||||
|
||||
@@ -20,7 +20,8 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import mock
|
||||
import unittest.mock as mock
|
||||
|
||||
from recommonmark.parser import CommonMarkParser
|
||||
from recommonmark.states import DummyStateMachine
|
||||
from sphinx.builders.html import StandaloneHTMLBuilder
|
||||
|
||||
@@ -85,7 +85,7 @@ cameras_ndc = PerspectiveCameras(focal_length=fcl_ndc, principal_point=prp_ndc)
|
||||
# Screen space camera
|
||||
image_size = ((128, 256),) # (h, w)
|
||||
fcl_screen = (76.8,) # fcl_ndc * min(image_size) / 2
|
||||
prp_screen = ((115.2, 48), ) # w / 2 - px_ndc * min(image_size) / 2, h / 2 - py_ndc * min(image_size) / 2
|
||||
prp_screen = ((115.2, 32), ) # w / 2 - px_ndc * min(image_size) / 2, h / 2 - py_ndc * min(image_size) / 2
|
||||
cameras_screen = PerspectiveCameras(focal_length=fcl_screen, principal_point=prp_screen, in_ndc=False, image_size=image_size)
|
||||
```
|
||||
|
||||
|
||||
@@ -4,4 +4,4 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
__version__ = "0.7.3"
|
||||
__version__ = "0.7.4"
|
||||
|
||||
@@ -266,6 +266,8 @@ at::Tensor FaceAreasNormalsBackwardCuda(
|
||||
grad_normals_t{grad_normals, "grad_normals", 4};
|
||||
at::CheckedFrom c = "FaceAreasNormalsBackwardCuda";
|
||||
at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("FaceAreasNormalsBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of verts
|
||||
at::cuda::CUDAGuard device_guard(verts.device());
|
||||
|
||||
@@ -130,6 +130,9 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
|
||||
at::checkAllSameType(
|
||||
c, {barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
|
||||
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("InterpFaceAttrsBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the input
|
||||
at::cuda::CUDAGuard device_guard(pix_to_face.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -534,6 +534,9 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
|
||||
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(p1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -305,6 +305,8 @@ std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
|
||||
at::CheckedFrom c = "DistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {objects_t, targets_t, idx_objects_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {objects_t, targets_t, grad_dists_t});
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("DistanceBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(objects.device());
|
||||
@@ -624,6 +626,9 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
|
||||
at::CheckedFrom c = "PointFaceArrayDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, tris_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic(
|
||||
"PointFaceArrayDistanceBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
@@ -787,6 +792,9 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
||||
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic(
|
||||
"PointEdgeArrayDistanceBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
|
||||
@@ -141,6 +141,9 @@ void PointsToVolumesForwardCuda(
|
||||
grid_sizes_t,
|
||||
mask_t});
|
||||
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("PointsToVolumesForwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points_3d.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -583,6 +583,9 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
at::checkAllSameType(
|
||||
c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
|
||||
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("RasterizeMeshesBackwardCuda");
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -423,7 +423,8 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
at::CheckedFrom c = "RasterizePointsBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t});
|
||||
|
||||
// This is nondeterministic because atomicAdd
|
||||
at::globalContext().alertNotDeterministic("RasterizePointsBackwardCuda");
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -155,7 +155,7 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
|
||||
// Max possible threads per block
|
||||
const int MAX_THREADS_PER_BLOCK = 1024;
|
||||
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1);
|
||||
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 2);
|
||||
|
||||
// Create the accessors
|
||||
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
|
||||
@@ -215,10 +215,6 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
|
||||
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
|
||||
break;
|
||||
case 1:
|
||||
FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>(
|
||||
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
|
||||
break;
|
||||
default:
|
||||
FarthestPointSamplingKernel<1024>
|
||||
<<<blocks, threads, shared_mem, stream>>>(
|
||||
|
||||
@@ -450,6 +450,9 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
) -> FrameDataSubtype:
|
||||
"""An abstract method to build the frame data based on raw frame/sequence
|
||||
annotations, load the binary data and adjust them according to the metadata.
|
||||
@@ -465,8 +468,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
Beware that modifications of frame data are done in-place.
|
||||
|
||||
Args:
|
||||
dataset_root: The root folder of the dataset; all the paths in jsons are
|
||||
specified relative to this root (but not json paths themselves).
|
||||
dataset_root: The root folder of the dataset; all paths in frame / sequence
|
||||
annotations are defined w.r.t. this root. Has to be set if any of the
|
||||
load_* flabs below is true.
|
||||
load_images: Enable loading the frame RGB data.
|
||||
load_depths: Enable loading the frame depth maps.
|
||||
load_depth_masks: Enable loading the frame depth map masks denoting the
|
||||
@@ -494,7 +498,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
path_manager: Optionally a PathManager for interpreting paths in a special way.
|
||||
"""
|
||||
|
||||
dataset_root: str = ""
|
||||
dataset_root: Optional[str] = None
|
||||
load_images: bool = True
|
||||
load_depths: bool = True
|
||||
load_depth_masks: bool = True
|
||||
@@ -510,11 +514,37 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
box_crop_context: float = 0.3
|
||||
path_manager: Any = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
load_any_blob = (
|
||||
self.load_images
|
||||
or self.load_depths
|
||||
or self.load_depth_masks
|
||||
or self.load_masks
|
||||
or self.load_point_clouds
|
||||
)
|
||||
if load_any_blob and self.dataset_root is None:
|
||||
raise ValueError(
|
||||
"dataset_root must be set to load any blob data. "
|
||||
"Make sure it is set in either FrameDataBuilder or Dataset params."
|
||||
)
|
||||
|
||||
if self.path_manager is None:
|
||||
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
|
||||
else:
|
||||
dataset_root_exists = self.path_manager.isdir(self.dataset_root)
|
||||
|
||||
if load_any_blob and not dataset_root_exists:
|
||||
raise ValueError(
|
||||
f"dataset_root is passed but {self.dataset_root} does not exist."
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
) -> FrameDataSubtype:
|
||||
"""Builds the frame data based on raw frame/sequence annotations, loads the
|
||||
binary data and adjust them according to the metadata. The processing includes:
|
||||
@@ -555,12 +585,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
else None,
|
||||
)
|
||||
|
||||
if load_blobs and self.load_masks and frame_annotation.mask is not None:
|
||||
(
|
||||
frame_data.fg_probability,
|
||||
frame_data.mask_path,
|
||||
frame_data.bbox_xywh,
|
||||
) = self._load_fg_probability(frame_annotation)
|
||||
mask_annotation = frame_annotation.mask
|
||||
if mask_annotation is not None:
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
frame_data.mask_path = mask_path
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
@@ -604,25 +641,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
|
||||
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
assert self.dataset_root is not None and entry.mask is not None
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
# we can use provided bbox_xywh or calculate it based on mask
|
||||
# saves time to skip bbox calculation
|
||||
# pyre-ignore
|
||||
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
|
||||
fg_probability, self.box_crop_mask_thr
|
||||
)
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
return (
|
||||
safe_as_tensor(fg_probability, torch.float),
|
||||
full_path,
|
||||
safe_as_tensor(bbox_xywh, torch.long),
|
||||
)
|
||||
|
||||
return fg_probability, full_path
|
||||
|
||||
def _load_images(
|
||||
self,
|
||||
@@ -650,7 +678,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
assert entry_depth is not None
|
||||
assert self.dataset_root is not None and entry_depth is not None
|
||||
path = os.path.join(self.dataset_root, entry_depth.path)
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
@@ -660,6 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
if self.load_depth_masks:
|
||||
assert entry_depth.mask_path is not None
|
||||
# pyre-ignore
|
||||
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
||||
depth_mask = load_depth_mask(self._local_path(mask_path))
|
||||
else:
|
||||
@@ -708,6 +737,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
assert self.dataset_root is not None
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
|
||||
@@ -190,6 +190,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
box_crop=self.box_crop,
|
||||
box_crop_mask_thr=self.box_crop_mask_thr,
|
||||
box_crop_context=self.box_crop_context,
|
||||
path_manager=self.path_manager,
|
||||
)
|
||||
logger.info(str(self))
|
||||
|
||||
|
||||
161
pytorch3d/implicitron/dataset/orm_types.py
Normal file
161
pytorch3d/implicitron/dataset/orm_types.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# This functionality requires SQLAlchemy 2.0 or later.
|
||||
|
||||
import math
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pytorch3d.implicitron.dataset.types import (
|
||||
DepthAnnotation,
|
||||
ImageAnnotation,
|
||||
MaskAnnotation,
|
||||
PointCloudAnnotation,
|
||||
VideoAnnotation,
|
||||
ViewpointAnnotation,
|
||||
)
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy.orm import (
|
||||
composite,
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
mapped_column,
|
||||
MappedAsDataclass,
|
||||
)
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
|
||||
# these produce policies to serialize structured types to blobs
|
||||
def ArrayTypeFactory(shape):
|
||||
class NumpyArrayType(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is not None:
|
||||
if value.shape != shape:
|
||||
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
|
||||
return value.astype(np.float32).tobytes()
|
||||
return None
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return np.frombuffer(value, dtype=np.float32).reshape(shape)
|
||||
return None
|
||||
|
||||
return NumpyArrayType
|
||||
|
||||
|
||||
def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
|
||||
format_symbol = {
|
||||
float: "f", # float32
|
||||
int: "i", # int32
|
||||
}[dtype]
|
||||
|
||||
class TupleType(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
_format = format_symbol * math.prod(shape)
|
||||
|
||||
def process_bind_param(self, value, _):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if len(shape) > 1:
|
||||
value = np.array(value, dtype=dtype).reshape(-1)
|
||||
|
||||
return struct.pack(TupleType._format, *value)
|
||||
|
||||
def process_result_value(self, value, _):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
loaded = struct.unpack(TupleType._format, value)
|
||||
if len(shape) > 1:
|
||||
loaded = _rec_totuple(
|
||||
np.array(loaded, dtype=dtype).reshape(shape).tolist()
|
||||
)
|
||||
|
||||
return loaded
|
||||
|
||||
return TupleType
|
||||
|
||||
|
||||
def _rec_totuple(t):
|
||||
if isinstance(t, list):
|
||||
return tuple(_rec_totuple(x) for x in t)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
class Base(MappedAsDataclass, DeclarativeBase):
|
||||
"""subclasses will be converted to dataclasses"""
|
||||
|
||||
|
||||
class SqlFrameAnnotation(Base):
|
||||
__tablename__ = "frame_annots"
|
||||
|
||||
sequence_name: Mapped[str] = mapped_column(primary_key=True)
|
||||
frame_number: Mapped[int] = mapped_column(primary_key=True)
|
||||
frame_timestamp: Mapped[float] = mapped_column(index=True)
|
||||
|
||||
image: Mapped[ImageAnnotation] = composite(
|
||||
mapped_column("_image_path"),
|
||||
mapped_column("_image_size", TupleTypeFactory(int)),
|
||||
)
|
||||
|
||||
depth: Mapped[DepthAnnotation] = composite(
|
||||
mapped_column("_depth_path", nullable=True),
|
||||
mapped_column("_depth_scale_adjustment", nullable=True),
|
||||
mapped_column("_depth_mask_path", nullable=True),
|
||||
)
|
||||
|
||||
mask: Mapped[MaskAnnotation] = composite(
|
||||
mapped_column("_mask_path", nullable=True),
|
||||
mapped_column("_mask_mass", index=True, nullable=True),
|
||||
mapped_column(
|
||||
"_mask_bounding_box_xywh",
|
||||
TupleTypeFactory(float, shape=(4,)),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
viewpoint: Mapped[ViewpointAnnotation] = composite(
|
||||
mapped_column(
|
||||
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
|
||||
),
|
||||
mapped_column(
|
||||
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
|
||||
),
|
||||
mapped_column(
|
||||
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
|
||||
),
|
||||
mapped_column(
|
||||
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
|
||||
),
|
||||
mapped_column("_viewpoint_intrinsics_format", nullable=True),
|
||||
)
|
||||
|
||||
|
||||
class SqlSequenceAnnotation(Base):
|
||||
__tablename__ = "sequence_annots"
|
||||
|
||||
sequence_name: Mapped[str] = mapped_column(primary_key=True)
|
||||
category: Mapped[str] = mapped_column(index=True)
|
||||
|
||||
video: Mapped[VideoAnnotation] = composite(
|
||||
mapped_column("_video_path", nullable=True),
|
||||
mapped_column("_video_length", nullable=True),
|
||||
)
|
||||
point_cloud: Mapped[PointCloudAnnotation] = composite(
|
||||
mapped_column("_point_cloud_path", nullable=True),
|
||||
mapped_column("_point_cloud_quality_score", nullable=True),
|
||||
mapped_column("_point_cloud_n_points", nullable=True),
|
||||
)
|
||||
# the bigger the better
|
||||
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
|
||||
736
pytorch3d/implicitron/dataset/sql_dataset.py
Normal file
736
pytorch3d/implicitron/dataset/sql_dataset.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
|
||||
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
||||
FrameData,
|
||||
FrameDataBuilder,
|
||||
FrameDataBuilderBase,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SET_LISTS_TABLE: str = "set_lists"
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
"""
|
||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||
The length is returned after all sequence and frame filters are applied (see param
|
||||
definitions below). Indices can either be ordinal in [0, len), or pairs of
|
||||
(sequence_name, frame_number); with the performance of `dataset[i]` and
|
||||
`dataset[sequence_name, frame_number]` being same. A faster way to get metadata only
|
||||
(without blobs) is `dataset.meta[idx]` indexing; it requires box_crop==False.
|
||||
With ordinal indexing, the sequences are NOT guaranteed to span contiguous index
|
||||
ranges, and frame numbers are NOT guaranteed to be increasing within a sequence.
|
||||
Sequence-aware batch samplers have to use `sequence_[frames|indices]_in_order`
|
||||
iterators, which are efficient.
|
||||
|
||||
This functionality requires SQLAlchemy 2.0 or later.
|
||||
|
||||
Metadata-related args:
|
||||
sqlite_metadata_file: A SQLite file containing frame and sequence annotation
|
||||
tables (mapping to SqlFrameAnnotation and SqlSequenceAnnotation,
|
||||
respectively).
|
||||
dataset_root: A root directory to look for images, masks, etc. It can be
|
||||
alternatively set in `frame_data_builder` args, but this takes precedence.
|
||||
subset_lists_file: A JSON/sqlite file containing the lists of frames
|
||||
corresponding to different subsets (e.g. train/val/test) of the dataset;
|
||||
format: {subset: [(sequence_name, frame_id, file_path)]}. All entries
|
||||
must be present in frame_annotation metadata table.
|
||||
path_manager: a facade for non-POSIX filesystems.
|
||||
subsets: Restrict frames/sequences only to the given list of subsets
|
||||
as defined in subset_lists_file (see above). Applied before all other
|
||||
filters.
|
||||
remove_empty_masks: Removes the frames with no active foreground pixels
|
||||
in the segmentation mask (needs frame_annotation.mask.mass to be set;
|
||||
null values are retained).
|
||||
pick_frames_sql_clause: SQL WHERE clause to constrain frame annotations
|
||||
NOTE: This is a potential security risk! The string is passed to the SQL
|
||||
engine verbatim. Don’t expose it to end users of your application!
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
|
||||
sequences (after other sequence filters have been applied but before
|
||||
frame-based filters).
|
||||
limit_to: Limit the dataset to the first #limit_to frames (after other
|
||||
filters have been applied, except n_frames_per_sequence).
|
||||
n_frames_per_sequence: If > 0, randomly samples `n_frames_per_sequence`
|
||||
frames in each sequences uniformly without replacement if it has
|
||||
more frames than that; applied after other frame-level filters.
|
||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||
random frames per sequence.
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||
|
||||
sqlite_metadata_file: str = ""
|
||||
dataset_root: Optional[str] = None
|
||||
subset_lists_file: str = ""
|
||||
eval_batches_file: Optional[str] = None
|
||||
path_manager: Any = None
|
||||
subsets: Optional[List[str]] = None
|
||||
remove_empty_masks: bool = True
|
||||
pick_frames_sql_clause: Optional[str] = None
|
||||
pick_categories: Tuple[str, ...] = ()
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_to: int = 0
|
||||
limit_to: int = 0
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||
# we set it manually in the constructor
|
||||
# _index: pd.DataFrame = field(init=False)
|
||||
|
||||
frame_data_builder: FrameDataBuilderBase
|
||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if sa.__version__ < "2.0":
|
||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||
|
||||
if not self.sqlite_metadata_file:
|
||||
raise ValueError("sqlite_metadata_file must be set")
|
||||
|
||||
if self.dataset_root:
|
||||
frame_builder_type = self.frame_data_builder_class_type
|
||||
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
||||
"dataset_root"
|
||||
] = self.dataset_root
|
||||
|
||||
run_auto_creation(self)
|
||||
self.frame_data_builder.path_manager = self.path_manager
|
||||
|
||||
# pyre-ignore
|
||||
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
|
||||
|
||||
sequences = self._get_filtered_sequences_if_any()
|
||||
|
||||
if self.subsets:
|
||||
index = self._build_index_from_subset_lists(sequences)
|
||||
else:
|
||||
# TODO: if self.subset_lists_file and not self.subsets, it might be faster to
|
||||
# still use the concatenated lists, assuming they cover the whole dataset
|
||||
index = self._build_index_from_db(sequences)
|
||||
|
||||
if self.n_frames_per_sequence >= 0:
|
||||
index = self._stratified_sample_index(index)
|
||||
|
||||
if len(index) == 0:
|
||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||
|
||||
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
||||
|
||||
self.eval_batches = None # pyre-ignore
|
||||
if self.eval_batches_file:
|
||||
self.eval_batches = self._load_filter_eval_batches()
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-ignore[16]
|
||||
return len(self._index)
|
||||
|
||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||
"""
|
||||
Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair
|
||||
"""
|
||||
return self._get_item(frame_idx, True)
|
||||
|
||||
@property
|
||||
def meta(self):
|
||||
"""
|
||||
Allows accessing metadata only without loading blobs using `dataset.meta[idx]`.
|
||||
Requires box_crop==False, since in that case, cameras cannot be adjusted
|
||||
without loading masks.
|
||||
|
||||
Returns:
|
||||
FrameData objects with blob fields like `image_rgb` set to None.
|
||||
|
||||
Raises:
|
||||
ValueError if dataset.box_crop is set.
|
||||
"""
|
||||
return SqlIndexDataset._MetadataAccessor(self)
|
||||
|
||||
@dataclass
|
||||
class _MetadataAccessor:
|
||||
dataset: "SqlIndexDataset"
|
||||
|
||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||
return self.dataset._get_item(frame_idx, False)
|
||||
|
||||
def _get_item(
|
||||
self, frame_idx: Union[int, Tuple[str, int]], load_blobs: bool = True
|
||||
) -> FrameData:
|
||||
if isinstance(frame_idx, int):
|
||||
if frame_idx >= len(self._index):
|
||||
raise IndexError(f"index {frame_idx} out of range {len(self._index)}")
|
||||
|
||||
seq, frame = self._index.index[frame_idx]
|
||||
else:
|
||||
seq, frame, *rest = frame_idx
|
||||
if (seq, frame) not in self._index.index:
|
||||
raise IndexError(
|
||||
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
|
||||
)
|
||||
|
||||
if rest and rest[0] != self._index.loc[(seq, frame), "_image_path"]:
|
||||
raise IndexError(f"Non-matching image path in {frame_idx}.")
|
||||
|
||||
stmt = sa.select(self.frame_annotations_type).where(
|
||||
self.frame_annotations_type.sequence_name == seq,
|
||||
self.frame_annotations_type.frame_number
|
||||
== int(frame), # cast from np.int64
|
||||
)
|
||||
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
||||
SqlSequenceAnnotation.sequence_name == seq
|
||||
)
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
|
||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||
|
||||
frame_data = self.frame_data_builder.build(
|
||||
entry, seq_metadata, load_blobs=load_blobs
|
||||
)
|
||||
|
||||
# The rest of the fields are optional
|
||||
frame_data.frame_type = self._get_frame_type(entry)
|
||||
return frame_data
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pyre-ignore[16]
|
||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self._index.index.unique("sequence_name")
|
||||
|
||||
# override
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
stmt = sa.select(
|
||||
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
||||
).where( # we limit results to sequences that have frames after all filters
|
||||
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
||||
)
|
||||
with self._sql_engine.connect() as connection:
|
||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||
|
||||
return cat_to_seqs.groupby("category")["sequence_name"].apply(list).to_dict()
|
||||
|
||||
# override
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
Implements the DatasetBase method.
|
||||
|
||||
NOTE: Avoid this function as there are more efficient alternatives such as
|
||||
querying `dataset[idx]` directly or getting all sequence frames with
|
||||
`sequence_[frames|indices]_in_order`.
|
||||
|
||||
Return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. They need to belong to the same sequence!
|
||||
If timestamps are absent, they are replaced with zeros.
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
frames.
|
||||
|
||||
Args:
|
||||
idxs: a sequence int frame index in the dataset (it can be a slice)
|
||||
subset_filter: must remain None
|
||||
|
||||
Returns:
|
||||
list of tuples of
|
||||
- frame index in video
|
||||
- timestamp of frame in video, coalesced with 0s
|
||||
|
||||
Raises:
|
||||
ValueError if idxs belong to more than one sequence.
|
||||
"""
|
||||
|
||||
if subset_filter is not None:
|
||||
raise NotImplementedError(
|
||||
"Subset filters are not supported in SQL Dataset. "
|
||||
"We encourage creating a dataset per subset."
|
||||
)
|
||||
|
||||
index_slice, _ = self._get_frame_no_coalesced_ts_by_row_indices(idxs)
|
||||
# alternatively, we can use `.values.tolist()`, which may be faster
|
||||
# but returns a list of lists
|
||||
return list(index_slice.itertuples())
|
||||
|
||||
# override
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
"""
|
||||
Overrides the default DatasetBase implementation (we don’t use `_seq_to_idx`).
|
||||
Returns an iterator over the frame indices in a given sequence.
|
||||
We attempt to first sort by timestamp (if they are available),
|
||||
then by frame number.
|
||||
|
||||
Args:
|
||||
seq_name: the name of the sequence.
|
||||
subset_filter: subset names to filter to
|
||||
|
||||
Returns:
|
||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||
where `frame_no` is the index within the sequence, and
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
# TODO: implement sort_timestamp_first? (which would matter if the orders
|
||||
# of frame numbers and timestamps are different)
|
||||
rows = self._index.index.get_loc(seq_name)
|
||||
if isinstance(rows, slice):
|
||||
assert rows.stop is not None, "Unexpected result from pandas"
|
||||
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
else:
|
||||
rows = np.where(rows)[0]
|
||||
|
||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||
rows, seq_name, subset_filter
|
||||
)
|
||||
index_slice["idx"] = idx
|
||||
|
||||
yield from index_slice.itertuples(index=False)
|
||||
|
||||
# override
|
||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||
"""
|
||||
This class does not support eval batches with ordinal indices. You can pass
|
||||
eval_batches as a batch_sampler to a data_loader since the dataset supports
|
||||
`dataset[seq_name, frame_no]` indexing.
|
||||
"""
|
||||
return self.eval_batches
|
||||
|
||||
# override
|
||||
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
|
||||
raise ValueError("Not supported! Preprocess the data by merging them instead.")
|
||||
|
||||
# override
|
||||
@property
|
||||
def frame_data_type(self) -> Type[FrameData]:
|
||||
return self.frame_data_builder.frame_data_type
|
||||
|
||||
def is_filtered(self) -> bool:
|
||||
"""
|
||||
Returns `True` in case the dataset has been filtered and thus some frame
|
||||
annotations stored on the disk might be missing in the dataset object.
|
||||
Does not account for subsets.
|
||||
|
||||
Returns:
|
||||
is_filtered: `True` if the dataset has been filtered, else `False`.
|
||||
"""
|
||||
return (
|
||||
self.remove_empty_masks
|
||||
or self.limit_to > 0
|
||||
or self.limit_sequences_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible query: WHERE category IN 'self.pick_categories'
|
||||
# AND sequence_name IN 'self.pick_sequences'
|
||||
# AND sequence_name NOT IN 'self.exclude_sequences'
|
||||
# LIMIT 'self.limit_sequence_to'
|
||||
|
||||
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
|
||||
|
||||
where_conditions = [
|
||||
*self._get_category_filters(),
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if where_conditions:
|
||||
stmt = stmt.where(*where_conditions)
|
||||
|
||||
if self.limit_sequences_to > 0:
|
||||
logger.info(
|
||||
f"Limiting dataset to first {self.limit_sequences_to} sequences"
|
||||
)
|
||||
# NOTE: ROWID is SQLite-specific
|
||||
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
|
||||
|
||||
if not where_conditions and self.limit_sequences_to <= 0:
|
||||
# we will not need to filter by sequences
|
||||
return None
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
sequences = pd.read_sql_query(stmt, connection)["sequence_name"]
|
||||
logger.info("... retained %d sequences" % len(sequences))
|
||||
|
||||
return sequences
|
||||
|
||||
def _get_category_filters(self) -> List[sa.ColumnElement]:
|
||||
if not self.pick_categories:
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
||||
|
||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||
if not self.pick_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
||||
|
||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||
if not self.exclude_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||
|
||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||
assert self.subsets is not None
|
||||
with open(subset_lists_path, "r") as f:
|
||||
subset_to_seq_frame = json.load(f)
|
||||
|
||||
seq_frame_list = sum(
|
||||
(
|
||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||
for subset in self.subsets
|
||||
),
|
||||
[],
|
||||
)
|
||||
index = pd.DataFrame(
|
||||
seq_frame_list,
|
||||
columns=["sequence_name", "frame_number", "_image_path", "subset"],
|
||||
)
|
||||
return index
|
||||
|
||||
def _load_subsets_from_sql(self, subset_lists_path: str) -> pd.DataFrame:
|
||||
subsets = self.subsets
|
||||
assert subsets is not None
|
||||
# we need a new engine since we store the subsets in a separate DB
|
||||
engine = sa.create_engine(f"sqlite:///{subset_lists_path}")
|
||||
table = sa.Table(_SET_LISTS_TABLE, sa.MetaData(), autoload_with=engine)
|
||||
stmt = sa.select(table).where(table.c.subset.in_(subsets))
|
||||
with engine.connect() as connection:
|
||||
index = pd.read_sql(stmt, connection)
|
||||
|
||||
return index
|
||||
|
||||
def _build_index_from_subset_lists(
|
||||
self, sequences: Optional[pd.Series]
|
||||
) -> pd.DataFrame:
|
||||
if not self.subset_lists_file:
|
||||
raise ValueError("Requested subsets but subset_lists_file not given")
|
||||
|
||||
logger.info(f"Loading subset lists from {self.subset_lists_file}.")
|
||||
|
||||
subset_lists_path = self._local_path(self.subset_lists_file)
|
||||
if subset_lists_path.lower().endswith(".json"):
|
||||
index = self._load_subsets_from_json(subset_lists_path)
|
||||
else:
|
||||
index = self._load_subsets_from_sql(subset_lists_path)
|
||||
index = index.set_index(["sequence_name", "frame_number"])
|
||||
logger.info(f" -> loaded {len(index)} samples of {self.subsets}.")
|
||||
|
||||
if sequences is not None:
|
||||
logger.info("Applying filtered sequences.")
|
||||
sequence_values = index.index.get_level_values("sequence_name")
|
||||
index = index.loc[sequence_values.isin(sequences)]
|
||||
logger.info(f" -> retained {len(index)} samples.")
|
||||
|
||||
pick_frames_criteria = []
|
||||
if self.remove_empty_masks:
|
||||
logger.info("Culling samples with empty masks.")
|
||||
|
||||
if len(index) > self.remove_empty_masks_poll_whole_table_threshold:
|
||||
# APPROACH 1: find empty masks and drop indices.
|
||||
# dev load: 17s / 15 s (3.1M / 500K)
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
).where(self.frame_annotations_type._mask_mass == 0)
|
||||
with Session(self._sql_engine) as session:
|
||||
to_remove = session.execute(stmt).all()
|
||||
|
||||
# Pandas uses np.int64 for integer types, so we have to case
|
||||
# we might want to read it to pandas DataFrame directly to avoid the loop
|
||||
to_remove = [(seq, np.int64(fr)) for seq, fr in to_remove]
|
||||
index.drop(to_remove, errors="ignore", inplace=True)
|
||||
else:
|
||||
# APPROACH 3: load index into a temp table and join with annotations
|
||||
# dev load: 94 s / 23 s (3.1M / 500K)
|
||||
pick_frames_criteria.append(
|
||||
sa.or_(
|
||||
self.frame_annotations_type._mask_mass.is_(None),
|
||||
self.frame_annotations_type._mask_mass != 0,
|
||||
)
|
||||
)
|
||||
|
||||
if self.pick_frames_sql_clause:
|
||||
logger.info("Applying the custom SQL clause.")
|
||||
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
|
||||
|
||||
if pick_frames_criteria:
|
||||
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
||||
|
||||
logger.info(f" -> retained {len(index)} samples.")
|
||||
|
||||
if self.limit_to > 0:
|
||||
logger.info(f"Limiting dataset to first {self.limit_to} frames")
|
||||
index = index.sort_index().iloc[: self.limit_to]
|
||||
|
||||
return index.reset_index()
|
||||
|
||||
def _pick_frames_by_criteria(self, index: pd.DataFrame, criteria) -> pd.DataFrame:
|
||||
IndexTable = self._get_temp_index_table_instance()
|
||||
with self._sql_engine.connect() as connection:
|
||||
IndexTable.create(connection)
|
||||
# we don’t let pandas’s `to_sql` create the table automatically as
|
||||
# the table would be permanent, so we create it and append with pandas
|
||||
n_rows = index.to_sql(IndexTable.name, connection, if_exists="append")
|
||||
assert n_rows == len(index)
|
||||
sa_type = self.frame_annotations_type
|
||||
stmt = (
|
||||
sa.select(IndexTable)
|
||||
.select_from(
|
||||
IndexTable.join(
|
||||
self.frame_annotations_type,
|
||||
sa.and_(
|
||||
sa_type.sequence_name == IndexTable.c.sequence_name,
|
||||
sa_type.frame_number == IndexTable.c.frame_number,
|
||||
),
|
||||
)
|
||||
)
|
||||
.where(*criteria)
|
||||
)
|
||||
return pd.read_sql_query(stmt, connection).set_index(
|
||||
["sequence_name", "frame_number"]
|
||||
)
|
||||
|
||||
def _build_index_from_db(self, sequences: Optional[pd.Series]):
|
||||
logger.info("Loading sequcence-frame index from the database")
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
self.frame_annotations_type._image_path,
|
||||
sa.null().label("subset"),
|
||||
)
|
||||
where_conditions = []
|
||||
if sequences is not None:
|
||||
logger.info(" applying filtered sequences")
|
||||
where_conditions.append(
|
||||
self.frame_annotations_type.sequence_name.in_(sequences.tolist())
|
||||
)
|
||||
|
||||
if self.remove_empty_masks:
|
||||
logger.info(" excluding samples with empty masks")
|
||||
where_conditions.append(
|
||||
sa.or_(
|
||||
self.frame_annotations_type._mask_mass.is_(None),
|
||||
self.frame_annotations_type._mask_mass != 0,
|
||||
)
|
||||
)
|
||||
|
||||
if self.pick_frames_sql_clause:
|
||||
logger.info(" applying custom SQL clause")
|
||||
where_conditions.append(sa.text(self.pick_frames_sql_clause))
|
||||
|
||||
if where_conditions:
|
||||
stmt = stmt.where(*where_conditions)
|
||||
|
||||
if self.limit_to > 0:
|
||||
logger.info(f"Limiting dataset to first {self.limit_to} frames")
|
||||
stmt = stmt.order_by(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
).limit(self.limit_to)
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
index = pd.read_sql_query(stmt, connection)
|
||||
|
||||
logger.info(f" -> loaded {len(index)} samples.")
|
||||
return index
|
||||
|
||||
def _sort_index_(self, index):
|
||||
logger.info("Sorting the index by sequence and frame number.")
|
||||
index.sort_values(["sequence_name", "frame_number"], inplace=True)
|
||||
logger.info(" -> Done.")
|
||||
|
||||
def _load_filter_eval_batches(self):
|
||||
assert self.eval_batches_file
|
||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||
|
||||
if not os.path.isfile(self.eval_batches_file):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
f"Looking for dataset json file in {self.eval_batches_file}. "
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
|
||||
with open(self.eval_batches_file, "r") as f:
|
||||
eval_batches = json.load(f)
|
||||
|
||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||
pick_sequences = set(self.pick_sequences)
|
||||
if self.pick_categories:
|
||||
cat_to_seq = self.category_to_sequence_names()
|
||||
pick_sequences.update(
|
||||
seq for cat in self.pick_categories for seq in cat_to_seq[cat]
|
||||
)
|
||||
|
||||
if pick_sequences:
|
||||
old_len = len(eval_batches)
|
||||
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
|
||||
logger.warn(
|
||||
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
|
||||
)
|
||||
|
||||
if self.exclude_sequences:
|
||||
old_len = len(eval_batches)
|
||||
exclude_sequences = set(self.exclude_sequences)
|
||||
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
|
||||
logger.warn(
|
||||
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
|
||||
)
|
||||
|
||||
return eval_batches
|
||||
|
||||
def _stratified_sample_index(self, index):
|
||||
# NOTE this stratified sampling can be done more efficiently in
|
||||
# the no-subset case above if it is added to the SQL query.
|
||||
# We keep this generic implementation since no-subset case is uncommon
|
||||
index = index.groupby("sequence_name", group_keys=False).apply(
|
||||
lambda seq_frames: seq_frames.sample(
|
||||
min(len(seq_frames), self.n_frames_per_sequence),
|
||||
random_state=(
|
||||
_seq_name_to_seed(seq_frames.iloc[0]["sequence_name"]) + self.seed
|
||||
),
|
||||
)
|
||||
)
|
||||
logger.info(f" -> retained {len(index)} samples aster stratified sampling.")
|
||||
return index
|
||||
|
||||
def _get_frame_type(self, entry: SqlFrameAnnotation) -> Optional[str]:
|
||||
return self._index.loc[(entry.sequence_name, entry.frame_number), "subset"]
|
||||
|
||||
def _get_frame_no_coalesced_ts_by_row_indices(
|
||||
self,
|
||||
idxs: Sequence[int],
|
||||
seq_name: Optional[str] = None,
|
||||
subset_filter: Union[Sequence[str], str, None] = None,
|
||||
) -> Tuple[pd.DataFrame, Sequence[int]]:
|
||||
"""
|
||||
Loads timestamps for given index rows belonging to the same sequence.
|
||||
If seq_name is known, it speeds up the computation.
|
||||
Raises ValueError if `idxs` do not all belong to a single sequences .
|
||||
"""
|
||||
index_slice = self._index.iloc[idxs]
|
||||
if subset_filter is not None:
|
||||
if isinstance(subset_filter, str):
|
||||
subset_filter = [subset_filter]
|
||||
indicator = index_slice["subset"].isin(subset_filter)
|
||||
index_slice = index_slice.loc[indicator]
|
||||
idxs = [i for i, isin in zip(idxs, indicator) if isin]
|
||||
|
||||
frames = index_slice.index.get_level_values("frame_number").tolist()
|
||||
if seq_name is None:
|
||||
seq_name_list = index_slice.index.get_level_values("sequence_name").tolist()
|
||||
seq_name_set = set(seq_name_list)
|
||||
if len(seq_name_set) > 1:
|
||||
raise ValueError("Given indices belong to more than one sequence.")
|
||||
elif len(seq_name_set) == 1:
|
||||
seq_name = seq_name_list[0]
|
||||
|
||||
coalesced_ts = sa.sql.functions.coalesce(
|
||||
self.frame_annotations_type.frame_timestamp, 0
|
||||
)
|
||||
stmt = sa.select(
|
||||
coalesced_ts.label("frame_timestamp"),
|
||||
self.frame_annotations_type.frame_number,
|
||||
).where(
|
||||
self.frame_annotations_type.sequence_name == seq_name,
|
||||
self.frame_annotations_type.frame_number.in_(frames),
|
||||
)
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
|
||||
if len(frame_no_ts) != len(index_slice):
|
||||
raise ValueError(
|
||||
"Not all indices are found in the database; "
|
||||
"do they belong to more than one sequence?"
|
||||
)
|
||||
|
||||
return frame_no_ts, idxs
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
return self.path_manager.get_local_path(path)
|
||||
|
||||
def _get_temp_index_table_instance(self, table_name: str = "__index"):
|
||||
CachedTable = self.frame_annotations_type.metadata.tables.get(table_name)
|
||||
if CachedTable is not None: # table definition is not idempotent
|
||||
return CachedTable
|
||||
|
||||
return sa.Table(
|
||||
table_name,
|
||||
self.frame_annotations_type.metadata,
|
||||
sa.Column("sequence_name", sa.String, primary_key=True),
|
||||
sa.Column("frame_number", sa.Integer, primary_key=True),
|
||||
sa.Column("_image_path", sa.String),
|
||||
sa.Column("subset", sa.String),
|
||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||
)
|
||||
|
||||
|
||||
def _seq_name_to_seed(seq_name) -> int:
|
||||
"""Generates numbers in [0, 2 ** 28)"""
|
||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
||||
424
pytorch3d/implicitron/dataset/sql_dataset_provider.py
Normal file
424
pytorch3d/implicitron/dataset/sql_dataset_provider.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
from .sql_dataset import SqlIndexDataset
|
||||
|
||||
|
||||
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
|
||||
|
||||
# _NEED_CONTROL is a list of those elements of SqlIndexDataset which
|
||||
# are not directly specified for it in the config but come from the
|
||||
# DatasetMapProvider.
|
||||
_NEED_CONTROL: Tuple[str, ...] = (
|
||||
"path_manager",
|
||||
"subsets",
|
||||
"sqlite_metadata_file",
|
||||
"subset_lists_file",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
"""
|
||||
Generates the training, validation, and testing dataset objects for
|
||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||
|
||||
The dataset is organized in the filesystem as follows::
|
||||
|
||||
self.dataset_root
|
||||
├── <possible/partition/0>
|
||||
│ ├── <sequence_name_0>
|
||||
│ │ ├── depth_masks
|
||||
│ │ ├── depths
|
||||
│ │ ├── images
|
||||
│ │ ├── masks
|
||||
│ │ └── pointcloud.ply
|
||||
│ ├── <sequence_name_1>
|
||||
│ │ ├── depth_masks
|
||||
│ │ ├── depths
|
||||
│ │ ├── images
|
||||
│ │ ├── masks
|
||||
│ │ └── pointcloud.ply
|
||||
│ ├── ...
|
||||
│ ├── <sequence_name_N>
|
||||
│ ├── set_lists
|
||||
│ ├── <subset_base_name_0>.json
|
||||
│ ├── <subset_base_name_1>.json
|
||||
│ ├── ...
|
||||
│ ├── <subset_base_name_2>.json
|
||||
│ ├── eval_batches
|
||||
│ │ ├── <eval_batches_base_name_0>.json
|
||||
│ │ ├── <eval_batches_base_name_1>.json
|
||||
│ │ ├── ...
|
||||
│ │ ├── <eval_batches_base_name_M>.json
|
||||
│ ├── frame_annotations.jgz
|
||||
│ ├── sequence_annotations.jgz
|
||||
├── <possible/partition/1>
|
||||
├── ...
|
||||
├── <possible/partition/K>
|
||||
├── set_lists
|
||||
├── <subset_base_name_0>.sqlite
|
||||
├── <subset_base_name_1>.sqlite
|
||||
├── ...
|
||||
├── <subset_base_name_2>.sqlite
|
||||
├── eval_batches
|
||||
│ ├── <eval_batches_base_name_0>.json
|
||||
│ ├── <eval_batches_base_name_1>.json
|
||||
│ ├── ...
|
||||
│ ├── <eval_batches_base_name_M>.json
|
||||
|
||||
The dataset contains sequences named `<sequence_name_i>` that may be partitioned by
|
||||
directories such as `<possible/partition/0>` e.g. representing categories but they
|
||||
can also be stored in a flat structure. Each sequence folder contains the list of
|
||||
sequence images, depth maps, foreground masks, and valid-depth masks
|
||||
`images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore,
|
||||
`set_lists/` dirtectories (with partitions or global) store json or sqlite files
|
||||
`<subset_base_name_l>.<ext>`, each describing a certain sequence subset.
|
||||
These subset path conventions are not hard-coded and arbitrary relative path can be
|
||||
specified by setting `self.subset_lists_path` to the relative path w.r.t.
|
||||
dataset root.
|
||||
|
||||
Each `<subset_base_name_l>.json` file contains the following dictionary::
|
||||
|
||||
{
|
||||
"train": [
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
"val": [
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
"test": [
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
]
|
||||
|
||||
defining the list of frames (identified with their `sequence_name` and
|
||||
`frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of
|
||||
SQLite format, `<subset_base_name_l>.sqlite` contains a table with the header::
|
||||
|
||||
| sequence_name | frame_number | image_path | subset |
|
||||
|
||||
Note that `frame_number` can be obtained only from the metadata and
|
||||
does not necesarrily correspond to the numeric suffix of the corresponding image
|
||||
file name (e.g. a file `<partition_0>/<sequence_name_0>/images/frame00005.jpg` can
|
||||
have its frame number set to `20`, not 5).
|
||||
|
||||
Each `<eval_batches_base_name_M>.json` file contains a list of evaluation examples
|
||||
in the following form::
|
||||
|
||||
[
|
||||
[ # batch 1
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
[ # batch 2
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
]
|
||||
|
||||
Note that the evaluation examples always come from the `"test"` subset of the dataset.
|
||||
(test frames can repeat across batches). The batches can contain single element,
|
||||
which is typical in case of regular radiance field fitting.
|
||||
|
||||
Args:
|
||||
subset_lists_path: The relative path to the dataset subset definition.
|
||||
For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json".
|
||||
By default (None), dataset is not partitioned to subsets (in that case, setting
|
||||
`ignore_subsets` will speed up construction)
|
||||
dataset_root: The root folder of the dataset.
|
||||
metadata_basename: name of the SQL metadata file in dataset_root;
|
||||
not expected to be changed by users
|
||||
test_on_train: Construct validation and test datasets from
|
||||
the training subset; note that in practice, in this
|
||||
case all subset dataset objects will be same
|
||||
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||
ignore_subsets: Don’t filter by subsets in the dataset; note that in this
|
||||
case all subset datasets will be same
|
||||
eval_batch_num_training_frames: Add a certain number of training frames to each
|
||||
eval batch. Useful for evaluating models that require
|
||||
source views as input (e.g. NeRF-WCE / PixelNeRF).
|
||||
dataset_args: Specifies additional arguments to the
|
||||
JsonIndexDataset constructor call.
|
||||
path_manager_factory: (Optional) An object that generates an instance of
|
||||
PathManager that can translate provided file paths.
|
||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||
"""
|
||||
|
||||
category: Optional[str] = None
|
||||
subset_list_name: Optional[str] = None # TODO: docs
|
||||
# OR
|
||||
subset_lists_path: Optional[str] = None
|
||||
eval_batches_path: Optional[str] = None
|
||||
|
||||
dataset_root: str = _CO3D_SQL_DATASET_ROOT
|
||||
metadata_basename: str = "metadata.sqlite"
|
||||
|
||||
test_on_train: bool = False
|
||||
only_test_set: bool = False
|
||||
ignore_subsets: bool = False
|
||||
train_subsets: Tuple[str, ...] = ("train",)
|
||||
val_subsets: Tuple[str, ...] = ("val",)
|
||||
test_subsets: Tuple[str, ...] = ("test",)
|
||||
|
||||
eval_batch_num_training_frames: int = 0
|
||||
|
||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||
dataset_class_type: str = "SqlIndexDataset"
|
||||
dataset: SqlIndexDataset
|
||||
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
if self.ignore_subsets and not self.only_test_set:
|
||||
self.test_on_train = True # no point in loading same data 3 times
|
||||
|
||||
path_manager = self.path_manager_factory.get()
|
||||
|
||||
sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename)
|
||||
sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file)
|
||||
|
||||
if not os.path.isfile(sqlite_metadata_file):
|
||||
# The sqlite_metadata_file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
f"Looking for frame annotations in {sqlite_metadata_file}."
|
||||
+ " Please specify a correct dataset_root folder."
|
||||
+ " Note: By default the root folder is taken from the"
|
||||
+ " CO3D_SQL_DATASET_ROOT environment variable."
|
||||
)
|
||||
|
||||
if self.subset_lists_path and self.subset_list_name:
|
||||
raise ValueError(
|
||||
"subset_lists_path and subset_list_name cannot be both set"
|
||||
)
|
||||
|
||||
subset_lists_file = self._get_lists_file("set_lists")
|
||||
|
||||
# setup the common dataset arguments
|
||||
common_dataset_kwargs = {
|
||||
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
|
||||
"sqlite_metadata_file": sqlite_metadata_file,
|
||||
"dataset_root": self.dataset_root,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
"path_manager": path_manager,
|
||||
}
|
||||
|
||||
if self.category:
|
||||
logger.info(f"Forcing category filter in the datasets to {self.category}")
|
||||
common_dataset_kwargs["pick_categories"] = self.category.split(",")
|
||||
|
||||
# get the used dataset type
|
||||
dataset_type: Type[SqlIndexDataset] = registry.get(
|
||||
SqlIndexDataset, self.dataset_class_type
|
||||
)
|
||||
expand_args_fields(dataset_type)
|
||||
|
||||
if subset_lists_file is not None and not os.path.isfile(subset_lists_file):
|
||||
available_subsets = self._get_available_subsets(
|
||||
OmegaConf.to_object(common_dataset_kwargs["pick_categories"])
|
||||
)
|
||||
msg = f"Cannot find subset list file {self.subset_lists_path}."
|
||||
if available_subsets:
|
||||
msg += f" Some of the available subsets: {str(available_subsets)}."
|
||||
raise ValueError(msg)
|
||||
|
||||
train_dataset = None
|
||||
val_dataset = None
|
||||
if not self.only_test_set:
|
||||
# load the training set
|
||||
logger.debug("Constructing train dataset.")
|
||||
train_dataset = dataset_type(
|
||||
**common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets)
|
||||
)
|
||||
logger.info(f"Train dataset: {str(train_dataset)}")
|
||||
|
||||
if self.test_on_train:
|
||||
assert train_dataset is not None
|
||||
val_dataset = test_dataset = train_dataset
|
||||
else:
|
||||
# load the val and test sets
|
||||
if not self.only_test_set:
|
||||
# NOTE: this is always loaded in JsonProviderV2
|
||||
logger.debug("Extracting val dataset.")
|
||||
val_dataset = dataset_type(
|
||||
**common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets)
|
||||
)
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
|
||||
logger.debug("Extracting test dataset.")
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
del common_dataset_kwargs["eval_batches_file"]
|
||||
test_dataset = dataset_type(
|
||||
**common_dataset_kwargs,
|
||||
subsets=self._get_subsets(self.test_subsets, True),
|
||||
eval_batches_file=eval_batches_file,
|
||||
)
|
||||
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||
|
||||
if (
|
||||
eval_batches_file is not None
|
||||
and self.eval_batch_num_training_frames > 0
|
||||
):
|
||||
self._extend_eval_batches(test_dataset)
|
||||
|
||||
self._dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
|
||||
def _get_subsets(self, subsets, is_eval: bool = False):
|
||||
if self.ignore_subsets:
|
||||
return None
|
||||
|
||||
if is_eval and self.eval_batch_num_training_frames > 0:
|
||||
# we will need to have training frames for extended batches
|
||||
return list(subsets) + list(self.train_subsets)
|
||||
|
||||
return subsets
|
||||
|
||||
def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None:
|
||||
rng = np.random.default_rng(seed=0)
|
||||
eval_batches = test_dataset.get_eval_batches()
|
||||
if eval_batches is None:
|
||||
raise ValueError("Eval batches were not loaded!")
|
||||
|
||||
for batch in eval_batches:
|
||||
sequence = batch[0][0]
|
||||
seq_frames = list(
|
||||
test_dataset.sequence_frames_in_order(sequence, self.train_subsets)
|
||||
)
|
||||
idx_to_add = rng.permutation(len(seq_frames))[
|
||||
: self.eval_batch_num_training_frames
|
||||
]
|
||||
batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add)
|
||||
|
||||
@classmethod
|
||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Called by get_default_args.
|
||||
Certain fields are not exposed on each dataset class
|
||||
but rather are controlled by this provider class.
|
||||
"""
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
|
||||
def create_dataset(self):
|
||||
# No `dataset` member of this class is created.
|
||||
# The dataset(s) live in `self.get_dataset_map`.
|
||||
pass
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
return self._dataset_map # pyre-ignore [16]
|
||||
|
||||
def _get_available_subsets(self, categories: List[str]):
|
||||
"""
|
||||
Get the available subset names for a given category folder (if given) inside
|
||||
a root dataset folder `dataset_root`.
|
||||
"""
|
||||
path_manager = self.path_manager_factory.get()
|
||||
|
||||
subsets: List[str] = []
|
||||
for prefix in [""] + categories:
|
||||
set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists")
|
||||
if not (
|
||||
(path_manager is not None) and path_manager.isdir(set_list_dir)
|
||||
) and not os.path.isdir(set_list_dir):
|
||||
continue
|
||||
|
||||
set_list_files = (os.listdir if path_manager is None else path_manager.ls)(
|
||||
set_list_dir
|
||||
)
|
||||
subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files)
|
||||
|
||||
return subsets
|
||||
|
||||
def _get_lists_file(self, flavor: str) -> Optional[str]:
|
||||
if flavor == "eval_batches":
|
||||
subset_lists_path = self.eval_batches_path
|
||||
else:
|
||||
subset_lists_path = self.subset_lists_path
|
||||
|
||||
if not subset_lists_path and not self.subset_list_name:
|
||||
return None
|
||||
|
||||
category_elem = ""
|
||||
if self.category and "," not in self.category:
|
||||
# if multiple categories are given, looking for global set lists
|
||||
category_elem = self.category
|
||||
|
||||
subset_lists_path = subset_lists_path or (
|
||||
os.path.join(
|
||||
category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}"
|
||||
)
|
||||
)
|
||||
|
||||
assert subset_lists_path
|
||||
path_manager = self.path_manager_factory.get()
|
||||
# try absolute path first
|
||||
subset_lists_file = _get_local_path_check_extensions(
|
||||
subset_lists_path, path_manager
|
||||
)
|
||||
if subset_lists_file:
|
||||
return subset_lists_file
|
||||
|
||||
full_path = os.path.join(self.dataset_root, subset_lists_path)
|
||||
subset_lists_file = _get_local_path_check_extensions(full_path, path_manager)
|
||||
|
||||
if not subset_lists_file:
|
||||
raise FileNotFoundError(
|
||||
f"Subset lists path given but not found: {full_path}"
|
||||
)
|
||||
|
||||
return subset_lists_file
|
||||
|
||||
|
||||
def _get_local_path_check_extensions(
|
||||
path, path_manager, extensions=("", ".sqlite", ".json")
|
||||
) -> Optional[str]:
|
||||
for ext in extensions:
|
||||
local = _local_path(path_manager, path + ext)
|
||||
if os.path.isfile(local):
|
||||
return local
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _local_path(path_manager, path: str) -> str:
|
||||
if path_manager is None:
|
||||
return path
|
||||
return path_manager.get_local_path(path)
|
||||
189
pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py
Normal file
189
pytorch3d/implicitron/dataset/train_eval_data_loader_provider.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DataLoaderMap,
|
||||
SceneBatchSampler,
|
||||
SequenceDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D
|
||||
# and support both eval_batches protocols
|
||||
@registry.register
|
||||
class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider):
|
||||
"""
|
||||
Implementation of DataLoaderMapProviderBase that may use internal eval batches for
|
||||
the test dataset. In particular, if `eval_batches_relpath` is set, it loads
|
||||
eval batches from that json file, otherwise test set is treated in the same way as
|
||||
train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type`
|
||||
are respected.
|
||||
|
||||
If conditioning is not required, then the batch size should
|
||||
be set as 1, and most of the fields do not matter.
|
||||
|
||||
If conditioning is required, each batch will contain one main
|
||||
frame first to predict and the, rest of the elements are for
|
||||
conditioning.
|
||||
|
||||
If images_per_seq_options is left empty, the conditioning
|
||||
frames are picked according to the conditioning type given.
|
||||
This does not have regard to the order of frames in a
|
||||
scene, or which frames belong to what scene.
|
||||
|
||||
If images_per_seq_options is given, then the conditioning types
|
||||
must be SAME and the remaining fields are used.
|
||||
|
||||
Members:
|
||||
batch_size: The size of the batch of the data loader.
|
||||
num_workers: Number of data-loading threads in each data loader.
|
||||
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
|
||||
an epoch is the length of the training set.
|
||||
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
|
||||
an epoch is the length of the validation set.
|
||||
dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of
|
||||
batches in a testing epoch. Or 0 to mean an epoch is the length of the test
|
||||
set.
|
||||
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
|
||||
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
|
||||
value. Empty (the default) means that we are not careful about which frames
|
||||
come from which scene.
|
||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||
in the sequence. It first sorts the frames by timestimps when available,
|
||||
otherwise by frame numbers, finds the connected segments within the sequence
|
||||
of sufficient length, then samples a random pivot element among them and
|
||||
ideally uses it as a middle of the temporal window, shifting the borders
|
||||
where necessary. This strategy mitigates the bias against shorter segments
|
||||
and their boundaries.
|
||||
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||
difference in frame_number of neighbouring frames when forming connected
|
||||
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||
the whole sequence is considered a segment regardless of frame numbers.
|
||||
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||
the whole sequence is considered a segment regardless of frame timestamps.
|
||||
"""
|
||||
|
||||
batch_size: int = 1
|
||||
num_workers: int = 0
|
||||
|
||||
dataset_length_train: int = 0
|
||||
dataset_length_val: int = 0
|
||||
dataset_length_test: int = 0
|
||||
|
||||
images_per_seq_options: Tuple[int, ...] = ()
|
||||
sample_consecutive_frames: bool = False
|
||||
consecutive_frames_max_gap: int = 0
|
||||
consecutive_frames_max_gap_seconds: float = 0.1
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
train = self._make_generic_data_loader(
|
||||
datasets.train,
|
||||
self.dataset_length_train,
|
||||
datasets.train,
|
||||
)
|
||||
|
||||
val = self._make_generic_data_loader(
|
||||
datasets.val,
|
||||
self.dataset_length_val,
|
||||
datasets.train,
|
||||
)
|
||||
|
||||
if datasets.test is not None and datasets.test.get_eval_batches() is not None:
|
||||
test = self._make_eval_data_loader(datasets.test)
|
||||
else:
|
||||
test = self._make_generic_data_loader(
|
||||
datasets.test,
|
||||
self.dataset_length_test,
|
||||
datasets.train,
|
||||
)
|
||||
|
||||
return DataLoaderMap(train=train, val=val, test=test)
|
||||
|
||||
def _make_eval_data_loader(
|
||||
self,
|
||||
dataset: Optional[DatasetBase],
|
||||
) -> Optional[DataLoader[FrameData]]:
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=dataset.get_eval_batches(),
|
||||
**self._get_data_loader_common_kwargs(dataset),
|
||||
)
|
||||
|
||||
def _make_generic_data_loader(
|
||||
self,
|
||||
dataset: Optional[DatasetBase],
|
||||
num_batches: int,
|
||||
train_dataset: Optional[DatasetBase],
|
||||
) -> Optional[DataLoader[FrameData]]:
|
||||
"""
|
||||
Returns the dataloader for a dataset.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
train_dataset: the training dataset, used if conditioning_type==TRAIN
|
||||
conditioning_type: source for padding of batches
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
data_loader_kwargs = self._get_data_loader_common_kwargs(dataset)
|
||||
|
||||
if len(self.images_per_seq_options) > 0:
|
||||
# this is a typical few-view setup
|
||||
# conditioning comes from the same subset since subsets are split by seqs
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||
images_per_seq_options=self.images_per_seq_options,
|
||||
sample_consecutive_frames=self.sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
if self.batch_size == 1:
|
||||
# this is a typical many-view setup (without conditioning)
|
||||
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
|
||||
|
||||
# edge case: conditioning on train subset, typical for Nerformer-like many-view
|
||||
# there is only one sequence in all datasets, so we condition on another subset
|
||||
return self._train_loader(
|
||||
dataset, train_dataset, num_batches, data_loader_kwargs
|
||||
)
|
||||
|
||||
def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]:
|
||||
return {
|
||||
"num_workers": self.num_workers,
|
||||
"collate_fn": dataset.frame_data_type.collate,
|
||||
}
|
||||
@@ -204,7 +204,7 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls._field_types.values()
|
||||
types = cls.__annotations__.values()
|
||||
dlist_T = zip(*dlist)
|
||||
res_T = [
|
||||
_dataclass_list_from_dict_list(key_list, tp)
|
||||
@@ -270,7 +270,7 @@ def _dataclass_from_dict(d, typeannot):
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
types = cls._field_types.values()
|
||||
types = cls.__annotations__.values()
|
||||
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
types = get_args(typeannot)
|
||||
|
||||
@@ -38,8 +38,8 @@ class _Visualizer:
|
||||
image_render: torch.Tensor
|
||||
image_rgb_masked: torch.Tensor
|
||||
depth_render: torch.Tensor
|
||||
depth_map: torch.Tensor
|
||||
depth_mask: torch.Tensor
|
||||
depth_map: Optional[torch.Tensor]
|
||||
depth_mask: Optional[torch.Tensor]
|
||||
|
||||
visdom_env: str = "eval_debug"
|
||||
|
||||
@@ -75,9 +75,11 @@ class _Visualizer:
|
||||
viz = self._viz
|
||||
viz.images(
|
||||
torch.cat(
|
||||
(
|
||||
make_depth_image(self.depth_render, loss_mask_now),
|
||||
make_depth_image(self.depth_map, loss_mask_now),
|
||||
(make_depth_image(self.depth_render, loss_mask_now),)
|
||||
+ (
|
||||
(make_depth_image(self.depth_map, loss_mask_now),)
|
||||
if self.depth_map is not None
|
||||
else ()
|
||||
),
|
||||
dim=3,
|
||||
),
|
||||
@@ -91,12 +93,13 @@ class _Visualizer:
|
||||
win="depth_abs" + name_postfix + "_mask",
|
||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
|
||||
)
|
||||
viz.images(
|
||||
self.depth_mask,
|
||||
env=self.visdom_env,
|
||||
win="depth_abs" + name_postfix + "_maskd",
|
||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
|
||||
)
|
||||
if self.depth_mask is not None:
|
||||
viz.images(
|
||||
self.depth_mask,
|
||||
env=self.visdom_env,
|
||||
win="depth_abs" + name_postfix + "_maskd",
|
||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
|
||||
)
|
||||
|
||||
# show the 3D plot
|
||||
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
|
||||
@@ -104,29 +107,30 @@ class _Visualizer:
|
||||
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
|
||||
loss_mask_now.device
|
||||
)
|
||||
pcl_pred = get_rgbd_point_cloud(
|
||||
viewpoint_trivial,
|
||||
self.image_render,
|
||||
self.depth_render,
|
||||
# mask_crop,
|
||||
torch.ones_like(self.depth_render),
|
||||
# loss_mask_now,
|
||||
)
|
||||
pcl_gt = get_rgbd_point_cloud(
|
||||
viewpoint_trivial,
|
||||
self.image_rgb_masked,
|
||||
self.depth_map,
|
||||
# mask_crop,
|
||||
torch.ones_like(self.depth_map),
|
||||
# loss_mask_now,
|
||||
)
|
||||
_pcls = {
|
||||
pn: p
|
||||
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
|
||||
if int(p.num_points_per_cloud()) > 0
|
||||
"pred_depth": get_rgbd_point_cloud(
|
||||
viewpoint_trivial,
|
||||
self.image_render,
|
||||
self.depth_render,
|
||||
# mask_crop,
|
||||
torch.ones_like(self.depth_render),
|
||||
# loss_mask_now,
|
||||
)
|
||||
}
|
||||
if self.depth_map is not None:
|
||||
_pcls["gt_depth"] = get_rgbd_point_cloud(
|
||||
viewpoint_trivial,
|
||||
self.image_rgb_masked,
|
||||
self.depth_map,
|
||||
# mask_crop,
|
||||
torch.ones_like(self.depth_map),
|
||||
# loss_mask_now,
|
||||
)
|
||||
|
||||
_pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
|
||||
|
||||
plotlyplot = plot_scene(
|
||||
{f"pcl{name_postfix}": _pcls},
|
||||
{f"pcl{name_postfix}": _pcls}, # pyre-ignore
|
||||
camera_scale=1.0,
|
||||
pointcloud_max_points=10000,
|
||||
pointcloud_marker_size=1,
|
||||
@@ -277,10 +281,10 @@ def eval_batch(
|
||||
image_render=image_render,
|
||||
image_rgb_masked=image_rgb_masked,
|
||||
depth_render=cloned_render["depth_render"],
|
||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
depth_map=frame_data.depth_map,
|
||||
depth_mask=frame_data.depth_mask[:1],
|
||||
depth_mask=frame_data.depth_mask[:1]
|
||||
if frame_data.depth_mask is not None
|
||||
else None,
|
||||
visdom_env=visualize_visdom_env,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class ImplicitronEvaluator(EvaluatorBase):
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
# pyre-fixme[14]: `run` overrides method defined in `EvaluatorBase` inconsistently.
|
||||
def run(
|
||||
self,
|
||||
model: ImplicitronModelBase,
|
||||
|
||||
@@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
and source images, which will be used for intersecting with target rays.
|
||||
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||
foreground masks.
|
||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
|
||||
regions in the input images (i.e. regions that do not correspond
|
||||
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||
"mask_sample", rays will be sampled in the non zero regions.
|
||||
|
||||
@@ -41,6 +41,8 @@ class ModelDBIR(ImplicitronModelBase):
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
max_points: int = -1
|
||||
|
||||
# pyre-fixme[14]: `forward` overrides method defined in `ImplicitronModelBase`
|
||||
# inconsistently.
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
|
||||
@@ -111,6 +111,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
"minimum": lambda curr, acc: torch.minimum(curr, acc),
|
||||
}[self.weight_function_type]
|
||||
|
||||
# pyre-fixme[14]: `forward` overrides method defined in `RaymarcherBase`
|
||||
# inconsistently.
|
||||
def forward(
|
||||
self,
|
||||
rays_densities: torch.Tensor,
|
||||
|
||||
@@ -393,7 +393,7 @@ class _GLTFLoader:
|
||||
attributes = primitive["attributes"]
|
||||
vertex_colors = self._get_primitive_attribute(attributes, "COLOR_0", np.float32)
|
||||
if vertex_colors is not None:
|
||||
return TexturesVertex(torch.from_numpy(vertex_colors))
|
||||
return TexturesVertex([torch.from_numpy(vertex_colors)])
|
||||
|
||||
vertex_texcoords_0 = self._get_primitive_attribute(
|
||||
attributes, "TEXCOORD_0", np.float32
|
||||
@@ -559,12 +559,26 @@ class _GLTFWriter:
|
||||
meshes = defaultdict(list)
|
||||
# pyre-fixme[6]: Incompatible parameter type
|
||||
meshes["name"] = "Node-Mesh"
|
||||
primitives = {
|
||||
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
|
||||
"indices": 1,
|
||||
"material": 0, # default material
|
||||
"mode": _PrimitiveMode.TRIANGLES,
|
||||
}
|
||||
if isinstance(self.mesh.textures, TexturesVertex):
|
||||
primitives = {
|
||||
"attributes": {"POSITION": 0, "COLOR_0": 2},
|
||||
"indices": 1,
|
||||
"mode": _PrimitiveMode.TRIANGLES,
|
||||
}
|
||||
elif isinstance(self.mesh.textures, TexturesUV):
|
||||
primitives = {
|
||||
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
|
||||
"indices": 1,
|
||||
"mode": _PrimitiveMode.TRIANGLES,
|
||||
"material": 0,
|
||||
}
|
||||
else:
|
||||
primitives = {
|
||||
"attributes": {"POSITION": 0},
|
||||
"indices": 1,
|
||||
"mode": _PrimitiveMode.TRIANGLES,
|
||||
}
|
||||
|
||||
meshes["primitives"].append(primitives)
|
||||
self._json_data["meshes"].append(meshes)
|
||||
|
||||
@@ -610,6 +624,14 @@ class _GLTFWriter:
|
||||
element_min = list(map(float, np.min(data, axis=0)))
|
||||
element_max = list(map(float, np.max(data, axis=0)))
|
||||
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
elif key == "texvertices":
|
||||
component_type = _ComponentType.FLOAT
|
||||
data = self.mesh.textures.verts_features_list()[0].cpu().numpy()
|
||||
element_type = "VEC3"
|
||||
buffer_view = 2
|
||||
element_min = list(map(float, np.min(data, axis=0)))
|
||||
element_max = list(map(float, np.max(data, axis=0)))
|
||||
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
elif key == "indices":
|
||||
component_type = _ComponentType.UNSIGNED_SHORT
|
||||
data = (
|
||||
@@ -646,8 +668,10 @@ class _GLTFWriter:
|
||||
return (byte_length, data)
|
||||
|
||||
def _write_bufferview(self, key: str, **kwargs):
|
||||
if key not in ["positions", "texcoords", "indices"]:
|
||||
raise ValueError("key must be one of positions, texcoords or indices")
|
||||
if key not in ["positions", "texcoords", "texvertices", "indices"]:
|
||||
raise ValueError(
|
||||
"key must be one of positions, texcoords, texvertices or indices"
|
||||
)
|
||||
|
||||
bufferview = {
|
||||
"name": "bufferView_%s" % key,
|
||||
@@ -661,6 +685,10 @@ class _GLTFWriter:
|
||||
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
target = _TargetType.ARRAY_BUFFER
|
||||
bufferview["byteStride"] = int(byte_per_element)
|
||||
elif key == "texvertices":
|
||||
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
target = _TargetType.ELEMENT_ARRAY_BUFFER
|
||||
bufferview["byteStride"] = int(byte_per_element)
|
||||
elif key == "indices":
|
||||
byte_per_element = (
|
||||
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
|
||||
@@ -701,12 +729,15 @@ class _GLTFWriter:
|
||||
pos_byte, pos_data = self._write_accessor_json("positions")
|
||||
idx_byte, idx_data = self._write_accessor_json("indices")
|
||||
include_textures = False
|
||||
if (
|
||||
self.mesh.textures is not None
|
||||
and self.mesh.textures.verts_uvs_list()[0] is not None
|
||||
):
|
||||
tex_byte, tex_data = self._write_accessor_json("texcoords")
|
||||
include_textures = True
|
||||
if self.mesh.textures is not None:
|
||||
if hasattr(self.mesh.textures, "verts_features_list"):
|
||||
tex_byte, tex_data = self._write_accessor_json("texvertices")
|
||||
include_textures = True
|
||||
texcoords = False
|
||||
elif self.mesh.textures.verts_uvs_list()[0] is not None:
|
||||
tex_byte, tex_data = self._write_accessor_json("texcoords")
|
||||
include_textures = True
|
||||
texcoords = True
|
||||
|
||||
# bufferViews for positions, texture coords and indices
|
||||
byte_offset = 0
|
||||
@@ -717,17 +748,19 @@ class _GLTFWriter:
|
||||
byte_offset += idx_byte
|
||||
|
||||
if include_textures:
|
||||
self._write_bufferview(
|
||||
"texcoords", byte_length=tex_byte, offset=byte_offset
|
||||
)
|
||||
if texcoords:
|
||||
self._write_bufferview(
|
||||
"texcoords", byte_length=tex_byte, offset=byte_offset
|
||||
)
|
||||
else:
|
||||
self._write_bufferview(
|
||||
"texvertices", byte_length=tex_byte, offset=byte_offset
|
||||
)
|
||||
byte_offset += tex_byte
|
||||
|
||||
# image bufferView
|
||||
include_image = False
|
||||
if (
|
||||
self.mesh.textures is not None
|
||||
and self.mesh.textures.maps_list()[0] is not None
|
||||
):
|
||||
if self.mesh.textures is not None and hasattr(self.mesh.textures, "maps_list"):
|
||||
include_image = True
|
||||
image_byte, image_data = self._write_image_buffer(offset=byte_offset)
|
||||
byte_offset += image_byte
|
||||
|
||||
@@ -684,6 +684,8 @@ def save_obj(
|
||||
decimal_places: Optional[int] = None,
|
||||
path_manager: Optional[PathManager] = None,
|
||||
*,
|
||||
normals: Optional[torch.Tensor] = None,
|
||||
faces_normals_idx: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
texture_map: Optional[torch.Tensor] = None,
|
||||
@@ -698,6 +700,10 @@ def save_obj(
|
||||
decimal_places: Number of decimal places for saving.
|
||||
path_manager: Optional PathManager for interpreting f if
|
||||
it is a str.
|
||||
normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx
|
||||
to index into.
|
||||
faces_normals_idx: LongTensor of shape (F, 3) giving the index into
|
||||
normals for each vertex in the face.
|
||||
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
|
||||
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
|
||||
each vertex in the face.
|
||||
@@ -713,6 +719,22 @@ def save_obj(
|
||||
message = "'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if (normals is None) != (faces_normals_idx is None):
|
||||
message = "'normals' and 'faces_normals_idx' must both be None or neither."
|
||||
raise ValueError(message)
|
||||
|
||||
if faces_normals_idx is not None and (
|
||||
faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3
|
||||
):
|
||||
message = (
|
||||
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if normals is not None and (normals.dim() != 2 or normals.size(1) != 3):
|
||||
message = "'normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
|
||||
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
@@ -742,9 +764,12 @@ def save_obj(
|
||||
verts,
|
||||
faces,
|
||||
decimal_places,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
verts_uvs=verts_uvs,
|
||||
faces_uvs=faces_uvs,
|
||||
save_texture=save_texture,
|
||||
save_normals=normals is not None,
|
||||
)
|
||||
|
||||
# Save the .mtl and .png files associated with the texture
|
||||
@@ -777,9 +802,12 @@ def _save(
|
||||
faces,
|
||||
decimal_places: Optional[int] = None,
|
||||
*,
|
||||
normals: Optional[torch.Tensor] = None,
|
||||
faces_normals_idx: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
save_texture: bool = False,
|
||||
save_normals: bool = False,
|
||||
) -> None:
|
||||
|
||||
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
|
||||
@@ -798,18 +826,26 @@ def _save(
|
||||
|
||||
lines = ""
|
||||
|
||||
if len(verts):
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
|
||||
if len(verts):
|
||||
V, D = verts.shape
|
||||
for i in range(V):
|
||||
vert = [float_str % verts[i, j] for j in range(D)]
|
||||
lines += "v %s\n" % " ".join(vert)
|
||||
|
||||
if save_normals:
|
||||
assert normals is not None
|
||||
assert faces_normals_idx is not None
|
||||
lines += _write_normals(normals, faces_normals_idx, float_str)
|
||||
|
||||
if save_texture:
|
||||
assert faces_uvs is not None
|
||||
assert verts_uvs is not None
|
||||
|
||||
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
|
||||
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
@@ -818,7 +854,6 @@ def _save(
|
||||
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
|
||||
raise ValueError(message)
|
||||
|
||||
# pyre-fixme[16] # undefined attribute cpu
|
||||
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
|
||||
|
||||
# Save verts uvs after verts
|
||||
@@ -828,25 +863,77 @@ def _save(
|
||||
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
|
||||
lines += "vt %s\n" % " ".join(uv)
|
||||
|
||||
f.write(lines)
|
||||
|
||||
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
if len(faces):
|
||||
F, P = faces.shape
|
||||
for i in range(F):
|
||||
if save_texture:
|
||||
# Format faces as {verts_idx}/{verts_uvs_idx}
|
||||
_write_faces(
|
||||
f,
|
||||
faces,
|
||||
faces_uvs if save_texture else None,
|
||||
faces_normals_idx if save_normals else None,
|
||||
)
|
||||
|
||||
|
||||
def _write_normals(
|
||||
normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str
|
||||
) -> str:
|
||||
if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3:
|
||||
message = (
|
||||
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if normals.dim() != 2 or normals.size(1) != 3:
|
||||
message = "'normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu()
|
||||
|
||||
lines = []
|
||||
V, D = normals.shape
|
||||
for i in range(V):
|
||||
normal = [float_str % normals[i, j] for j in range(D)]
|
||||
lines.append("vn %s\n" % " ".join(normal))
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def _write_faces(
|
||||
f,
|
||||
faces: torch.Tensor,
|
||||
faces_uvs: Optional[torch.Tensor],
|
||||
faces_normals_idx: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
F, P = faces.shape
|
||||
for i in range(F):
|
||||
if faces_normals_idx is not None:
|
||||
if faces_uvs is not None:
|
||||
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
|
||||
face = [
|
||||
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
|
||||
"%d/%d/%d"
|
||||
% (
|
||||
faces[i, j] + 1,
|
||||
faces_uvs[i, j] + 1,
|
||||
faces_normals_idx[i, j] + 1,
|
||||
)
|
||||
for j in range(P)
|
||||
]
|
||||
else:
|
||||
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
|
||||
# Format faces as {verts_idx}//{verts_normals_idx}
|
||||
face = [
|
||||
"%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1)
|
||||
for j in range(P)
|
||||
]
|
||||
elif faces_uvs is not None:
|
||||
# Format faces as {verts_idx}/{verts_uvs_idx}
|
||||
face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)]
|
||||
else:
|
||||
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
|
||||
|
||||
if i + 1 < F:
|
||||
lines += "f %s\n" % " ".join(face)
|
||||
|
||||
elif i + 1 == F:
|
||||
# No newline at the end of the file.
|
||||
lines += "f %s" % " ".join(face)
|
||||
|
||||
f.write(lines)
|
||||
if i + 1 < F:
|
||||
f.write("f %s\n" % " ".join(face))
|
||||
else:
|
||||
# No newline at the end of the file.
|
||||
f.write("f %s" % " ".join(face))
|
||||
|
||||
@@ -375,14 +375,14 @@ class CamerasBase(TensorProperties):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_znear(self):
|
||||
return self.znear if hasattr(self, "znear") else None
|
||||
return getattr(self, "znear", None)
|
||||
|
||||
def get_image_size(self):
|
||||
"""
|
||||
Returns the image size, if provided, expected in the form of (height, width)
|
||||
The image size is used for conversion of projected points to screen coordinates.
|
||||
"""
|
||||
return self.image_size if hasattr(self, "image_size") else None
|
||||
return getattr(self, "image_size", None)
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
|
||||
|
||||
@@ -109,17 +109,11 @@ class MultinomialRaysampler(torch.nn.Module):
|
||||
self._stratified_sampling = stratified_sampling
|
||||
|
||||
# get the initial grid of image xy coords
|
||||
_xy_grid = torch.stack(
|
||||
tuple(
|
||||
reversed(
|
||||
meshgrid_ij(
|
||||
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
|
||||
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
|
||||
)
|
||||
)
|
||||
),
|
||||
dim=-1,
|
||||
y, x = meshgrid_ij(
|
||||
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
|
||||
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
|
||||
)
|
||||
_xy_grid = torch.stack([x, y], dim=-1)
|
||||
|
||||
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
|
||||
|
||||
|
||||
@@ -491,6 +491,8 @@ class TexturesAtlas(TexturesBase):
|
||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||
return new_tex
|
||||
|
||||
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
|
||||
# inconsistently.
|
||||
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
This is similar to a nearest neighbor sampling and involves a
|
||||
@@ -927,6 +929,8 @@ class TexturesUV(TexturesBase):
|
||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||
return new_tex
|
||||
|
||||
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
|
||||
# inconsistently.
|
||||
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate a 2D texture map using uv vertex texture coordinates for each
|
||||
@@ -1450,6 +1454,8 @@ class TexturesVertex(TexturesBase):
|
||||
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
|
||||
return new_tex
|
||||
|
||||
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
|
||||
# inconsistently.
|
||||
def sample_textures(self, fragments, faces_packed=None) -> torch.Tensor:
|
||||
"""
|
||||
Determine the color for each rasterized face. Interpolate the colors for
|
||||
|
||||
1
setup.py
1
setup.py
@@ -164,6 +164,7 @@ setup(
|
||||
"tqdm>4.29.0",
|
||||
"matplotlib",
|
||||
"accelerate",
|
||||
"sqlalchemy>=2.0",
|
||||
],
|
||||
},
|
||||
entry_points={
|
||||
|
||||
1
tests/implicitron/data/sql_dataset/set_lists_100.json
Normal file
1
tests/implicitron/data/sql_dataset/set_lists_100.json
Normal file
File diff suppressed because one or more lines are too long
BIN
tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite
Normal file
BIN
tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite
Normal file
Binary file not shown.
246
tests/implicitron/test_co3d_sql.py
Normal file
246
tests/implicitron/test_co3d_sql.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
|
||||
SequenceDataLoaderMapProvider,
|
||||
SimpleDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
|
||||
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
|
||||
SqlIndexDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
|
||||
TrainEvalDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
|
||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||
sh = logging.StreamHandler()
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
|
||||
|
||||
|
||||
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
|
||||
class TestCo3dSqlDataSource(unittest.TestCase):
|
||||
def test_no_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.ignore_subsets = True
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
dataset_args.limit_sequences_to = 1
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
|
||||
)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 202)
|
||||
for frame in data_loaders.train:
|
||||
self.assertIsNone(frame.frame_type)
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = (
|
||||
"skateboard/set_lists/set_lists_manyview_dev_0.json"
|
||||
)
|
||||
# this will naturally limit to one sequence (no need to limit by cat/sequence)
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
for sampler_type in [
|
||||
"SimpleDataLoaderMapProvider",
|
||||
"SequenceDataLoaderMapProvider",
|
||||
"TrainEvalDataLoaderMapProvider",
|
||||
]:
|
||||
args.data_loader_map_provider_class_type = sampler_type
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 100)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
# check loading blobs
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800)
|
||||
break
|
||||
|
||||
def test_sql_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
|
||||
for sampler_type in [
|
||||
"SimpleDataLoaderMapProvider",
|
||||
"SequenceDataLoaderMapProvider",
|
||||
"TrainEvalDataLoaderMapProvider",
|
||||
]:
|
||||
args.data_loader_map_provider_class_type = sampler_type
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 100)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(
|
||||
frame.image_rgb.shape[-1], 800
|
||||
) # check loading blobs
|
||||
break
|
||||
|
||||
@unittest.skip("It takes 75 seconds; skipping by default")
|
||||
def test_huge_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 3158974)
|
||||
self.assertEqual(len(data_loaders.val), 518417)
|
||||
self.assertEqual(len(data_loaders.test), 518417)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_broken_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "et_non_est"
|
||||
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
|
||||
with self.assertRaises(FileNotFoundError) as err:
|
||||
ImplicitronDataSource(**args)
|
||||
|
||||
# check the hint text
|
||||
self.assertIn("Subset lists path given but not found", str(err.exception))
|
||||
|
||||
def test_eval_batches(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
provider_args.eval_batches_path = (
|
||||
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
|
||||
)
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 50)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_eval_batches_from_subset_list_name(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_list_name = "manyview_dev_0"
|
||||
provider_args.category = "skateboard"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 50)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_frame_access(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
|
||||
frame_builder_args.load_point_clouds = True
|
||||
frame_builder_args.box_crop = False # required for .meta
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
dataset_map, _ = data_source.get_datasets_and_dataloaders()
|
||||
dataset = dataset_map["train"]
|
||||
|
||||
for idx in [10, ("245_26182_52130", 22)]:
|
||||
example_meta = dataset.meta[idx]
|
||||
example = dataset[idx]
|
||||
|
||||
self.assertIsNone(example_meta.image_rgb)
|
||||
self.assertIsNone(example_meta.fg_probability)
|
||||
self.assertIsNone(example_meta.depth_map)
|
||||
self.assertIsNone(example_meta.sequence_point_cloud)
|
||||
self.assertIsNotNone(example_meta.camera)
|
||||
|
||||
self.assertIsNotNone(example.image_rgb)
|
||||
self.assertIsNotNone(example.fg_probability)
|
||||
self.assertIsNotNone(example.depth_map)
|
||||
self.assertIsNotNone(example.sequence_point_cloud)
|
||||
self.assertIsNotNone(example.camera)
|
||||
|
||||
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.focal_length, example.camera.focal_length
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.principal_point, example.camera.principal_point
|
||||
)
|
||||
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock
|
||||
@@ -18,6 +19,7 @@ from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from tests.common_testing import get_tests_dir
|
||||
from tests.implicitron.common_resources import get_skateboard_data
|
||||
|
||||
DATA_DIR = get_tests_dir() / "implicitron/data"
|
||||
DEBUG: bool = False
|
||||
@@ -28,6 +30,12 @@ class TestDataSource(unittest.TestCase):
|
||||
self.maxDiff = None
|
||||
torch.manual_seed(42)
|
||||
|
||||
stack = contextlib.ExitStack()
|
||||
self.dataset_root, self.path_manager = stack.enter_context(
|
||||
get_skateboard_data()
|
||||
)
|
||||
self.addCleanup(stack.close)
|
||||
|
||||
def _test_omegaconf_generic_failure(self):
|
||||
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
||||
from dataclasses import dataclass
|
||||
@@ -56,12 +64,14 @@ class TestDataSource(unittest.TestCase):
|
||||
get_default_args(JsonIndexDataset)
|
||||
|
||||
def test_one(self):
|
||||
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
|
||||
cfg = get_default_args(ImplicitronDataSource)
|
||||
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
||||
cfg = get_default_args(ImplicitronDataSource)
|
||||
# making the test invariant to env variables
|
||||
cfg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
|
||||
cfg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
|
||||
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
||||
|
||||
def test_default(self):
|
||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||
@@ -73,7 +83,7 @@ class TestDataSource(unittest.TestCase):
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.n_frames_per_sequence = -1
|
||||
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
dataset_args.dataset_root = self.dataset_root
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
@@ -96,7 +106,7 @@ class TestDataSource(unittest.TestCase):
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.n_frames_per_sequence = -1
|
||||
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
dataset_args.dataset_root = self.dataset_root
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
|
||||
@@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
|
||||
|
||||
from .common_resources import get_skateboard_data
|
||||
|
||||
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
|
||||
|
||||
class TestDatasetVisualize(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
|
||||
for k, dataset in self.datasets.items()
|
||||
}
|
||||
)
|
||||
self.visdom = Visdom()
|
||||
self.visdom = Visdom(port=VISDOM_PORT)
|
||||
if not self.visdom.check_connection():
|
||||
print("Visdom server not running! Disabling visdom visualizations.")
|
||||
self.visdom = None
|
||||
|
||||
230
tests/implicitron/test_extending_orm_types.py
Normal file
230
tests/implicitron/test_extending_orm_types.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import ClassVar, Optional, Type
|
||||
|
||||
import pandas as pd
|
||||
import pkg_resources
|
||||
import sqlalchemy as sa
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder
|
||||
from pytorch3d.implicitron.dataset.orm_types import (
|
||||
SqlFrameAnnotation,
|
||||
SqlSequenceAnnotation,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
||||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
from sqlalchemy.orm import composite, Mapped, mapped_column, Session
|
||||
|
||||
NO_BLOBS_KWARGS = {
|
||||
"dataset_root": "",
|
||||
"load_images": False,
|
||||
"load_depths": False,
|
||||
"load_masks": False,
|
||||
"load_depth_masks": False,
|
||||
"box_crop": False,
|
||||
}
|
||||
|
||||
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
|
||||
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
|
||||
|
||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||
sh = logging.StreamHandler()
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MagneticFieldAnnotation:
|
||||
path: str
|
||||
average_flux_density: Optional[float] = None
|
||||
|
||||
|
||||
class ExtendedSqlFrameAnnotation(SqlFrameAnnotation):
|
||||
num_dogs: Mapped[Optional[int]] = mapped_column(default=None)
|
||||
|
||||
magnetic_field: Mapped[MagneticFieldAnnotation] = composite(
|
||||
mapped_column("_magnetic_field_path", nullable=True),
|
||||
mapped_column("_magnetic_field_average_flux_density", nullable=True),
|
||||
default_factory=lambda: None,
|
||||
)
|
||||
|
||||
|
||||
class ExtendedSqlIndexDataset(SqlIndexDataset):
|
||||
frame_annotations_type: ClassVar[
|
||||
Type[SqlFrameAnnotation]
|
||||
] = ExtendedSqlFrameAnnotation
|
||||
|
||||
|
||||
class CanineFrameData(FrameData):
|
||||
num_dogs: Optional[int] = None
|
||||
magnetic_field_average_flux_density: Optional[float] = None
|
||||
|
||||
|
||||
@registry.register
|
||||
class CanineFrameDataBuilder(
|
||||
GenericWorkaround, GenericFrameDataBuilder[CanineFrameData]
|
||||
):
|
||||
"""
|
||||
A concrete class to build an extended FrameData object
|
||||
"""
|
||||
|
||||
frame_data_type: ClassVar[Type[FrameData]] = CanineFrameData
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: ExtendedSqlFrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
load_blobs: bool = True,
|
||||
) -> CanineFrameData:
|
||||
frame_data = super().build(frame_annotation, sequence_annotation, load_blobs)
|
||||
frame_data.num_dogs = frame_annotation.num_dogs or 101
|
||||
frame_data.magnetic_field_average_flux_density = (
|
||||
frame_annotation.magnetic_field.average_flux_density
|
||||
)
|
||||
return frame_data
|
||||
|
||||
|
||||
class CanineSqlIndexDataset(SqlIndexDataset):
|
||||
frame_annotations_type: ClassVar[
|
||||
Type[SqlFrameAnnotation]
|
||||
] = ExtendedSqlFrameAnnotation
|
||||
|
||||
frame_data_builder_class_type: str = "CanineFrameDataBuilder"
|
||||
|
||||
|
||||
class TestExtendingOrmTypes(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# create a temporary copy of the DB with an extended schema
|
||||
engine = sa.create_engine(f"sqlite:///{METADATA_FILE}")
|
||||
with Session(engine) as session:
|
||||
extended_annots = [
|
||||
ExtendedSqlFrameAnnotation(
|
||||
**{
|
||||
k: v
|
||||
for k, v in frame_annot.__dict__.items()
|
||||
if not k.startswith("_") # remove mapped fields and SA metadata
|
||||
}
|
||||
)
|
||||
for frame_annot in session.scalars(sa.select(SqlFrameAnnotation))
|
||||
]
|
||||
seq_annots = session.scalars(
|
||||
sa.select(SqlSequenceAnnotation),
|
||||
execution_options={"prebuffer_rows": True},
|
||||
)
|
||||
session.expunge_all()
|
||||
|
||||
self._temp_db = tempfile.NamedTemporaryFile(delete=False)
|
||||
engine_ext = sa.create_engine(f"sqlite:///{self._temp_db.name}")
|
||||
ExtendedSqlFrameAnnotation.metadata.create_all(engine_ext, checkfirst=True)
|
||||
with Session(engine_ext, expire_on_commit=False) as session_ext:
|
||||
session_ext.add_all(extended_annots)
|
||||
for instance in seq_annots:
|
||||
session_ext.merge(instance)
|
||||
session_ext.commit()
|
||||
|
||||
# check the setup is correct
|
||||
with engine_ext.connect() as connection_ext:
|
||||
df = pd.read_sql_query(
|
||||
sa.select(ExtendedSqlFrameAnnotation), connection_ext
|
||||
)
|
||||
self.assertEqual(len(df), 100)
|
||||
self.assertIn("_magnetic_field_average_flux_density", df.columns)
|
||||
|
||||
df_seq = pd.read_sql_query(sa.select(SqlSequenceAnnotation), connection_ext)
|
||||
self.assertEqual(len(df_seq), 10)
|
||||
|
||||
def tearDown(self):
|
||||
self._temp_db.close()
|
||||
os.remove(self._temp_db.name)
|
||||
|
||||
def test_basic(self, sequence="cat1_seq2", frame_number=4):
|
||||
dataset = ExtendedSqlIndexDataset(
|
||||
sqlite_metadata_file=self._temp_db.name,
|
||||
remove_empty_masks=False,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number == 0:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
# test indexing
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[len(dataset) + 1]
|
||||
|
||||
# test sequence-frame indexing
|
||||
item = dataset[sequence, frame_number]
|
||||
self.assertEqual(item.sequence_name, sequence)
|
||||
self.assertEqual(item.frame_number, frame_number)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[sequence, 13]
|
||||
|
||||
def test_extending_frame_data(self, sequence="cat1_seq2", frame_number=4):
|
||||
dataset = CanineSqlIndexDataset(
|
||||
sqlite_metadata_file=self._temp_db.name,
|
||||
remove_empty_masks=False,
|
||||
frame_data_builder_CanineFrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
self.assertIsInstance(item, CanineFrameData)
|
||||
self.assertEqual(item.num_dogs, 101)
|
||||
self.assertIsNone(item.magnetic_field_average_flux_density)
|
||||
|
||||
if item.frame_number == 0:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
# test indexing
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[len(dataset) + 1]
|
||||
|
||||
# test sequence-frame indexing
|
||||
item = dataset[sequence, frame_number]
|
||||
self.assertIsInstance(item, CanineFrameData)
|
||||
self.assertEqual(item.sequence_name, sequence)
|
||||
self.assertEqual(item.frame_number, frame_number)
|
||||
self.assertEqual(item.num_dogs, 101)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[sequence, 13]
|
||||
@@ -17,6 +17,7 @@ from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
get_bbox_from_mask,
|
||||
load_16big_png_depth,
|
||||
load_1bit_png_mask,
|
||||
load_depth,
|
||||
@@ -107,11 +108,14 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
|
||||
|
||||
(
|
||||
self.frame_data.fg_probability,
|
||||
self.frame_data.mask_path,
|
||||
self.frame_data.bbox_xywh,
|
||||
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
|
||||
fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
|
||||
self.frame_annotation
|
||||
)
|
||||
self.frame_data.mask_path = mask_path
|
||||
self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
mask_thr = self.frame_data_builder.box_crop_mask_thr
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
|
||||
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
|
||||
|
||||
self.assertIsNotNone(self.frame_data.mask_path)
|
||||
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
|
||||
|
||||
37
tests/implicitron/test_orm_types.py
Normal file
37
tests/implicitron/test_orm_types.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
|
||||
|
||||
|
||||
class TestOrmTypes(unittest.TestCase):
|
||||
def test_tuple_serialization_none(self):
|
||||
ttype = TupleTypeFactory()()
|
||||
output = ttype.process_bind_param(None, None)
|
||||
self.assertIsNone(output)
|
||||
output = ttype.process_result_value(output, None)
|
||||
self.assertIsNone(output)
|
||||
|
||||
def test_tuple_serialization_1d(self):
|
||||
for input_tuple in [(1, 2, 3), (4.5, 6.7)]:
|
||||
ttype = TupleTypeFactory(type(input_tuple[0]), (len(input_tuple),))()
|
||||
output = ttype.process_bind_param(input_tuple, None)
|
||||
input_hat = ttype.process_result_value(output, None)
|
||||
self.assertEqual(type(input_hat[0]), type(input_tuple[0]))
|
||||
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
||||
|
||||
def test_tuple_serialization_2d(self):
|
||||
input_tuple = ((1.0, 2.0, 3.0), (4.5, 5.5, 6.6))
|
||||
ttype = TupleTypeFactory(type(input_tuple[0][0]), (2, 3))()
|
||||
output = ttype.process_bind_param(input_tuple, None)
|
||||
input_hat = ttype.process_result_value(output, None)
|
||||
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
|
||||
# we use float32 to serialise
|
||||
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
||||
522
tests/implicitron/test_sql_dataset.py
Normal file
522
tests/implicitron/test_sql_dataset.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from collections import Counter
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
||||
|
||||
NO_BLOBS_KWARGS = {
|
||||
"dataset_root": "",
|
||||
"load_images": False,
|
||||
"load_depths": False,
|
||||
"load_masks": False,
|
||||
"load_depth_masks": False,
|
||||
"box_crop": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||
sh = logging.StreamHandler()
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
|
||||
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
|
||||
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json")
|
||||
|
||||
|
||||
class TestSqlDataset(unittest.TestCase):
|
||||
def test_basic(self, sequence="cat1_seq2", frame_number=4):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number == 0:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
# test indexing
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[len(dataset) + 1]
|
||||
|
||||
# test sequence-frame indexing
|
||||
item = dataset[sequence, frame_number]
|
||||
self.assertEqual(item.sequence_name, sequence)
|
||||
self.assertEqual(item.frame_number, frame_number)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[sequence, 13]
|
||||
|
||||
def test_filter_empty_masks(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 78)
|
||||
|
||||
def test_pick_frames_sql_clause(self):
|
||||
dataset_no_empty_masks = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# check the datasets are equal
|
||||
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
|
||||
for i in range(len(dataset)):
|
||||
item_nem = dataset_no_empty_masks[i]
|
||||
item = dataset[i]
|
||||
self.assertEqual(item_nem.image_path, item.image_path)
|
||||
|
||||
# remove_empty_masks together with the custom criterion
|
||||
dataset_ts = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
pick_frames_sql_clause="frame_timestamp < 0.15",
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
self.assertEqual(len(dataset_ts), 19)
|
||||
|
||||
def test_limit_categories(self, category="cat0"):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_categories=[category],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 50)
|
||||
for i in range(len(dataset)):
|
||||
self.assertEqual(dataset[i].sequence_category, category)
|
||||
|
||||
def test_limit_sequences(self, num_sequences=3):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_sequences_to=num_sequences,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 10 * num_sequences)
|
||||
|
||||
def delist(sequence_name):
|
||||
return sequence_name if isinstance(sequence_name, str) else sequence_name[0]
|
||||
|
||||
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))}
|
||||
self.assertEqual(len(unique_seqs), num_sequences)
|
||||
|
||||
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"):
|
||||
# pick sequence
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_sequences=[sequence],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 10)
|
||||
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||
self.assertCountEqual(unique_seqs, {sequence})
|
||||
|
||||
item = dataset[sequence, 0]
|
||||
self.assertEqual(item.sequence_name, sequence)
|
||||
self.assertEqual(item.frame_number, 0)
|
||||
|
||||
# exclude sequence
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
exclude_sequences=[sequence],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 90)
|
||||
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||
self.assertNotIn(sequence, unique_seqs)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[sequence, 0]
|
||||
|
||||
def test_limit_frames(self, num_frames=13):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_to=num_frames,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), num_frames)
|
||||
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||
self.assertEqual(len(unique_seqs), 2)
|
||||
|
||||
# test when the limit is not binding
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_to=1000,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_limit_frames_per_sequence(self, num_frames=2):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
n_frames_per_sequence=num_frames,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), num_frames * 10)
|
||||
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||
self.assertEqual(len(seq_counts), 10)
|
||||
self.assertCountEqual(
|
||||
set(seq_counts.values()), {2}
|
||||
) # all counts are num_frames
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[next(iter(seq_counts)), num_frames + 1]
|
||||
|
||||
# test when the limit is not binding
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
n_frames_per_sequence=13,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_filter_medley(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
pick_categories=["cat1"],
|
||||
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
|
||||
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
|
||||
limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2"
|
||||
n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
|
||||
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
|
||||
self.assertEqual(seq_counts["cat1_seq1"], 6)
|
||||
self.assertEqual(seq_counts["cat1_seq2"], 4)
|
||||
|
||||
def test_subsets_trivial(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
limit_to=100, # force sorting
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number == 0:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
def test_subsets_filter_empty_masks(self):
|
||||
# we need to test this case as it uses quite different logic with `df.drop()`
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 78)
|
||||
|
||||
def test_subsets_pick_frames_sql_clause(self):
|
||||
dataset_no_empty_masks = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# check the datasets are equal
|
||||
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
|
||||
for i in range(len(dataset)):
|
||||
item_nem = dataset_no_empty_masks[i]
|
||||
item = dataset[i]
|
||||
self.assertEqual(item_nem.image_path, item.image_path)
|
||||
|
||||
# remove_empty_masks together with the custom criterion
|
||||
dataset_ts = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
pick_frames_sql_clause="frame_timestamp < 0.15",
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset_ts), 19)
|
||||
|
||||
def test_single_subset(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 50)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[51]
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number < 2:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 2)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
item = dataset[last_sequence, 0]
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[last_sequence, 1]
|
||||
|
||||
def test_subset_with_filters(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
pick_categories=["cat1"],
|
||||
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
|
||||
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
|
||||
limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2"
|
||||
n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
|
||||
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
|
||||
self.assertEqual(seq_counts["cat1_seq1"], 3)
|
||||
self.assertEqual(seq_counts["cat1_seq2"], 2)
|
||||
|
||||
def test_visitor(self):
|
||||
dataset_sorted = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
sequences = dataset_sorted.sequence_names()
|
||||
i = 0
|
||||
for seq in sequences:
|
||||
last_ts = float("-Inf")
|
||||
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq):
|
||||
self.assertEqual(i, idx)
|
||||
i += 1
|
||||
self.assertGreaterEqual(ts, last_ts)
|
||||
last_ts = ts
|
||||
|
||||
# test legacy visitor
|
||||
old_indices = None
|
||||
for seq in sequences:
|
||||
last_ts = float("-Inf")
|
||||
rows = dataset_sorted._index.index.get_loc(seq)
|
||||
indices = list(range(rows.start or 0, rows.stop, rows.step or 1))
|
||||
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices)
|
||||
self.assertEqual(len(fn_ts_list), len(indices))
|
||||
|
||||
if old_indices:
|
||||
# check raising if we ask for multiple sequences
|
||||
with self.assertRaises(ValueError):
|
||||
dataset_sorted.get_frame_numbers_and_timestamps(
|
||||
indices + old_indices
|
||||
)
|
||||
|
||||
old_indices = indices
|
||||
|
||||
def test_visitor_subsets(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_to=100, # force sorting
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
sequences = dataset.sequence_names()
|
||||
i = 0
|
||||
for seq in sequences:
|
||||
last_ts = float("-Inf")
|
||||
seq_frames = list(dataset.sequence_frames_in_order(seq))
|
||||
self.assertEqual(len(seq_frames), 10)
|
||||
for ts, _, idx in seq_frames:
|
||||
self.assertEqual(i, idx)
|
||||
i += 1
|
||||
self.assertGreaterEqual(ts, last_ts)
|
||||
last_ts = ts
|
||||
|
||||
last_ts = float("-Inf")
|
||||
train_frames = list(dataset.sequence_frames_in_order(seq, "train"))
|
||||
self.assertEqual(len(train_frames), 5)
|
||||
for ts, _, _ in train_frames:
|
||||
self.assertGreaterEqual(ts, last_ts)
|
||||
last_ts = ts
|
||||
|
||||
def test_category_to_sequence_names(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
cat_to_seqs = dataset.category_to_sequence_names()
|
||||
self.assertEqual(len(cat_to_seqs), 2)
|
||||
self.assertIn("cat1", cat_to_seqs)
|
||||
self.assertEqual(len(cat_to_seqs["cat1"]), 5)
|
||||
|
||||
# check that override preserves the behavior
|
||||
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
|
||||
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
|
||||
|
||||
def test_category_to_sequence_names_filters(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
exclude_sequences=["cat1_seq0"],
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
cat_to_seqs = dataset.category_to_sequence_names()
|
||||
self.assertEqual(len(cat_to_seqs), 2)
|
||||
self.assertIn("cat1", cat_to_seqs)
|
||||
self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one
|
||||
|
||||
# check that override preserves the behavior
|
||||
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
|
||||
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
|
||||
|
||||
def test_meta_access(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 50)
|
||||
|
||||
for idx in [10, ("cat0_seq2", 2)]:
|
||||
example_meta = dataset.meta[idx]
|
||||
example = dataset[idx]
|
||||
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.focal_length, example.camera.focal_length
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.principal_point, example.camera.principal_point
|
||||
)
|
||||
|
||||
def test_meta_access_no_blobs(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
frame_data_builder_FrameDataBuilder_args={
|
||||
"dataset_root": ".",
|
||||
"box_crop": False, # required by blob-less accessor
|
||||
},
|
||||
)
|
||||
|
||||
self.assertIsNone(dataset.meta[0].image_rgb)
|
||||
self.assertIsNone(dataset.meta[0].fg_probability)
|
||||
self.assertIsNone(dataset.meta[0].depth_map)
|
||||
self.assertIsNone(dataset.meta[0].sequence_point_cloud)
|
||||
self.assertIsNotNone(dataset.meta[0].camera)
|
||||
@@ -120,6 +120,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
|
||||
The scene is "already lit", i.e. the textures reflect the lighting
|
||||
already, so we want to render them with full ambient light.
|
||||
"""
|
||||
|
||||
self.skipTest("Data not available")
|
||||
|
||||
glb = DATA_DIR / "apartment_1.glb"
|
||||
@@ -266,3 +267,117 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
|
||||
expected = np.array(f)
|
||||
|
||||
self.assertClose(image, expected)
|
||||
|
||||
def test_load_save_load_cow_texturesvertex(self):
|
||||
"""
|
||||
Load the cow as converted to a single mesh in a glb file and then save it to a glb file.
|
||||
"""
|
||||
|
||||
glb = DATA_DIR / "cow.glb"
|
||||
self.assertTrue(glb.is_file())
|
||||
device = torch.device("cuda:0")
|
||||
mesh = _load(glb, device=device, include_textures=False)
|
||||
self.assertEqual(len(mesh), 1)
|
||||
self.assertIsNone(mesh.textures)
|
||||
|
||||
self.assertEqual(mesh.faces_packed().shape, (5856, 3))
|
||||
self.assertEqual(mesh.verts_packed().shape, (3225, 3))
|
||||
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
|
||||
self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes())
|
||||
|
||||
mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded()))
|
||||
|
||||
image = _render(mesh, "cow_gray")
|
||||
|
||||
with Image.open(DATA_DIR / "glb_cow_gray.png") as f:
|
||||
expected = np.array(f)
|
||||
|
||||
self.assertClose(image, expected)
|
||||
|
||||
# save the mesh to a glb file
|
||||
glb = DATA_DIR / "cow_write_texturesvertex.glb"
|
||||
_write(mesh, glb)
|
||||
|
||||
# reload the mesh glb file saved in TexturesVertex format
|
||||
glb = DATA_DIR / "cow_write_texturesvertex.glb"
|
||||
self.assertTrue(glb.is_file())
|
||||
mesh_dash = _load(glb, device=device)
|
||||
self.assertEqual(len(mesh_dash), 1)
|
||||
|
||||
self.assertEqual(mesh_dash.faces_packed().shape, (5856, 3))
|
||||
self.assertEqual(mesh_dash.verts_packed().shape, (3225, 3))
|
||||
self.assertEqual(mesh_dash.textures.verts_features_list()[0].shape, (3225, 3))
|
||||
|
||||
# check the re-rendered image with expected
|
||||
image_dash = _render(mesh, "cow_gray_texturesvertex")
|
||||
self.assertClose(image_dash, expected)
|
||||
|
||||
def test_save_toy(self):
|
||||
"""
|
||||
Construct a simple mesh and save it to a glb file in TexturesVertex mode.
|
||||
"""
|
||||
|
||||
example = {}
|
||||
example["POSITION"] = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 0.0],
|
||||
[-1.0, 0.0, 0.0],
|
||||
[-1.0, 0.0, 1.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[-1.0, 1.0, 0.0],
|
||||
[-1.0, 1.0, 1.0],
|
||||
[0.0, 1.0, 1.0],
|
||||
]
|
||||
]
|
||||
)
|
||||
example["indices"] = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1, 4, 2],
|
||||
[4, 3, 2],
|
||||
[3, 7, 2],
|
||||
[7, 6, 2],
|
||||
[3, 4, 7],
|
||||
[4, 8, 7],
|
||||
[8, 5, 7],
|
||||
[5, 6, 7],
|
||||
[5, 2, 6],
|
||||
[5, 1, 2],
|
||||
[1, 5, 4],
|
||||
[5, 8, 4],
|
||||
]
|
||||
]
|
||||
)
|
||||
example["indices"] -= 1
|
||||
example["COLOR_0"] = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
[1.0, 0.0, 0.0],
|
||||
]
|
||||
]
|
||||
)
|
||||
# example['prop'] = {'material':
|
||||
# {'pbrMetallicRoughness':
|
||||
# {'baseColorFactor':
|
||||
# torch.tensor([[0.7, 0.7, 1, 0.5]]),
|
||||
# 'metallicFactor': torch.tensor([1]),
|
||||
# 'roughnessFactor': torch.tensor([0.1])},
|
||||
# 'alphaMode': 'BLEND',
|
||||
# 'doubleSided': True}}
|
||||
|
||||
texture = TexturesVertex(example["COLOR_0"])
|
||||
mesh = Meshes(
|
||||
verts=example["POSITION"], faces=example["indices"], textures=texture
|
||||
)
|
||||
|
||||
glb = DATA_DIR / "example_write_texturesvertex.glb"
|
||||
_write(mesh, glb)
|
||||
|
||||
@@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
"f 4 2 1",
|
||||
]
|
||||
)
|
||||
actual_file = open(Path(f.name), "r")
|
||||
self.assertEqual(actual_file.read(), expected_file)
|
||||
self.assertEqual(Path(f.name).read_text(), expected_file)
|
||||
|
||||
def test_load_mtl(self):
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
@@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "same type of texture"):
|
||||
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
|
||||
|
||||
def test_save_obj_with_normal(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[
|
||||
[0.02, 0.5, 0.73],
|
||||
[0.3, 0.03, 0.361],
|
||||
[0.32, 0.12, 0.47],
|
||||
[0.36, 0.17, 0.9],
|
||||
[0.40, 0.7, 0.19],
|
||||
[1.0, 0.00, 0.000],
|
||||
[0.00, 1.00, 0.00],
|
||||
[0.00, 0.00, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_normals_idx = torch.tensor(
|
||||
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
|
||||
)
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
obj_file = os.path.join(temp_dir, "mesh.obj")
|
||||
save_obj(
|
||||
obj_file,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places=2,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
)
|
||||
|
||||
expected_obj_file = "\n".join(
|
||||
[
|
||||
"v 0.01 0.20 0.30",
|
||||
"v 0.20 0.03 0.41",
|
||||
"v 0.30 0.40 0.05",
|
||||
"v 0.60 0.70 0.80",
|
||||
"vn 0.02 0.50 0.73",
|
||||
"vn 0.30 0.03 0.36",
|
||||
"vn 0.32 0.12 0.47",
|
||||
"vn 0.36 0.17 0.90",
|
||||
"vn 0.40 0.70 0.19",
|
||||
"vn 1.00 0.00 0.00",
|
||||
"vn 0.00 1.00 0.00",
|
||||
"vn 0.00 0.00 1.00",
|
||||
"f 1//1 3//2 2//3",
|
||||
"f 1//3 2//4 3//5",
|
||||
"f 4//5 3//6 2//7",
|
||||
"f 4//7 2//8 1//1",
|
||||
]
|
||||
)
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
def test_save_obj_with_texture(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
@@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
mtl_file = open(mtl_file_name, "r")
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
self.assertClose(texture_image, texture_map)
|
||||
|
||||
def test_save_obj_with_normal_and_texture(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[
|
||||
[0.02, 0.5, 0.73],
|
||||
[0.3, 0.03, 0.361],
|
||||
[0.32, 0.12, 0.47],
|
||||
[0.36, 0.17, 0.9],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_normals_idx = faces
|
||||
verts_uvs = torch.tensor(
|
||||
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_uvs = faces
|
||||
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
obj_file = os.path.join(temp_dir, "mesh.obj")
|
||||
save_obj(
|
||||
obj_file,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places=2,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
verts_uvs=verts_uvs,
|
||||
faces_uvs=faces_uvs,
|
||||
texture_map=texture_map,
|
||||
)
|
||||
|
||||
expected_obj_file = "\n".join(
|
||||
[
|
||||
"",
|
||||
"mtllib mesh.mtl",
|
||||
"usemtl mesh",
|
||||
"",
|
||||
"v 0.01 0.20 0.30",
|
||||
"v 0.20 0.03 0.41",
|
||||
"v 0.30 0.40 0.05",
|
||||
"v 0.60 0.70 0.80",
|
||||
"vn 0.02 0.50 0.73",
|
||||
"vn 0.30 0.03 0.36",
|
||||
"vn 0.32 0.12 0.47",
|
||||
"vn 0.36 0.17 0.90",
|
||||
"vt 0.02 0.50",
|
||||
"vt 0.30 0.03",
|
||||
"vt 0.32 0.12",
|
||||
"vt 0.36 0.17",
|
||||
"f 1/1/1 3/3/3 2/2/2",
|
||||
"f 1/1/1 2/2/2 3/3/3",
|
||||
"f 4/4/4 3/3/3 2/2/2",
|
||||
"f 4/4/4 2/2/2 1/1/1",
|
||||
]
|
||||
)
|
||||
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
|
||||
|
||||
# Check there are only 3 files in the temp dir
|
||||
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
|
||||
tempfiles_dir = os.listdir(temp_dir)
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
@@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(tempfiles, tempfiles_dir)
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
obj_file = StringIO()
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
mtl_file = open(mtl_file_name, "r")
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
|
||||
Reference in New Issue
Block a user