18 Commits

Author SHA1 Message Date
Jeremy Reizenstein
297020a4b1 version 0.7.4
Summary: version number

Reviewed By: shapovalov

Differential Revision: D45704549

fbshipit-source-id: d63867f305b07c30ed9ea104f1494d23710fdbb7
2023-05-10 04:42:58 -07:00
Jeremy Reizenstein
062e6c54ae builds for PyTorch 2.0.1; drop 1.9
Summary: Drop support for PyTorch 1.9.0 and 1.9.1.

Reviewed By: shapovalov

Differential Revision: D45704329

fbshipit-source-id: c0fe3ecf6a1eb9bcd4163785c0cb4bf4f5060f50
2023-05-10 02:38:47 -07:00
Roman Shapovalov
c80180c96e Fix: FrameDataBuilder working with PathManager
Summary: In refactoring, we lost path manager here, which broke manifold storage. Fixing this.

Reviewed By: bottler

Differential Revision: D45574940

fbshipit-source-id: 579349eaa654215a09e057be57b56b46769c986a
2023-05-09 04:56:39 -07:00
Jason Fried
23cd19fbc7 typing.NamedTuple.field_types removed in favor of __annotations__
Summary:
typing.NamedTuple was simplified in 3.10
These two fields were the same in 3.8,  so this should be a no-op

#buildmore

Reviewed By: bottler

Differential Revision: D45373526

fbshipit-source-id: 2b26156f5f65b7be335133e9e705730f7254260d
2023-05-08 13:53:16 -07:00
dhb
092400f1e7 allow saving vertex normal in save_obj (#1511)
Summary:
Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`.

This patch fixes this by allowing passing per-vertex normal data in `save_obj`:
``` python
def save_obj(
    f: PathOrStr,
    verts,
    faces,
    decimal_places: Optional[int] = None,
    path_manager: Optional[PathManager] = None,
    *,
    verts_normals: Optional[torch.Tensor] = None,
    faces_normals: Optional[torch.Tensor] = None,
    verts_uvs: Optional[torch.Tensor] = None,
    faces_uvs: Optional[torch.Tensor] = None,
    texture_map: Optional[torch.Tensor] = None,
) -> None:
    """
    Save a mesh to an .obj file.

    Args:
        f: File (str or path) to which the mesh should be written.
        verts: FloatTensor of shape (V, 3) giving vertex coordinates.
        faces: LongTensor of shape (F, 3) giving faces.
        decimal_places: Number of decimal places for saving.
        path_manager: Optional PathManager for interpreting f if
            it is a str.
        verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
        faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
            for each vertex in the face.
        verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
        faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
            each vertex in the face.
        texture_map: FloatTensor of shape (H, W, 3) representing the texture map
            for the mesh which will be saved as an image. The values are expected
            to be in the range [0, 1],
    """
```

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1511

Reviewed By: shapovalov

Differential Revision: D45086045

Pulled By: bottler

fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
2023-05-07 06:32:02 -07:00
generatedunixname89002005287564
ec87284c4b Replace third-party mock with unittest.mock] vision/fair
Reviewed By: bottler

Differential Revision: D45600232

fbshipit-source-id: f41b95c6fca86d241666b54755a128cd33f6dd32
2023-05-05 09:36:30 -07:00
Xiao Xuan
f5a117c74b fix: correct typo in cameras.md (#1501)
Summary:
If my understanding is right, prp_screen[1] should be 32 rather than 48.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1501

Reviewed By: shapovalov

Differential Revision: D45044406

Pulled By: bottler

fbshipit-source-id: 7dd93312db4986f4701e642ba82d94333466b921
2023-05-05 08:13:39 -07:00
Jeremy Reizenstein
b921efae3e CUB usage fix for sample_farthest_points
Summary: Fix for https://github.com/facebookresearch/pytorch3d/issues/1529

Reviewed By: shapovalov

Differential Revision: D45569211

fbshipit-source-id: 8c485f26cd409cafac53d4d982a03cde81a1d853
2023-05-05 05:59:14 -07:00
Roman Shapovalov
c8d6cd427e Fix test_data_source in OSS.
Summary: Import generic path; avoiding incorrect path patching.

Reviewed By: bottler

Differential Revision: D45573976

fbshipit-source-id: e6ff4d759deb936e3b636defa1e0851fb0127b46
2023-05-05 02:05:50 -07:00
Jeremy Reizenstein
ef5f620263 nondeterminism warnings
Summary: do like xformers.

Reviewed By: shapovalov

Differential Revision: D44541873

fbshipit-source-id: 2c23160591cd9026fcd4972998d1bc90adba1356
2023-05-04 12:50:41 -07:00
Roman Shapovalov
3e3644e534 More tests for SQL Dataset
Summary:
I forgot to include these tests to D45086611 when transferring code from pixar_replay repo.

They test the new ORM types used in SQL dataset and are SQL Alchemy 2.0 specific.

An important test for extending types is a proof of concept for generality of SQL Dataset. The idea is to extend FrameAnnotation and FrameData in parallel.

Reviewed By: bottler

Differential Revision: D45529284

fbshipit-source-id: 2a634e518f580c312602107c85fc320db43abcf5
2023-05-04 03:32:27 -07:00
Ilia Vitsnudel
178a7774d4 Adding save mesh into glb file in TexturesVertex format
Summary:
Added a suit of functions and code additions to experimental_gltf_io.py file to enable saving Meshes in TexturesVertex format into .glb file.
Also added a test to tets_io_gltf.py to check the functionality with the test described in Test Plane.

Reviewed By: bottler

Differential Revision: D44969144

fbshipit-source-id: 9ce815a1584b510442fa36cc4dbc8d41cc3786d5
2023-05-01 00:41:47 -07:00
Emilien Garreau
823ab75d27 Simplify _xy_grid computation in raysampling
Summary: Remove the need of tuple and reversed in the raysampling xy_grid computation

Reviewed By: bottler

Differential Revision: D45269342

fbshipit-source-id: d0e4c0923b9a2cca674b35e8d64862043a0eab3b
2023-04-27 03:07:37 -07:00
Roman Shapovalov
32e1992924 SQL Index Dataset
Summary:
Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay.

It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS.

Reviewed By: bottler

Differential Revision: D45086611

fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
2023-04-25 09:56:15 -07:00
Roman Shapovalov
7aeedd17a4 When bounding boxes are cached in metadata, don’t crash on load_masks=False
Summary:
We currently support caching bounding boxes in MaskAnnotation. If present, they are not re-computed from the mask. However, the masks need to be loaded for the bbox to be set.

This diff fixes that. Even if load_masks / load_blobs are unset, the bounding box can be picked up from the metadata.

Reviewed By: bottler

Differential Revision: D45144918

fbshipit-source-id: 8a2e2c115e96070b6fcdc29cbe57e1cee606ddcd
2023-04-20 07:28:45 -07:00
Roman Shapovalov
0e3138eca8 Optional ground-truth depth maps in visualiser
Summary: The code does not crash if depth map/mask are not given.

Reviewed By: bottler

Differential Revision: D45082985

fbshipit-source-id: 3610d8beb4ac897fbbe52f56a6dd012a6365b89b
2023-04-18 07:00:17 -07:00
Richard Barnes
1af6bf4768 Replace hasattr with getattr in vision/fair/pytorch3d/pytorch3d/renderer/cameras.py
Summary:
The pattern
```
X.Y if hasattr(X, "Y") else Z
```
can be replaced with
```
getattr(X, "Y", Z)
```

The [getattr](https://www.w3schools.com/python/ref_func_getattr.asp) function gives more succinct code than the [hasattr](https://www.w3schools.com/python/ref_func_hasattr.asp) function. Please use it when appropriate.

**This diff is very low risk. Green tests indicate that you can safely Accept & Ship.**

Reviewed By: bottler

Differential Revision: D44886893

fbshipit-source-id: 86ba23e837217e1ebd64bf8e27d286257894839e
2023-04-14 04:24:54 -07:00
generatedunixname89002005307016
355d6332cb upgrade pyre version in fbcode/vision - batch 2
Differential Revision: D44881859

fbshipit-source-id: 4ed410724a14d580f811c1288f51a71ce8fb0c9a
2023-04-11 17:15:12 -07:00
43 changed files with 3191 additions and 197 deletions

View File

@@ -180,30 +180,6 @@ workflows:
jobs:
# - main:
# context: DOCKERHUB_TOKEN
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu102
name: linux_conda_py38_cu102_pyt190
python_version: '3.8'
pytorch_version: 1.9.0
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu111
name: linux_conda_py38_cu111_pyt190
python_version: '3.8'
pytorch_version: 1.9.0
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu102
name: linux_conda_py38_cu102_pyt191
python_version: '3.8'
pytorch_version: 1.9.1
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu111
name: linux_conda_py38_cu111_pyt191
python_version: '3.8'
pytorch_version: 1.9.1
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu102
@@ -370,29 +346,19 @@ workflows:
python_version: '3.8'
pytorch_version: 2.0.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda117
context: DOCKERHUB_TOKEN
cu_version: cu102
name: linux_conda_py39_cu102_pyt190
python_version: '3.9'
pytorch_version: 1.9.0
cu_version: cu117
name: linux_conda_py38_cu117_pyt201
python_version: '3.8'
pytorch_version: 2.0.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu111
name: linux_conda_py39_cu111_pyt190
python_version: '3.9'
pytorch_version: 1.9.0
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu102
name: linux_conda_py39_cu102_pyt191
python_version: '3.9'
pytorch_version: 1.9.1
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu111
name: linux_conda_py39_cu111_pyt191
python_version: '3.9'
pytorch_version: 1.9.1
cu_version: cu118
name: linux_conda_py38_cu118_pyt201
python_version: '3.8'
pytorch_version: 2.0.1
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu102
@@ -558,6 +524,20 @@ workflows:
name: linux_conda_py39_cu118_pyt200
python_version: '3.9'
pytorch_version: 2.0.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda117
context: DOCKERHUB_TOKEN
cu_version: cu117
name: linux_conda_py39_cu117_pyt201
python_version: '3.9'
pytorch_version: 2.0.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py39_cu118_pyt201
python_version: '3.9'
pytorch_version: 2.0.1
- binary_linux_conda:
context: DOCKERHUB_TOKEN
cu_version: cu102
@@ -666,6 +646,20 @@ workflows:
name: linux_conda_py310_cu118_pyt200
python_version: '3.10'
pytorch_version: 2.0.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda117
context: DOCKERHUB_TOKEN
cu_version: cu117
name: linux_conda_py310_cu117_pyt201
python_version: '3.10'
pytorch_version: 2.0.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py310_cu118_pyt201
python_version: '3.10'
pytorch_version: 2.0.1
- binary_linux_conda_cuda:
name: testrun_conda_cuda_py38_cu102_pyt190
context: DOCKERHUB_TOKEN

View File

@@ -20,8 +20,6 @@ from packaging import version
# version of pytorch.
# Pytorch 1.4 also supports cuda 10.0 but we no longer build for cuda 10.0 at all.
CONDA_CUDA_VERSIONS = {
"1.9.0": ["cu102", "cu111"],
"1.9.1": ["cu102", "cu111"],
"1.10.0": ["cu102", "cu111", "cu113"],
"1.10.1": ["cu102", "cu111", "cu113"],
"1.10.2": ["cu102", "cu111", "cu113"],
@@ -31,6 +29,7 @@ CONDA_CUDA_VERSIONS = {
"1.13.0": ["cu116", "cu117"],
"1.13.1": ["cu116", "cu117"],
"2.0.0": ["cu117", "cu118"],
"2.0.1": ["cu117", "cu118"],
}

View File

@@ -9,7 +9,7 @@ The core library is written in PyTorch. Several components have underlying imple
- Linux or macOS or Windows
- Python 3.8, 3.9 or 3.10
- PyTorch 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0 or 2.0.0.
- PyTorch 1.10.0, 1.10.1, 1.10.2, 1.11.0, 1.12.0, 1.12.1, 1.13.0, 2.0.0 or 2.0.1.
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
- gcc & g++ ≥ 4.9
- [fvcore](https://github.com/facebookresearch/fvcore)

View File

@@ -20,7 +20,8 @@
import os
import sys
import mock
import unittest.mock as mock
from recommonmark.parser import CommonMarkParser
from recommonmark.states import DummyStateMachine
from sphinx.builders.html import StandaloneHTMLBuilder

View File

@@ -85,7 +85,7 @@ cameras_ndc = PerspectiveCameras(focal_length=fcl_ndc, principal_point=prp_ndc)
# Screen space camera
image_size = ((128, 256),) # (h, w)
fcl_screen = (76.8,) # fcl_ndc * min(image_size) / 2
prp_screen = ((115.2, 48), ) # w / 2 - px_ndc * min(image_size) / 2, h / 2 - py_ndc * min(image_size) / 2
prp_screen = ((115.2, 32), ) # w / 2 - px_ndc * min(image_size) / 2, h / 2 - py_ndc * min(image_size) / 2
cameras_screen = PerspectiveCameras(focal_length=fcl_screen, principal_point=prp_screen, in_ndc=False, image_size=image_size)
```

View File

@@ -4,4 +4,4 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "0.7.3"
__version__ = "0.7.4"

View File

@@ -266,6 +266,8 @@ at::Tensor FaceAreasNormalsBackwardCuda(
grad_normals_t{grad_normals, "grad_normals", 4};
at::CheckedFrom c = "FaceAreasNormalsBackwardCuda";
at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("FaceAreasNormalsBackwardCuda");
// Set the device for the kernel launch based on the device of verts
at::cuda::CUDAGuard device_guard(verts.device());

View File

@@ -130,6 +130,9 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackwardCuda(
at::checkAllSameType(
c, {barycentric_coords_t, face_attrs_t, grad_pix_attrs_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("InterpFaceAttrsBackwardCuda");
// Set the device for the kernel launch based on the input
at::cuda::CUDAGuard device_guard(pix_to_face.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@@ -534,6 +534,9 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(p1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@@ -305,6 +305,8 @@ std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
at::CheckedFrom c = "DistanceBackwardCuda";
at::checkAllSameGPU(c, {objects_t, targets_t, idx_objects_t, grad_dists_t});
at::checkAllSameType(c, {objects_t, targets_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("DistanceBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(objects.device());
@@ -624,6 +626,9 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
at::CheckedFrom c = "PointFaceArrayDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, tris_t, grad_dists_t});
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic(
"PointFaceArrayDistanceBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
@@ -787,6 +792,9 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
at::CheckedFrom c = "PointEdgeArrayDistanceBackwardCuda";
at::checkAllSameGPU(c, {points_t, segms_t, grad_dists_t});
at::checkAllSameType(c, {points_t, segms_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic(
"PointEdgeArrayDistanceBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());

View File

@@ -141,6 +141,9 @@ void PointsToVolumesForwardCuda(
grid_sizes_t,
mask_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("PointsToVolumesForwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points_3d.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@@ -583,6 +583,9 @@ at::Tensor RasterizeMeshesBackwardCuda(
at::checkAllSameType(
c, {face_verts_t, grad_zbuf_t, grad_bary_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("RasterizeMeshesBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@@ -423,7 +423,8 @@ at::Tensor RasterizePointsBackwardCuda(
at::CheckedFrom c = "RasterizePointsBackwardCuda";
at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t});
at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t});
// This is nondeterministic because atomicAdd
at::globalContext().alertNotDeterministic("RasterizePointsBackwardCuda");
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

View File

@@ -155,7 +155,7 @@ at::Tensor FarthestPointSamplingCuda(
// Max possible threads per block
const int MAX_THREADS_PER_BLOCK = 1024;
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1);
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 2);
// Create the accessors
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
@@ -215,10 +215,6 @@ at::Tensor FarthestPointSamplingCuda(
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
case 1:
FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>(
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
break;
default:
FarthestPointSamplingKernel<1024>
<<<blocks, threads, shared_mem, stream>>>(

View File

@@ -450,6 +450,9 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
*,
load_blobs: bool = True,
**kwargs,
) -> FrameDataSubtype:
"""An abstract method to build the frame data based on raw frame/sequence
annotations, load the binary data and adjust them according to the metadata.
@@ -465,8 +468,9 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
Beware that modifications of frame data are done in-place.
Args:
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
dataset_root: The root folder of the dataset; all paths in frame / sequence
annotations are defined w.r.t. this root. Has to be set if any of the
load_* flabs below is true.
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
@@ -494,7 +498,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
path_manager: Optionally a PathManager for interpreting paths in a special way.
"""
dataset_root: str = ""
dataset_root: Optional[str] = None
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
@@ -510,11 +514,37 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
box_crop_context: float = 0.3
path_manager: Any = None
def __post_init__(self) -> None:
load_any_blob = (
self.load_images
or self.load_depths
or self.load_depth_masks
or self.load_masks
or self.load_point_clouds
)
if load_any_blob and self.dataset_root is None:
raise ValueError(
"dataset_root must be set to load any blob data. "
"Make sure it is set in either FrameDataBuilder or Dataset params."
)
if self.path_manager is None:
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
else:
dataset_root_exists = self.path_manager.isdir(self.dataset_root)
if load_any_blob and not dataset_root_exists:
raise ValueError(
f"dataset_root is passed but {self.dataset_root} does not exist."
)
def build(
self,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
*,
load_blobs: bool = True,
**kwargs,
) -> FrameDataSubtype:
"""Builds the frame data based on raw frame/sequence annotations, loads the
binary data and adjust them according to the metadata. The processing includes:
@@ -555,12 +585,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
else None,
)
if load_blobs and self.load_masks and frame_annotation.mask is not None:
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
) = self._load_fg_probability(frame_annotation)
mask_annotation = frame_annotation.mask
if mask_annotation is not None:
fg_mask_np: Optional[np.ndarray] = None
if load_blobs and self.load_masks:
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
frame_data.mask_path = mask_path
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
@@ -604,25 +641,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]:
full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore
) -> Tuple[np.ndarray, str]:
assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path))
# we can use provided bbox_xywh or calculate it based on mask
# saves time to skip bbox calculation
# pyre-ignore
bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask(
fg_probability, self.box_crop_mask_thr
)
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
)
return (
safe_as_tensor(fg_probability, torch.float),
full_path,
safe_as_tensor(bbox_xywh, torch.long),
)
return fg_probability, full_path
def _load_images(
self,
@@ -650,7 +678,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
assert self.dataset_root is not None and entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
@@ -660,6 +688,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
if self.load_depth_masks:
assert entry_depth.mask_path is not None
# pyre-ignore
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path))
else:
@@ -708,6 +737,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str:

View File

@@ -190,6 +190,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
box_crop=self.box_crop,
box_crop_mask_thr=self.box_crop_mask_thr,
box_crop_context=self.box_crop_context,
path_manager=self.path_manager,
)
logger.info(str(self))

View File

@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This functionality requires SQLAlchemy 2.0 or later.
import math
import struct
from typing import Optional, Tuple
import numpy as np
from pytorch3d.implicitron.dataset.types import (
DepthAnnotation,
ImageAnnotation,
MaskAnnotation,
PointCloudAnnotation,
VideoAnnotation,
ViewpointAnnotation,
)
from sqlalchemy import LargeBinary
from sqlalchemy.orm import (
composite,
DeclarativeBase,
Mapped,
mapped_column,
MappedAsDataclass,
)
from sqlalchemy.types import TypeDecorator
# these produce policies to serialize structured types to blobs
def ArrayTypeFactory(shape):
class NumpyArrayType(TypeDecorator):
impl = LargeBinary
def process_bind_param(self, value, dialect):
if value is not None:
if value.shape != shape:
raise ValueError(f"Passed an array of wrong shape: {value.shape}")
return value.astype(np.float32).tobytes()
return None
def process_result_value(self, value, dialect):
if value is not None:
return np.frombuffer(value, dtype=np.float32).reshape(shape)
return None
return NumpyArrayType
def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
format_symbol = {
float: "f", # float32
int: "i", # int32
}[dtype]
class TupleType(TypeDecorator):
impl = LargeBinary
_format = format_symbol * math.prod(shape)
def process_bind_param(self, value, _):
if value is None:
return None
if len(shape) > 1:
value = np.array(value, dtype=dtype).reshape(-1)
return struct.pack(TupleType._format, *value)
def process_result_value(self, value, _):
if value is None:
return None
loaded = struct.unpack(TupleType._format, value)
if len(shape) > 1:
loaded = _rec_totuple(
np.array(loaded, dtype=dtype).reshape(shape).tolist()
)
return loaded
return TupleType
def _rec_totuple(t):
if isinstance(t, list):
return tuple(_rec_totuple(x) for x in t)
return t
class Base(MappedAsDataclass, DeclarativeBase):
"""subclasses will be converted to dataclasses"""
class SqlFrameAnnotation(Base):
__tablename__ = "frame_annots"
sequence_name: Mapped[str] = mapped_column(primary_key=True)
frame_number: Mapped[int] = mapped_column(primary_key=True)
frame_timestamp: Mapped[float] = mapped_column(index=True)
image: Mapped[ImageAnnotation] = composite(
mapped_column("_image_path"),
mapped_column("_image_size", TupleTypeFactory(int)),
)
depth: Mapped[DepthAnnotation] = composite(
mapped_column("_depth_path", nullable=True),
mapped_column("_depth_scale_adjustment", nullable=True),
mapped_column("_depth_mask_path", nullable=True),
)
mask: Mapped[MaskAnnotation] = composite(
mapped_column("_mask_path", nullable=True),
mapped_column("_mask_mass", index=True, nullable=True),
mapped_column(
"_mask_bounding_box_xywh",
TupleTypeFactory(float, shape=(4,)),
nullable=True,
),
)
viewpoint: Mapped[ViewpointAnnotation] = composite(
mapped_column(
"_viewpoint_R", TupleTypeFactory(float, shape=(3, 3)), nullable=True
),
mapped_column(
"_viewpoint_T", TupleTypeFactory(float, shape=(3,)), nullable=True
),
mapped_column(
"_viewpoint_focal_length", TupleTypeFactory(float), nullable=True
),
mapped_column(
"_viewpoint_principal_point", TupleTypeFactory(float), nullable=True
),
mapped_column("_viewpoint_intrinsics_format", nullable=True),
)
class SqlSequenceAnnotation(Base):
__tablename__ = "sequence_annots"
sequence_name: Mapped[str] = mapped_column(primary_key=True)
category: Mapped[str] = mapped_column(index=True)
video: Mapped[VideoAnnotation] = composite(
mapped_column("_video_path", nullable=True),
mapped_column("_video_length", nullable=True),
)
point_cloud: Mapped[PointCloudAnnotation] = composite(
mapped_column("_point_cloud_path", nullable=True),
mapped_column("_point_cloud_quality_score", nullable=True),
mapped_column("_point_cloud_n_points", nullable=True),
)
# the bigger the better
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)

View File

@@ -0,0 +1,736 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import json
import logging
import os
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np
import pandas as pd
import sqlalchemy as sa
import torch
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
FrameData,
FrameDataBuilder,
FrameDataBuilderBase,
)
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from sqlalchemy.orm import Session
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
logger = logging.getLogger(__name__)
_SET_LISTS_TABLE: str = "set_lists"
@registry.register
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
"""
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
The length is returned after all sequence and frame filters are applied (see param
definitions below). Indices can either be ordinal in [0, len), or pairs of
(sequence_name, frame_number); with the performance of `dataset[i]` and
`dataset[sequence_name, frame_number]` being same. A faster way to get metadata only
(without blobs) is `dataset.meta[idx]` indexing; it requires box_crop==False.
With ordinal indexing, the sequences are NOT guaranteed to span contiguous index
ranges, and frame numbers are NOT guaranteed to be increasing within a sequence.
Sequence-aware batch samplers have to use `sequence_[frames|indices]_in_order`
iterators, which are efficient.
This functionality requires SQLAlchemy 2.0 or later.
Metadata-related args:
sqlite_metadata_file: A SQLite file containing frame and sequence annotation
tables (mapping to SqlFrameAnnotation and SqlSequenceAnnotation,
respectively).
dataset_root: A root directory to look for images, masks, etc. It can be
alternatively set in `frame_data_builder` args, but this takes precedence.
subset_lists_file: A JSON/sqlite file containing the lists of frames
corresponding to different subsets (e.g. train/val/test) of the dataset;
format: {subset: [(sequence_name, frame_id, file_path)]}. All entries
must be present in frame_annotation metadata table.
path_manager: a facade for non-POSIX filesystems.
subsets: Restrict frames/sequences only to the given list of subsets
as defined in subset_lists_file (see above). Applied before all other
filters.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask (needs frame_annotation.mask.mass to be set;
null values are retained).
pick_frames_sql_clause: SQL WHERE clause to constrain frame annotations
NOTE: This is a potential security risk! The string is passed to the SQL
engine verbatim. Dont expose it to end users of your application!
pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to.
exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
sequences (after other sequence filters have been applied but before
frame-based filters).
limit_to: Limit the dataset to the first #limit_to frames (after other
filters have been applied, except n_frames_per_sequence).
n_frames_per_sequence: If > 0, randomly samples `n_frames_per_sequence`
frames in each sequences uniformly without replacement if it has
more frames than that; applied after other frame-level filters.
seed: The seed of the random generator sampling `n_frames_per_sequence`
random frames per sequence.
"""
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
sqlite_metadata_file: str = ""
dataset_root: Optional[str] = None
subset_lists_file: str = ""
eval_batches_file: Optional[str] = None
path_manager: Any = None
subsets: Optional[List[str]] = None
remove_empty_masks: bool = True
pick_frames_sql_clause: Optional[str] = None
pick_categories: Tuple[str, ...] = ()
pick_sequences: Tuple[str, ...] = ()
exclude_sequences: Tuple[str, ...] = ()
limit_sequences_to: int = 0
limit_to: int = 0
n_frames_per_sequence: int = -1
seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000
# we set it manually in the constructor
# _index: pd.DataFrame = field(init=False)
frame_data_builder: FrameDataBuilderBase
frame_data_builder_class_type: str = "FrameDataBuilder"
def __post_init__(self) -> None:
if sa.__version__ < "2.0":
raise ImportError("This class requires SQL Alchemy 2.0 or later")
if not self.sqlite_metadata_file:
raise ValueError("sqlite_metadata_file must be set")
if self.dataset_root:
frame_builder_type = self.frame_data_builder_class_type
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
"dataset_root"
] = self.dataset_root
run_auto_creation(self)
self.frame_data_builder.path_manager = self.path_manager
# pyre-ignore
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
sequences = self._get_filtered_sequences_if_any()
if self.subsets:
index = self._build_index_from_subset_lists(sequences)
else:
# TODO: if self.subset_lists_file and not self.subsets, it might be faster to
# still use the concatenated lists, assuming they cover the whole dataset
index = self._build_index_from_db(sequences)
if self.n_frames_per_sequence >= 0:
index = self._stratified_sample_index(index)
if len(index) == 0:
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
self.eval_batches = None # pyre-ignore
if self.eval_batches_file:
self.eval_batches = self._load_filter_eval_batches()
logger.info(str(self))
def __len__(self) -> int:
# pyre-ignore[16]
return len(self._index)
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
"""
Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair
"""
return self._get_item(frame_idx, True)
@property
def meta(self):
"""
Allows accessing metadata only without loading blobs using `dataset.meta[idx]`.
Requires box_crop==False, since in that case, cameras cannot be adjusted
without loading masks.
Returns:
FrameData objects with blob fields like `image_rgb` set to None.
Raises:
ValueError if dataset.box_crop is set.
"""
return SqlIndexDataset._MetadataAccessor(self)
@dataclass
class _MetadataAccessor:
dataset: "SqlIndexDataset"
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
return self.dataset._get_item(frame_idx, False)
def _get_item(
self, frame_idx: Union[int, Tuple[str, int]], load_blobs: bool = True
) -> FrameData:
if isinstance(frame_idx, int):
if frame_idx >= len(self._index):
raise IndexError(f"index {frame_idx} out of range {len(self._index)}")
seq, frame = self._index.index[frame_idx]
else:
seq, frame, *rest = frame_idx
if (seq, frame) not in self._index.index:
raise IndexError(
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
)
if rest and rest[0] != self._index.loc[(seq, frame), "_image_path"]:
raise IndexError(f"Non-matching image path in {frame_idx}.")
stmt = sa.select(self.frame_annotations_type).where(
self.frame_annotations_type.sequence_name == seq,
self.frame_annotations_type.frame_number
== int(frame), # cast from np.int64
)
seq_stmt = sa.select(SqlSequenceAnnotation).where(
SqlSequenceAnnotation.sequence_name == seq
)
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
frame_data = self.frame_data_builder.build(
entry, seq_metadata, load_blobs=load_blobs
)
# The rest of the fields are optional
frame_data.frame_type = self._get_frame_type(entry)
return frame_data
def __str__(self) -> str:
# pyre-ignore[16]
return f"SqlIndexDataset #frames={len(self._index)}"
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
return self._index.index.unique("sequence_name")
# override
def category_to_sequence_names(self) -> Dict[str, List[str]]:
stmt = sa.select(
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
).where( # we limit results to sequences that have frames after all filters
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
)
with self._sql_engine.connect() as connection:
cat_to_seqs = pd.read_sql(stmt, connection)
return cat_to_seqs.groupby("category")["sequence_name"].apply(list).to_dict()
# override
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int], subset_filter: Optional[Sequence[str]] = None
) -> List[Tuple[int, float]]:
"""
Implements the DatasetBase method.
NOTE: Avoid this function as there are more efficient alternatives such as
querying `dataset[idx]` directly or getting all sequence frames with
`sequence_[frames|indices]_in_order`.
Return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. They need to belong to the same sequence!
If timestamps are absent, they are replaced with zeros.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idxs: a sequence int frame index in the dataset (it can be a slice)
subset_filter: must remain None
Returns:
list of tuples of
- frame index in video
- timestamp of frame in video, coalesced with 0s
Raises:
ValueError if idxs belong to more than one sequence.
"""
if subset_filter is not None:
raise NotImplementedError(
"Subset filters are not supported in SQL Dataset. "
"We encourage creating a dataset per subset."
)
index_slice, _ = self._get_frame_no_coalesced_ts_by_row_indices(idxs)
# alternatively, we can use `.values.tolist()`, which may be faster
# but returns a list of lists
return list(index_slice.itertuples())
# override
def sequence_frames_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[Tuple[float, int, int]]:
"""
Overrides the default DatasetBase implementation (we dont use `_seq_to_idx`).
Returns an iterator over the frame indices in a given sequence.
We attempt to first sort by timestamp (if they are available),
then by frame number.
Args:
seq_name: the name of the sequence.
subset_filter: subset names to filter to
Returns:
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
where `frame_no` is the index within the sequence, and
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
# TODO: implement sort_timestamp_first? (which would matter if the orders
# of frame numbers and timestamps are different)
rows = self._index.index.get_loc(seq_name)
if isinstance(rows, slice):
assert rows.stop is not None, "Unexpected result from pandas"
rows = range(rows.start or 0, rows.stop, rows.step or 1)
else:
rows = np.where(rows)[0]
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
rows, seq_name, subset_filter
)
index_slice["idx"] = idx
yield from index_slice.itertuples(index=False)
# override
def get_eval_batches(self) -> Optional[List[Any]]:
"""
This class does not support eval batches with ordinal indices. You can pass
eval_batches as a batch_sampler to a data_loader since the dataset supports
`dataset[seq_name, frame_no]` indexing.
"""
return self.eval_batches
# override
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
raise ValueError("Not supported! Preprocess the data by merging them instead.")
# override
@property
def frame_data_type(self) -> Type[FrameData]:
return self.frame_data_builder.frame_data_type
def is_filtered(self) -> bool:
"""
Returns `True` in case the dataset has been filtered and thus some frame
annotations stored on the disk might be missing in the dataset object.
Does not account for subsets.
Returns:
is_filtered: `True` if the dataset has been filtered, else `False`.
"""
return (
self.remove_empty_masks
or self.limit_to > 0
or self.limit_sequences_to > 0
or len(self.pick_sequences) > 0
or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0
)
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible query: WHERE category IN 'self.pick_categories'
# AND sequence_name IN 'self.pick_sequences'
# AND sequence_name NOT IN 'self.exclude_sequences'
# LIMIT 'self.limit_sequence_to'
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
where_conditions = [
*self._get_category_filters(),
*self._get_pick_filters(),
*self._get_exclude_filters(),
]
if where_conditions:
stmt = stmt.where(*where_conditions)
if self.limit_sequences_to > 0:
logger.info(
f"Limiting dataset to first {self.limit_sequences_to} sequences"
)
# NOTE: ROWID is SQLite-specific
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
if not where_conditions and self.limit_sequences_to <= 0:
# we will not need to filter by sequences
return None
with self._sql_engine.connect() as connection:
sequences = pd.read_sql_query(stmt, connection)["sequence_name"]
logger.info("... retained %d sequences" % len(sequences))
return sequences
def _get_category_filters(self) -> List[sa.ColumnElement]:
if not self.pick_categories:
return []
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
def _get_pick_filters(self) -> List[sa.ColumnElement]:
if not self.pick_sequences:
return []
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
if not self.exclude_sequences:
return []
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
assert self.subsets is not None
with open(subset_lists_path, "r") as f:
subset_to_seq_frame = json.load(f)
seq_frame_list = sum(
(
[(*row, subset) for row in subset_to_seq_frame[subset]]
for subset in self.subsets
),
[],
)
index = pd.DataFrame(
seq_frame_list,
columns=["sequence_name", "frame_number", "_image_path", "subset"],
)
return index
def _load_subsets_from_sql(self, subset_lists_path: str) -> pd.DataFrame:
subsets = self.subsets
assert subsets is not None
# we need a new engine since we store the subsets in a separate DB
engine = sa.create_engine(f"sqlite:///{subset_lists_path}")
table = sa.Table(_SET_LISTS_TABLE, sa.MetaData(), autoload_with=engine)
stmt = sa.select(table).where(table.c.subset.in_(subsets))
with engine.connect() as connection:
index = pd.read_sql(stmt, connection)
return index
def _build_index_from_subset_lists(
self, sequences: Optional[pd.Series]
) -> pd.DataFrame:
if not self.subset_lists_file:
raise ValueError("Requested subsets but subset_lists_file not given")
logger.info(f"Loading subset lists from {self.subset_lists_file}.")
subset_lists_path = self._local_path(self.subset_lists_file)
if subset_lists_path.lower().endswith(".json"):
index = self._load_subsets_from_json(subset_lists_path)
else:
index = self._load_subsets_from_sql(subset_lists_path)
index = index.set_index(["sequence_name", "frame_number"])
logger.info(f" -> loaded {len(index)} samples of {self.subsets}.")
if sequences is not None:
logger.info("Applying filtered sequences.")
sequence_values = index.index.get_level_values("sequence_name")
index = index.loc[sequence_values.isin(sequences)]
logger.info(f" -> retained {len(index)} samples.")
pick_frames_criteria = []
if self.remove_empty_masks:
logger.info("Culling samples with empty masks.")
if len(index) > self.remove_empty_masks_poll_whole_table_threshold:
# APPROACH 1: find empty masks and drop indices.
# dev load: 17s / 15 s (3.1M / 500K)
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
).where(self.frame_annotations_type._mask_mass == 0)
with Session(self._sql_engine) as session:
to_remove = session.execute(stmt).all()
# Pandas uses np.int64 for integer types, so we have to case
# we might want to read it to pandas DataFrame directly to avoid the loop
to_remove = [(seq, np.int64(fr)) for seq, fr in to_remove]
index.drop(to_remove, errors="ignore", inplace=True)
else:
# APPROACH 3: load index into a temp table and join with annotations
# dev load: 94 s / 23 s (3.1M / 500K)
pick_frames_criteria.append(
sa.or_(
self.frame_annotations_type._mask_mass.is_(None),
self.frame_annotations_type._mask_mass != 0,
)
)
if self.pick_frames_sql_clause:
logger.info("Applying the custom SQL clause.")
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
if pick_frames_criteria:
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
logger.info(f" -> retained {len(index)} samples.")
if self.limit_to > 0:
logger.info(f"Limiting dataset to first {self.limit_to} frames")
index = index.sort_index().iloc[: self.limit_to]
return index.reset_index()
def _pick_frames_by_criteria(self, index: pd.DataFrame, criteria) -> pd.DataFrame:
IndexTable = self._get_temp_index_table_instance()
with self._sql_engine.connect() as connection:
IndexTable.create(connection)
# we dont let pandass `to_sql` create the table automatically as
# the table would be permanent, so we create it and append with pandas
n_rows = index.to_sql(IndexTable.name, connection, if_exists="append")
assert n_rows == len(index)
sa_type = self.frame_annotations_type
stmt = (
sa.select(IndexTable)
.select_from(
IndexTable.join(
self.frame_annotations_type,
sa.and_(
sa_type.sequence_name == IndexTable.c.sequence_name,
sa_type.frame_number == IndexTable.c.frame_number,
),
)
)
.where(*criteria)
)
return pd.read_sql_query(stmt, connection).set_index(
["sequence_name", "frame_number"]
)
def _build_index_from_db(self, sequences: Optional[pd.Series]):
logger.info("Loading sequcence-frame index from the database")
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
self.frame_annotations_type._image_path,
sa.null().label("subset"),
)
where_conditions = []
if sequences is not None:
logger.info(" applying filtered sequences")
where_conditions.append(
self.frame_annotations_type.sequence_name.in_(sequences.tolist())
)
if self.remove_empty_masks:
logger.info(" excluding samples with empty masks")
where_conditions.append(
sa.or_(
self.frame_annotations_type._mask_mass.is_(None),
self.frame_annotations_type._mask_mass != 0,
)
)
if self.pick_frames_sql_clause:
logger.info(" applying custom SQL clause")
where_conditions.append(sa.text(self.pick_frames_sql_clause))
if where_conditions:
stmt = stmt.where(*where_conditions)
if self.limit_to > 0:
logger.info(f"Limiting dataset to first {self.limit_to} frames")
stmt = stmt.order_by(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
).limit(self.limit_to)
with self._sql_engine.connect() as connection:
index = pd.read_sql_query(stmt, connection)
logger.info(f" -> loaded {len(index)} samples.")
return index
def _sort_index_(self, index):
logger.info("Sorting the index by sequence and frame number.")
index.sort_values(["sequence_name", "frame_number"], inplace=True)
logger.info(" -> Done.")
def _load_filter_eval_batches(self):
assert self.eval_batches_file
logger.info(f"Loading eval batches from {self.eval_batches_file}")
if not os.path.isfile(self.eval_batches_file):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
f"Looking for dataset json file in {self.eval_batches_file}. "
+ "Please specify a correct dataset_root folder."
)
with open(self.eval_batches_file, "r") as f:
eval_batches = json.load(f)
# limit the dataset to sequences to allow multiple evaluations in one file
pick_sequences = set(self.pick_sequences)
if self.pick_categories:
cat_to_seq = self.category_to_sequence_names()
pick_sequences.update(
seq for cat in self.pick_categories for seq in cat_to_seq[cat]
)
if pick_sequences:
old_len = len(eval_batches)
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
logger.warn(
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
)
if self.exclude_sequences:
old_len = len(eval_batches)
exclude_sequences = set(self.exclude_sequences)
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
logger.warn(
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
)
return eval_batches
def _stratified_sample_index(self, index):
# NOTE this stratified sampling can be done more efficiently in
# the no-subset case above if it is added to the SQL query.
# We keep this generic implementation since no-subset case is uncommon
index = index.groupby("sequence_name", group_keys=False).apply(
lambda seq_frames: seq_frames.sample(
min(len(seq_frames), self.n_frames_per_sequence),
random_state=(
_seq_name_to_seed(seq_frames.iloc[0]["sequence_name"]) + self.seed
),
)
)
logger.info(f" -> retained {len(index)} samples aster stratified sampling.")
return index
def _get_frame_type(self, entry: SqlFrameAnnotation) -> Optional[str]:
return self._index.loc[(entry.sequence_name, entry.frame_number), "subset"]
def _get_frame_no_coalesced_ts_by_row_indices(
self,
idxs: Sequence[int],
seq_name: Optional[str] = None,
subset_filter: Union[Sequence[str], str, None] = None,
) -> Tuple[pd.DataFrame, Sequence[int]]:
"""
Loads timestamps for given index rows belonging to the same sequence.
If seq_name is known, it speeds up the computation.
Raises ValueError if `idxs` do not all belong to a single sequences .
"""
index_slice = self._index.iloc[idxs]
if subset_filter is not None:
if isinstance(subset_filter, str):
subset_filter = [subset_filter]
indicator = index_slice["subset"].isin(subset_filter)
index_slice = index_slice.loc[indicator]
idxs = [i for i, isin in zip(idxs, indicator) if isin]
frames = index_slice.index.get_level_values("frame_number").tolist()
if seq_name is None:
seq_name_list = index_slice.index.get_level_values("sequence_name").tolist()
seq_name_set = set(seq_name_list)
if len(seq_name_set) > 1:
raise ValueError("Given indices belong to more than one sequence.")
elif len(seq_name_set) == 1:
seq_name = seq_name_list[0]
coalesced_ts = sa.sql.functions.coalesce(
self.frame_annotations_type.frame_timestamp, 0
)
stmt = sa.select(
coalesced_ts.label("frame_timestamp"),
self.frame_annotations_type.frame_number,
).where(
self.frame_annotations_type.sequence_name == seq_name,
self.frame_annotations_type.frame_number.in_(frames),
)
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if len(frame_no_ts) != len(index_slice):
raise ValueError(
"Not all indices are found in the database; "
"do they belong to more than one sequence?"
)
return frame_no_ts, idxs
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)
def _get_temp_index_table_instance(self, table_name: str = "__index"):
CachedTable = self.frame_annotations_type.metadata.tables.get(table_name)
if CachedTable is not None: # table definition is not idempotent
return CachedTable
return sa.Table(
table_name,
self.frame_annotations_type.metadata,
sa.Column("sequence_name", sa.String, primary_key=True),
sa.Column("frame_number", sa.Integer, primary_key=True),
sa.Column("_image_path", sa.String),
sa.Column("subset", sa.String),
prefixes=["TEMP"], # NOTE SQLite specific!
)
def _seq_name_to_seed(seq_name) -> int:
"""Generates numbers in [0, 2 ** 28)"""
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
def _safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None

View File

@@ -0,0 +1,424 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import List, Optional, Tuple, Type
import numpy as np
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
)
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,
run_auto_creation,
)
from .sql_dataset import SqlIndexDataset
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
# _NEED_CONTROL is a list of those elements of SqlIndexDataset which
# are not directly specified for it in the config but come from the
# DatasetMapProvider.
_NEED_CONTROL: Tuple[str, ...] = (
"path_manager",
"subsets",
"sqlite_metadata_file",
"subset_lists_file",
)
logger = logging.getLogger(__name__)
@registry.register
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"""
Generates the training, validation, and testing dataset objects for
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
The dataset is organized in the filesystem as follows::
self.dataset_root
├── <possible/partition/0>
│ ├── <sequence_name_0>
│ │ ├── depth_masks
│ │ ├── depths
│ │ ├── images
│ │ ├── masks
│ │ └── pointcloud.ply
│ ├── <sequence_name_1>
│ │ ├── depth_masks
│ │ ├── depths
│ │ ├── images
│ │ ├── masks
│ │ └── pointcloud.ply
│ ├── ...
│ ├── <sequence_name_N>
│ ├── set_lists
│ ├── <subset_base_name_0>.json
│ ├── <subset_base_name_1>.json
│ ├── ...
│ ├── <subset_base_name_2>.json
│ ├── eval_batches
│ │ ├── <eval_batches_base_name_0>.json
│ │ ├── <eval_batches_base_name_1>.json
│ │ ├── ...
│ │ ├── <eval_batches_base_name_M>.json
│ ├── frame_annotations.jgz
│ ├── sequence_annotations.jgz
├── <possible/partition/1>
├── ...
├── <possible/partition/K>
├── set_lists
├── <subset_base_name_0>.sqlite
├── <subset_base_name_1>.sqlite
├── ...
├── <subset_base_name_2>.sqlite
├── eval_batches
│ ├── <eval_batches_base_name_0>.json
│ ├── <eval_batches_base_name_1>.json
│ ├── ...
│ ├── <eval_batches_base_name_M>.json
The dataset contains sequences named `<sequence_name_i>` that may be partitioned by
directories such as `<possible/partition/0>` e.g. representing categories but they
can also be stored in a flat structure. Each sequence folder contains the list of
sequence images, depth maps, foreground masks, and valid-depth masks
`images`, `depths`, `masks`, and `depth_masks` respectively. Furthermore,
`set_lists/` dirtectories (with partitions or global) store json or sqlite files
`<subset_base_name_l>.<ext>`, each describing a certain sequence subset.
These subset path conventions are not hard-coded and arbitrary relative path can be
specified by setting `self.subset_lists_path` to the relative path w.r.t.
dataset root.
Each `<subset_base_name_l>.json` file contains the following dictionary::
{
"train": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
"val": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
"test": [
(sequence_name: str, frame_number: int, image_path: str),
...
],
]
defining the list of frames (identified with their `sequence_name` and
`frame_number`) in the "train", "val", and "test" subsets of the dataset. In case of
SQLite format, `<subset_base_name_l>.sqlite` contains a table with the header::
| sequence_name | frame_number | image_path | subset |
Note that `frame_number` can be obtained only from the metadata and
does not necesarrily correspond to the numeric suffix of the corresponding image
file name (e.g. a file `<partition_0>/<sequence_name_0>/images/frame00005.jpg` can
have its frame number set to `20`, not 5).
Each `<eval_batches_base_name_M>.json` file contains a list of evaluation examples
in the following form::
[
[ # batch 1
(sequence_name: str, frame_number: int, image_path: str),
...
],
[ # batch 2
(sequence_name: str, frame_number: int, image_path: str),
...
],
]
Note that the evaluation examples always come from the `"test"` subset of the dataset.
(test frames can repeat across batches). The batches can contain single element,
which is typical in case of regular radiance field fitting.
Args:
subset_lists_path: The relative path to the dataset subset definition.
For CO3D, these include e.g. "skateboard/set_lists/set_lists_manyview_dev_0.json".
By default (None), dataset is not partitioned to subsets (in that case, setting
`ignore_subsets` will speed up construction)
dataset_root: The root folder of the dataset.
metadata_basename: name of the SQL metadata file in dataset_root;
not expected to be changed by users
test_on_train: Construct validation and test datasets from
the training subset; note that in practice, in this
case all subset dataset objects will be same
only_test_set: Load only the test set. Incompatible with `test_on_train`.
ignore_subsets: Dont filter by subsets in the dataset; note that in this
case all subset datasets will be same
eval_batch_num_training_frames: Add a certain number of training frames to each
eval batch. Useful for evaluating models that require
source views as input (e.g. NeRF-WCE / PixelNeRF).
dataset_args: Specifies additional arguments to the
JsonIndexDataset constructor call.
path_manager_factory: (Optional) An object that generates an instance of
PathManager that can translate provided file paths.
path_manager_factory_class_type: The class type of `path_manager_factory`.
"""
category: Optional[str] = None
subset_list_name: Optional[str] = None # TODO: docs
# OR
subset_lists_path: Optional[str] = None
eval_batches_path: Optional[str] = None
dataset_root: str = _CO3D_SQL_DATASET_ROOT
metadata_basename: str = "metadata.sqlite"
test_on_train: bool = False
only_test_set: bool = False
ignore_subsets: bool = False
train_subsets: Tuple[str, ...] = ("train",)
val_subsets: Tuple[str, ...] = ("val",)
test_subsets: Tuple[str, ...] = ("test",)
eval_batch_num_training_frames: int = 0
# this is a mould that is never constructed, used to build self._dataset_map values
dataset_class_type: str = "SqlIndexDataset"
dataset: SqlIndexDataset
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self):
super().__init__()
run_auto_creation(self)
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
if self.ignore_subsets and not self.only_test_set:
self.test_on_train = True # no point in loading same data 3 times
path_manager = self.path_manager_factory.get()
sqlite_metadata_file = os.path.join(self.dataset_root, self.metadata_basename)
sqlite_metadata_file = _local_path(path_manager, sqlite_metadata_file)
if not os.path.isfile(sqlite_metadata_file):
# The sqlite_metadata_file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
f"Looking for frame annotations in {sqlite_metadata_file}."
+ " Please specify a correct dataset_root folder."
+ " Note: By default the root folder is taken from the"
+ " CO3D_SQL_DATASET_ROOT environment variable."
)
if self.subset_lists_path and self.subset_list_name:
raise ValueError(
"subset_lists_path and subset_list_name cannot be both set"
)
subset_lists_file = self._get_lists_file("set_lists")
# setup the common dataset arguments
common_dataset_kwargs = {
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
"sqlite_metadata_file": sqlite_metadata_file,
"dataset_root": self.dataset_root,
"subset_lists_file": subset_lists_file,
"path_manager": path_manager,
}
if self.category:
logger.info(f"Forcing category filter in the datasets to {self.category}")
common_dataset_kwargs["pick_categories"] = self.category.split(",")
# get the used dataset type
dataset_type: Type[SqlIndexDataset] = registry.get(
SqlIndexDataset, self.dataset_class_type
)
expand_args_fields(dataset_type)
if subset_lists_file is not None and not os.path.isfile(subset_lists_file):
available_subsets = self._get_available_subsets(
OmegaConf.to_object(common_dataset_kwargs["pick_categories"])
)
msg = f"Cannot find subset list file {self.subset_lists_path}."
if available_subsets:
msg += f" Some of the available subsets: {str(available_subsets)}."
raise ValueError(msg)
train_dataset = None
val_dataset = None
if not self.only_test_set:
# load the training set
logger.debug("Constructing train dataset.")
train_dataset = dataset_type(
**common_dataset_kwargs, subsets=self._get_subsets(self.train_subsets)
)
logger.info(f"Train dataset: {str(train_dataset)}")
if self.test_on_train:
assert train_dataset is not None
val_dataset = test_dataset = train_dataset
else:
# load the val and test sets
if not self.only_test_set:
# NOTE: this is always loaded in JsonProviderV2
logger.debug("Extracting val dataset.")
val_dataset = dataset_type(
**common_dataset_kwargs, subsets=self._get_subsets(self.val_subsets)
)
logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.")
eval_batches_file = self._get_lists_file("eval_batches")
del common_dataset_kwargs["eval_batches_file"]
test_dataset = dataset_type(
**common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True),
eval_batches_file=eval_batches_file,
)
logger.info(f"Test dataset: {str(test_dataset)}")
if (
eval_batches_file is not None
and self.eval_batch_num_training_frames > 0
):
self._extend_eval_batches(test_dataset)
self._dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
def _get_subsets(self, subsets, is_eval: bool = False):
if self.ignore_subsets:
return None
if is_eval and self.eval_batch_num_training_frames > 0:
# we will need to have training frames for extended batches
return list(subsets) + list(self.train_subsets)
return subsets
def _extend_eval_batches(self, test_dataset: SqlIndexDataset) -> None:
rng = np.random.default_rng(seed=0)
eval_batches = test_dataset.get_eval_batches()
if eval_batches is None:
raise ValueError("Eval batches were not loaded!")
for batch in eval_batches:
sequence = batch[0][0]
seq_frames = list(
test_dataset.sequence_frames_in_order(sequence, self.train_subsets)
)
idx_to_add = rng.permutation(len(seq_frames))[
: self.eval_batch_num_training_frames
]
batch.extend((sequence, seq_frames[a][1]) for a in idx_to_add)
@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
"""
Called by get_default_args.
Certain fields are not exposed on each dataset class
but rather are controlled by this provider class.
"""
for key in _NEED_CONTROL:
del args[key]
def create_dataset(self):
# No `dataset` member of this class is created.
# The dataset(s) live in `self.get_dataset_map`.
pass
def get_dataset_map(self) -> DatasetMap:
return self._dataset_map # pyre-ignore [16]
def _get_available_subsets(self, categories: List[str]):
"""
Get the available subset names for a given category folder (if given) inside
a root dataset folder `dataset_root`.
"""
path_manager = self.path_manager_factory.get()
subsets: List[str] = []
for prefix in [""] + categories:
set_list_dir = os.path.join(self.dataset_root, prefix, "set_lists")
if not (
(path_manager is not None) and path_manager.isdir(set_list_dir)
) and not os.path.isdir(set_list_dir):
continue
set_list_files = (os.listdir if path_manager is None else path_manager.ls)(
set_list_dir
)
subsets.extend(os.path.join(prefix, "set_lists", f) for f in set_list_files)
return subsets
def _get_lists_file(self, flavor: str) -> Optional[str]:
if flavor == "eval_batches":
subset_lists_path = self.eval_batches_path
else:
subset_lists_path = self.subset_lists_path
if not subset_lists_path and not self.subset_list_name:
return None
category_elem = ""
if self.category and "," not in self.category:
# if multiple categories are given, looking for global set lists
category_elem = self.category
subset_lists_path = subset_lists_path or (
os.path.join(
category_elem, f"{flavor}", f"{flavor}_{self.subset_list_name}"
)
)
assert subset_lists_path
path_manager = self.path_manager_factory.get()
# try absolute path first
subset_lists_file = _get_local_path_check_extensions(
subset_lists_path, path_manager
)
if subset_lists_file:
return subset_lists_file
full_path = os.path.join(self.dataset_root, subset_lists_path)
subset_lists_file = _get_local_path_check_extensions(full_path, path_manager)
if not subset_lists_file:
raise FileNotFoundError(
f"Subset lists path given but not found: {full_path}"
)
return subset_lists_file
def _get_local_path_check_extensions(
path, path_manager, extensions=("", ".sqlite", ".json")
) -> Optional[str]:
for ext in extensions:
local = _local_path(path_manager, path + ext)
if os.path.isfile(local):
return local
return None
def _local_path(path_manager, path: str) -> str:
if path_manager is None:
return path
return path_manager.get_local_path(path)

View File

@@ -0,0 +1,189 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DataLoaderMap,
SceneBatchSampler,
SequenceDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
# TODO: we can merge it with SequenceDataLoaderMapProvider in PyTorch3D
# and support both eval_batches protocols
@registry.register
class TrainEvalDataLoaderMapProvider(SequenceDataLoaderMapProvider):
"""
Implementation of DataLoaderMapProviderBase that may use internal eval batches for
the test dataset. In particular, if `eval_batches_relpath` is set, it loads
eval batches from that json file, otherwise test set is treated in the same way as
train and val, i.e. the parameters `dataset_length_test` and `test_conditioning_type`
are respected.
If conditioning is not required, then the batch size should
be set as 1, and most of the fields do not matter.
If conditioning is required, each batch will contain one main
frame first to predict and the, rest of the elements are for
conditioning.
If images_per_seq_options is left empty, the conditioning
frames are picked according to the conditioning type given.
This does not have regard to the order of frames in a
scene, or which frames belong to what scene.
If images_per_seq_options is given, then the conditioning types
must be SAME and the remaining fields are used.
Members:
batch_size: The size of the batch of the data loader.
num_workers: Number of data-loading threads in each data loader.
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
an epoch is the length of the training set.
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
an epoch is the length of the validation set.
dataset_length_test: used if test_dataset.eval_batches is NOT set. The number of
batches in a testing epoch. Or 0 to mean an epoch is the length of the test
set.
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
value. Empty (the default) means that we are not careful about which frames
come from which scene.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
"""
batch_size: int = 1
num_workers: int = 0
dataset_length_train: int = 0
dataset_length_val: int = 0
dataset_length_test: int = 0
images_per_seq_options: Tuple[int, ...] = ()
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
def __post_init__(self):
run_auto_creation(self)
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
train = self._make_generic_data_loader(
datasets.train,
self.dataset_length_train,
datasets.train,
)
val = self._make_generic_data_loader(
datasets.val,
self.dataset_length_val,
datasets.train,
)
if datasets.test is not None and datasets.test.get_eval_batches() is not None:
test = self._make_eval_data_loader(datasets.test)
else:
test = self._make_generic_data_loader(
datasets.test,
self.dataset_length_test,
datasets.train,
)
return DataLoaderMap(train=train, val=val, test=test)
def _make_eval_data_loader(
self,
dataset: Optional[DatasetBase],
) -> Optional[DataLoader[FrameData]]:
if dataset is None:
return None
return DataLoader(
dataset,
batch_sampler=dataset.get_eval_batches(),
**self._get_data_loader_common_kwargs(dataset),
)
def _make_generic_data_loader(
self,
dataset: Optional[DatasetBase],
num_batches: int,
train_dataset: Optional[DatasetBase],
) -> Optional[DataLoader[FrameData]]:
"""
Returns the dataloader for a dataset.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
train_dataset: the training dataset, used if conditioning_type==TRAIN
conditioning_type: source for padding of batches
"""
if dataset is None:
return None
data_loader_kwargs = self._get_data_loader_common_kwargs(dataset)
if len(self.images_per_seq_options) > 0:
# this is a typical few-view setup
# conditioning comes from the same subset since subsets are split by seqs
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=self.images_per_seq_options,
sample_consecutive_frames=self.sample_consecutive_frames,
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
if self.batch_size == 1:
# this is a typical many-view setup (without conditioning)
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
# edge case: conditioning on train subset, typical for Nerformer-like many-view
# there is only one sequence in all datasets, so we condition on another subset
return self._train_loader(
dataset, train_dataset, num_batches, data_loader_kwargs
)
def _get_data_loader_common_kwargs(self, dataset: DatasetBase) -> Dict[str, Any]:
return {
"num_workers": self.num_workers,
"collate_fn": dataset.frame_data_type.collate,
}

View File

@@ -204,7 +204,7 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
# otherwise, we dispatch by the type of the provided annotation to convert to
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls._field_types.values()
types = cls.__annotations__.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp)
@@ -270,7 +270,7 @@ def _dataclass_from_dict(d, typeannot):
cls = get_origin(typeannot) or typeannot
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
types = cls._field_types.values()
types = cls.__annotations__.values()
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
elif issubclass(cls, (list, tuple)):
types = get_args(typeannot)

View File

@@ -38,8 +38,8 @@ class _Visualizer:
image_render: torch.Tensor
image_rgb_masked: torch.Tensor
depth_render: torch.Tensor
depth_map: torch.Tensor
depth_mask: torch.Tensor
depth_map: Optional[torch.Tensor]
depth_mask: Optional[torch.Tensor]
visdom_env: str = "eval_debug"
@@ -75,9 +75,11 @@ class _Visualizer:
viz = self._viz
viz.images(
torch.cat(
(
make_depth_image(self.depth_render, loss_mask_now),
make_depth_image(self.depth_map, loss_mask_now),
(make_depth_image(self.depth_render, loss_mask_now),)
+ (
(make_depth_image(self.depth_map, loss_mask_now),)
if self.depth_map is not None
else ()
),
dim=3,
),
@@ -91,12 +93,13 @@ class _Visualizer:
win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
)
viz.images(
self.depth_mask,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
)
if self.depth_mask is not None:
viz.images(
self.depth_mask,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
)
# show the 3D plot
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
@@ -104,29 +107,30 @@ class _Visualizer:
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
loss_mask_now.device
)
pcl_pred = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_render,
self.depth_render,
# mask_crop,
torch.ones_like(self.depth_render),
# loss_mask_now,
)
pcl_gt = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
# mask_crop,
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = {
pn: p
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
if int(p.num_points_per_cloud()) > 0
"pred_depth": get_rgbd_point_cloud(
viewpoint_trivial,
self.image_render,
self.depth_render,
# mask_crop,
torch.ones_like(self.depth_render),
# loss_mask_now,
)
}
if self.depth_map is not None:
_pcls["gt_depth"] = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
# mask_crop,
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = {pn: p for pn, p in _pcls.items() if int(p.num_points_per_cloud()) > 0}
plotlyplot = plot_scene(
{f"pcl{name_postfix}": _pcls},
{f"pcl{name_postfix}": _pcls}, # pyre-ignore
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1,
@@ -277,10 +281,10 @@ def eval_batch(
image_render=image_render,
image_rgb_masked=image_rgb_masked,
depth_render=cloned_render["depth_render"],
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map,
depth_mask=frame_data.depth_mask[:1],
depth_mask=frame_data.depth_mask[:1]
if frame_data.depth_mask is not None
else None,
visdom_env=visualize_visdom_env,
)

View File

@@ -57,6 +57,7 @@ class ImplicitronEvaluator(EvaluatorBase):
def __post_init__(self):
run_auto_creation(self)
# pyre-fixme[14]: `run` overrides method defined in `EvaluatorBase` inconsistently.
def run(
self,
model: ImplicitronModelBase,

View File

@@ -360,7 +360,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
mask_crop: A binary tensor of shape `(B, 1, H, W)` denoting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.

View File

@@ -41,6 +41,8 @@ class ModelDBIR(ImplicitronModelBase):
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
max_points: int = -1
# pyre-fixme[14]: `forward` overrides method defined in `ImplicitronModelBase`
# inconsistently.
def forward(
self,
*, # force keyword-only arguments

View File

@@ -111,6 +111,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
"minimum": lambda curr, acc: torch.minimum(curr, acc),
}[self.weight_function_type]
# pyre-fixme[14]: `forward` overrides method defined in `RaymarcherBase`
# inconsistently.
def forward(
self,
rays_densities: torch.Tensor,

View File

@@ -393,7 +393,7 @@ class _GLTFLoader:
attributes = primitive["attributes"]
vertex_colors = self._get_primitive_attribute(attributes, "COLOR_0", np.float32)
if vertex_colors is not None:
return TexturesVertex(torch.from_numpy(vertex_colors))
return TexturesVertex([torch.from_numpy(vertex_colors)])
vertex_texcoords_0 = self._get_primitive_attribute(
attributes, "TEXCOORD_0", np.float32
@@ -559,12 +559,26 @@ class _GLTFWriter:
meshes = defaultdict(list)
# pyre-fixme[6]: Incompatible parameter type
meshes["name"] = "Node-Mesh"
primitives = {
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
"indices": 1,
"material": 0, # default material
"mode": _PrimitiveMode.TRIANGLES,
}
if isinstance(self.mesh.textures, TexturesVertex):
primitives = {
"attributes": {"POSITION": 0, "COLOR_0": 2},
"indices": 1,
"mode": _PrimitiveMode.TRIANGLES,
}
elif isinstance(self.mesh.textures, TexturesUV):
primitives = {
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
"indices": 1,
"mode": _PrimitiveMode.TRIANGLES,
"material": 0,
}
else:
primitives = {
"attributes": {"POSITION": 0},
"indices": 1,
"mode": _PrimitiveMode.TRIANGLES,
}
meshes["primitives"].append(primitives)
self._json_data["meshes"].append(meshes)
@@ -610,6 +624,14 @@ class _GLTFWriter:
element_min = list(map(float, np.min(data, axis=0)))
element_max = list(map(float, np.max(data, axis=0)))
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
elif key == "texvertices":
component_type = _ComponentType.FLOAT
data = self.mesh.textures.verts_features_list()[0].cpu().numpy()
element_type = "VEC3"
buffer_view = 2
element_min = list(map(float, np.min(data, axis=0)))
element_max = list(map(float, np.max(data, axis=0)))
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
elif key == "indices":
component_type = _ComponentType.UNSIGNED_SHORT
data = (
@@ -646,8 +668,10 @@ class _GLTFWriter:
return (byte_length, data)
def _write_bufferview(self, key: str, **kwargs):
if key not in ["positions", "texcoords", "indices"]:
raise ValueError("key must be one of positions, texcoords or indices")
if key not in ["positions", "texcoords", "texvertices", "indices"]:
raise ValueError(
"key must be one of positions, texcoords, texvertices or indices"
)
bufferview = {
"name": "bufferView_%s" % key,
@@ -661,6 +685,10 @@ class _GLTFWriter:
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
target = _TargetType.ARRAY_BUFFER
bufferview["byteStride"] = int(byte_per_element)
elif key == "texvertices":
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
target = _TargetType.ELEMENT_ARRAY_BUFFER
bufferview["byteStride"] = int(byte_per_element)
elif key == "indices":
byte_per_element = (
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
@@ -701,12 +729,15 @@ class _GLTFWriter:
pos_byte, pos_data = self._write_accessor_json("positions")
idx_byte, idx_data = self._write_accessor_json("indices")
include_textures = False
if (
self.mesh.textures is not None
and self.mesh.textures.verts_uvs_list()[0] is not None
):
tex_byte, tex_data = self._write_accessor_json("texcoords")
include_textures = True
if self.mesh.textures is not None:
if hasattr(self.mesh.textures, "verts_features_list"):
tex_byte, tex_data = self._write_accessor_json("texvertices")
include_textures = True
texcoords = False
elif self.mesh.textures.verts_uvs_list()[0] is not None:
tex_byte, tex_data = self._write_accessor_json("texcoords")
include_textures = True
texcoords = True
# bufferViews for positions, texture coords and indices
byte_offset = 0
@@ -717,17 +748,19 @@ class _GLTFWriter:
byte_offset += idx_byte
if include_textures:
self._write_bufferview(
"texcoords", byte_length=tex_byte, offset=byte_offset
)
if texcoords:
self._write_bufferview(
"texcoords", byte_length=tex_byte, offset=byte_offset
)
else:
self._write_bufferview(
"texvertices", byte_length=tex_byte, offset=byte_offset
)
byte_offset += tex_byte
# image bufferView
include_image = False
if (
self.mesh.textures is not None
and self.mesh.textures.maps_list()[0] is not None
):
if self.mesh.textures is not None and hasattr(self.mesh.textures, "maps_list"):
include_image = True
image_byte, image_data = self._write_image_buffer(offset=byte_offset)
byte_offset += image_byte

View File

@@ -684,6 +684,8 @@ def save_obj(
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
*,
normals: Optional[torch.Tensor] = None,
faces_normals_idx: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None,
@@ -698,6 +700,10 @@ def save_obj(
decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if
it is a str.
normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx
to index into.
faces_normals_idx: LongTensor of shape (F, 3) giving the index into
normals for each vertex in the face.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face.
@@ -713,6 +719,22 @@ def save_obj(
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if (normals is None) != (faces_normals_idx is None):
message = "'normals' and 'faces_normals_idx' must both be None or neither."
raise ValueError(message)
if faces_normals_idx is not None and (
faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3
):
message = (
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
)
raise ValueError(message)
if normals is not None and (normals.dim() != 2 or normals.size(1) != 3):
message = "'normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
@@ -742,9 +764,12 @@ def save_obj(
verts,
faces,
decimal_places,
normals=normals,
faces_normals_idx=faces_normals_idx,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
save_texture=save_texture,
save_normals=normals is not None,
)
# Save the .mtl and .png files associated with the texture
@@ -777,9 +802,12 @@ def _save(
faces,
decimal_places: Optional[int] = None,
*,
normals: Optional[torch.Tensor] = None,
faces_normals_idx: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
save_texture: bool = False,
save_normals: bool = False,
) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
@@ -798,18 +826,26 @@ def _save(
lines = ""
if len(verts):
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
if len(verts):
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
if save_normals:
assert normals is not None
assert faces_normals_idx is not None
lines += _write_normals(normals, faces_normals_idx, float_str)
if save_texture:
assert faces_uvs is not None
assert verts_uvs is not None
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
@@ -818,7 +854,6 @@ def _save(
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)
# pyre-fixme[16] # undefined attribute cpu
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
# Save verts uvs after verts
@@ -828,25 +863,77 @@ def _save(
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
lines += "vt %s\n" % " ".join(uv)
f.write(lines)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
if len(faces):
F, P = faces.shape
for i in range(F):
if save_texture:
# Format faces as {verts_idx}/{verts_uvs_idx}
_write_faces(
f,
faces,
faces_uvs if save_texture else None,
faces_normals_idx if save_normals else None,
)
def _write_normals(
normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str
) -> str:
if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3:
message = (
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
)
raise ValueError(message)
if normals.dim() != 2 or normals.size(1) != 3:
message = "'normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu()
lines = []
V, D = normals.shape
for i in range(V):
normal = [float_str % normals[i, j] for j in range(D)]
lines.append("vn %s\n" % " ".join(normal))
return "".join(lines)
def _write_faces(
f,
faces: torch.Tensor,
faces_uvs: Optional[torch.Tensor],
faces_normals_idx: Optional[torch.Tensor],
) -> None:
F, P = faces.shape
for i in range(F):
if faces_normals_idx is not None:
if faces_uvs is not None:
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
face = [
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
"%d/%d/%d"
% (
faces[i, j] + 1,
faces_uvs[i, j] + 1,
faces_normals_idx[i, j] + 1,
)
for j in range(P)
]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
# Format faces as {verts_idx}//{verts_normals_idx}
face = [
"%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1)
for j in range(P)
]
elif faces_uvs is not None:
# Format faces as {verts_idx}/{verts_uvs_idx}
face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
f.write(lines)
if i + 1 < F:
f.write("f %s\n" % " ".join(face))
else:
# No newline at the end of the file.
f.write("f %s" % " ".join(face))

View File

@@ -375,14 +375,14 @@ class CamerasBase(TensorProperties):
raise NotImplementedError()
def get_znear(self):
return self.znear if hasattr(self, "znear") else None
return getattr(self, "znear", None)
def get_image_size(self):
"""
Returns the image size, if provided, expected in the form of (height, width)
The image size is used for conversion of projected points to screen coordinates.
"""
return self.image_size if hasattr(self, "image_size") else None
return getattr(self, "image_size", None)
def __getitem__(
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]

View File

@@ -109,17 +109,11 @@ class MultinomialRaysampler(torch.nn.Module):
self._stratified_sampling = stratified_sampling
# get the initial grid of image xy coords
_xy_grid = torch.stack(
tuple(
reversed(
meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
)
),
dim=-1,
y, x = meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
_xy_grid = torch.stack([x, y], dim=-1)
self.register_buffer("_xy_grid", _xy_grid, persistent=False)

View File

@@ -491,6 +491,8 @@ class TexturesAtlas(TexturesBase):
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
return new_tex
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
# inconsistently.
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
"""
This is similar to a nearest neighbor sampling and involves a
@@ -927,6 +929,8 @@ class TexturesUV(TexturesBase):
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
return new_tex
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
# inconsistently.
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
"""
Interpolate a 2D texture map using uv vertex texture coordinates for each
@@ -1450,6 +1454,8 @@ class TexturesVertex(TexturesBase):
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
return new_tex
# pyre-fixme[14]: `sample_textures` overrides method defined in `TexturesBase`
# inconsistently.
def sample_textures(self, fragments, faces_packed=None) -> torch.Tensor:
"""
Determine the color for each rasterized face. Interpolate the colors for

View File

@@ -164,6 +164,7 @@ setup(
"tqdm>4.29.0",
"matplotlib",
"accelerate",
"sqlalchemy>=2.0",
],
},
entry_points={

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
import torch
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
SequenceDataLoaderMapProvider,
SimpleDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
SqlIndexDatasetMapProvider,
)
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
TrainEvalDataLoaderMapProvider,
)
from pytorch3d.implicitron.tools.config import get_default_args
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
class TestCo3dSqlDataSource(unittest.TestCase):
def test_no_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.ignore_subsets = True
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.pick_categories = ["skateboard"]
dataset_args.limit_sequences_to = 1
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 202)
for frame in data_loaders.train:
self.assertIsNone(frame.frame_type)
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = (
"skateboard/set_lists/set_lists_manyview_dev_0.json"
)
# this will naturally limit to one sequence (no need to limit by cat/sequence)
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
for sampler_type in [
"SimpleDataLoaderMapProvider",
"SequenceDataLoaderMapProvider",
"TrainEvalDataLoaderMapProvider",
]:
args.data_loader_map_provider_class_type = sampler_type
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 100)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
# check loading blobs
self.assertEqual(frame.image_rgb.shape[-1], 800)
break
def test_sql_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
for sampler_type in [
"SimpleDataLoaderMapProvider",
"SequenceDataLoaderMapProvider",
"TrainEvalDataLoaderMapProvider",
]:
args.data_loader_map_provider_class_type = sampler_type
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 100)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(
frame.image_rgb.shape[-1], 800
) # check loading blobs
break
@unittest.skip("It takes 75 seconds; skipping by default")
def test_huge_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 3158974)
self.assertEqual(len(data_loaders.val), 518417)
self.assertEqual(len(data_loaders.test), 518417)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_broken_subsets(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "et_non_est"
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
with self.assertRaises(FileNotFoundError) as err:
ImplicitronDataSource(**args)
# check the hint text
self.assertIn("Subset lists path given but not found", str(err.exception))
def test_eval_batches(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
provider_args.eval_batches_path = (
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
)
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 50)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_eval_batches_from_subset_list_name(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_list_name = "manyview_dev_0"
provider_args.category = "skateboard"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
data_source = ImplicitronDataSource(**args)
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
self.assertEqual(len(data_loaders.train), 102)
self.assertEqual(len(data_loaders.val), 100)
self.assertEqual(len(data_loaders.test), 50)
for split in ["train", "val", "test"]:
for frame in data_loaders[split]:
self.assertEqual(frame.frame_type, [split])
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
break
def test_frame_access(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
dataset_args = provider_args.dataset_SqlIndexDataset_args
dataset_args.remove_empty_masks = True
dataset_args.pick_categories = ["skateboard"]
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
frame_builder_args.load_point_clouds = True
frame_builder_args.box_crop = False # required for .meta
data_source = ImplicitronDataSource(**args)
dataset_map, _ = data_source.get_datasets_and_dataloaders()
dataset = dataset_map["train"]
for idx in [10, ("245_26182_52130", 22)]:
example_meta = dataset.meta[idx]
example = dataset[idx]
self.assertIsNone(example_meta.image_rgb)
self.assertIsNone(example_meta.fg_probability)
self.assertIsNone(example_meta.depth_map)
self.assertIsNone(example_meta.sequence_point_cloud)
self.assertIsNotNone(example_meta.camera)
self.assertIsNotNone(example.image_rgb)
self.assertIsNotNone(example.fg_probability)
self.assertIsNotNone(example.depth_map)
self.assertIsNotNone(example.sequence_point_cloud)
self.assertIsNotNone(example.camera)
self.assertEqual(example_meta.sequence_name, example.sequence_name)
self.assertEqual(example_meta.frame_number, example.frame_number)
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
self.assertEqual(example_meta.sequence_category, example.sequence_category)
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
torch.testing.assert_close(
example_meta.camera.focal_length, example.camera.focal_length
)
torch.testing.assert_close(
example_meta.camera.principal_point, example.camera.principal_point
)

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import os
import unittest
import unittest.mock
@@ -18,6 +19,7 @@ from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import get_default_args
from tests.common_testing import get_tests_dir
from tests.implicitron.common_resources import get_skateboard_data
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
@@ -28,6 +30,12 @@ class TestDataSource(unittest.TestCase):
self.maxDiff = None
torch.manual_seed(42)
stack = contextlib.ExitStack()
self.dataset_root, self.path_manager = stack.enter_context(
get_skateboard_data()
)
self.addCleanup(stack.close)
def _test_omegaconf_generic_failure(self):
# OmegaConf possible bug - this is why we need _GenericWorkaround
from dataclasses import dataclass
@@ -56,12 +64,14 @@ class TestDataSource(unittest.TestCase):
get_default_args(JsonIndexDataset)
def test_one(self):
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
cfg = get_default_args(ImplicitronDataSource)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG:
(DATA_DIR / "data_source.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
cfg = get_default_args(ImplicitronDataSource)
# making the test invariant to env variables
cfg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
cfg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG:
(DATA_DIR / "data_source.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
def test_default(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
@@ -73,7 +83,7 @@ class TestDataSource(unittest.TestCase):
dataset_args.test_restrict_sequence_id = 0
dataset_args.n_frames_per_sequence = -1
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.dataset_root = self.dataset_root
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
@@ -96,7 +106,7 @@ class TestDataSource(unittest.TestCase):
dataset_args.test_restrict_sequence_id = 0
dataset_args.n_frames_per_sequence = -1
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.dataset_root = self.dataset_root
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(

View File

@@ -26,6 +26,8 @@ from tests.common_testing import interactive_testing_requested
from .common_resources import get_skateboard_data
VISDOM_PORT = int(os.environ.get("VISDOM_PORT", 8097))
class TestDatasetVisualize(unittest.TestCase):
def setUp(self):
@@ -77,7 +79,7 @@ class TestDatasetVisualize(unittest.TestCase):
for k, dataset in self.datasets.items()
}
)
self.visdom = Visdom()
self.visdom = Visdom(port=VISDOM_PORT)
if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None

View File

@@ -0,0 +1,230 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import logging
import os
import tempfile
import unittest
from typing import ClassVar, Optional, Type
import pandas as pd
import pkg_resources
import sqlalchemy as sa
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder
from pytorch3d.implicitron.dataset.orm_types import (
SqlFrameAnnotation,
SqlSequenceAnnotation,
)
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
from pytorch3d.implicitron.tools.config import registry
from sqlalchemy.orm import composite, Mapped, mapped_column, Session
NO_BLOBS_KWARGS = {
"dataset_root": "",
"load_images": False,
"load_depths": False,
"load_masks": False,
"load_depth_masks": False,
"box_crop": False,
}
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
@dataclasses.dataclass
class MagneticFieldAnnotation:
path: str
average_flux_density: Optional[float] = None
class ExtendedSqlFrameAnnotation(SqlFrameAnnotation):
num_dogs: Mapped[Optional[int]] = mapped_column(default=None)
magnetic_field: Mapped[MagneticFieldAnnotation] = composite(
mapped_column("_magnetic_field_path", nullable=True),
mapped_column("_magnetic_field_average_flux_density", nullable=True),
default_factory=lambda: None,
)
class ExtendedSqlIndexDataset(SqlIndexDataset):
frame_annotations_type: ClassVar[
Type[SqlFrameAnnotation]
] = ExtendedSqlFrameAnnotation
class CanineFrameData(FrameData):
num_dogs: Optional[int] = None
magnetic_field_average_flux_density: Optional[float] = None
@registry.register
class CanineFrameDataBuilder(
GenericWorkaround, GenericFrameDataBuilder[CanineFrameData]
):
"""
A concrete class to build an extended FrameData object
"""
frame_data_type: ClassVar[Type[FrameData]] = CanineFrameData
def build(
self,
frame_annotation: ExtendedSqlFrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
load_blobs: bool = True,
) -> CanineFrameData:
frame_data = super().build(frame_annotation, sequence_annotation, load_blobs)
frame_data.num_dogs = frame_annotation.num_dogs or 101
frame_data.magnetic_field_average_flux_density = (
frame_annotation.magnetic_field.average_flux_density
)
return frame_data
class CanineSqlIndexDataset(SqlIndexDataset):
frame_annotations_type: ClassVar[
Type[SqlFrameAnnotation]
] = ExtendedSqlFrameAnnotation
frame_data_builder_class_type: str = "CanineFrameDataBuilder"
class TestExtendingOrmTypes(unittest.TestCase):
def setUp(self):
# create a temporary copy of the DB with an extended schema
engine = sa.create_engine(f"sqlite:///{METADATA_FILE}")
with Session(engine) as session:
extended_annots = [
ExtendedSqlFrameAnnotation(
**{
k: v
for k, v in frame_annot.__dict__.items()
if not k.startswith("_") # remove mapped fields and SA metadata
}
)
for frame_annot in session.scalars(sa.select(SqlFrameAnnotation))
]
seq_annots = session.scalars(
sa.select(SqlSequenceAnnotation),
execution_options={"prebuffer_rows": True},
)
session.expunge_all()
self._temp_db = tempfile.NamedTemporaryFile(delete=False)
engine_ext = sa.create_engine(f"sqlite:///{self._temp_db.name}")
ExtendedSqlFrameAnnotation.metadata.create_all(engine_ext, checkfirst=True)
with Session(engine_ext, expire_on_commit=False) as session_ext:
session_ext.add_all(extended_annots)
for instance in seq_annots:
session_ext.merge(instance)
session_ext.commit()
# check the setup is correct
with engine_ext.connect() as connection_ext:
df = pd.read_sql_query(
sa.select(ExtendedSqlFrameAnnotation), connection_ext
)
self.assertEqual(len(df), 100)
self.assertIn("_magnetic_field_average_flux_density", df.columns)
df_seq = pd.read_sql_query(sa.select(SqlSequenceAnnotation), connection_ext)
self.assertEqual(len(df_seq), 10)
def tearDown(self):
self._temp_db.close()
os.remove(self._temp_db.name)
def test_basic(self, sequence="cat1_seq2", frame_number=4):
dataset = ExtendedSqlIndexDataset(
sqlite_metadata_file=self._temp_db.name,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
# test indexing
with self.assertRaises(IndexError):
dataset[len(dataset) + 1]
# test sequence-frame indexing
item = dataset[sequence, frame_number]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, frame_number)
with self.assertRaises(IndexError):
dataset[sequence, 13]
def test_extending_frame_data(self, sequence="cat1_seq2", frame_number=4):
dataset = CanineSqlIndexDataset(
sqlite_metadata_file=self._temp_db.name,
remove_empty_masks=False,
frame_data_builder_CanineFrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
self.assertIsInstance(item, CanineFrameData)
self.assertEqual(item.num_dogs, 101)
self.assertIsNone(item.magnetic_field_average_flux_density)
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
# test indexing
with self.assertRaises(IndexError):
dataset[len(dataset) + 1]
# test sequence-frame indexing
item = dataset[sequence, frame_number]
self.assertIsInstance(item, CanineFrameData)
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, frame_number)
self.assertEqual(item.num_dogs, 101)
with self.assertRaises(IndexError):
dataset[sequence, 13]

View File

@@ -17,6 +17,7 @@ from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
from pytorch3d.implicitron.dataset.utils import (
get_bbox_from_mask,
load_16big_png_depth,
load_1bit_png_mask,
load_depth,
@@ -107,11 +108,14 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
)
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
(
self.frame_data.fg_probability,
self.frame_data.mask_path,
self.frame_data.bbox_xywh,
) = self.frame_data_builder._load_fg_probability(self.frame_annotation)
fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
self.frame_annotation
)
self.frame_data.mask_path = mask_path
self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
mask_thr = self.frame_data_builder.box_crop_mask_thr
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
self.assertIsNotNone(self.frame_data.mask_path)
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))

View File

@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
class TestOrmTypes(unittest.TestCase):
def test_tuple_serialization_none(self):
ttype = TupleTypeFactory()()
output = ttype.process_bind_param(None, None)
self.assertIsNone(output)
output = ttype.process_result_value(output, None)
self.assertIsNone(output)
def test_tuple_serialization_1d(self):
for input_tuple in [(1, 2, 3), (4.5, 6.7)]:
ttype = TupleTypeFactory(type(input_tuple[0]), (len(input_tuple),))()
output = ttype.process_bind_param(input_tuple, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(type(input_hat[0]), type(input_tuple[0]))
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
def test_tuple_serialization_2d(self):
input_tuple = ((1.0, 2.0, 3.0), (4.5, 5.5, 6.6))
ttype = TupleTypeFactory(type(input_tuple[0][0]), (2, 3))()
output = ttype.process_bind_param(input_tuple, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
# we use float32 to serialise
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)

View File

@@ -0,0 +1,522 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
from collections import Counter
import pkg_resources
import torch
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
NO_BLOBS_KWARGS = {
"dataset_root": "",
"load_images": False,
"load_depths": False,
"load_masks": False,
"load_depth_masks": False,
"box_crop": False,
}
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
sh = logging.StreamHandler()
logger.addHandler(sh)
logger.setLevel(logging.DEBUG)
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json")
class TestSqlDataset(unittest.TestCase):
def test_basic(self, sequence="cat1_seq2", frame_number=4):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
# test indexing
with self.assertRaises(IndexError):
dataset[len(dataset) + 1]
# test sequence-frame indexing
item = dataset[sequence, frame_number]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, frame_number)
with self.assertRaises(IndexError):
dataset[sequence, 13]
def test_filter_empty_masks(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 78)
def test_pick_frames_sql_clause(self):
dataset_no_empty_masks = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# check the datasets are equal
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
for i in range(len(dataset)):
item_nem = dataset_no_empty_masks[i]
item = dataset[i]
self.assertEqual(item_nem.image_path, item.image_path)
# remove_empty_masks together with the custom criterion
dataset_ts = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_frames_sql_clause="frame_timestamp < 0.15",
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset_ts), 19)
def test_limit_categories(self, category="cat0"):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_categories=[category],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
for i in range(len(dataset)):
self.assertEqual(dataset[i].sequence_category, category)
def test_limit_sequences(self, num_sequences=3):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_sequences_to=num_sequences,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 10 * num_sequences)
def delist(sequence_name):
return sequence_name if isinstance(sequence_name, str) else sequence_name[0]
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))}
self.assertEqual(len(unique_seqs), num_sequences)
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"):
# pick sequence
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_sequences=[sequence],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 10)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertCountEqual(unique_seqs, {sequence})
item = dataset[sequence, 0]
self.assertEqual(item.sequence_name, sequence)
self.assertEqual(item.frame_number, 0)
# exclude sequence
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
exclude_sequences=[sequence],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 90)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertNotIn(sequence, unique_seqs)
with self.assertRaises(IndexError):
dataset[sequence, 0]
def test_limit_frames(self, num_frames=13):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=num_frames,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_frames)
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
self.assertEqual(len(unique_seqs), 2)
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=1000,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_limit_frames_per_sequence(self, num_frames=2):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
n_frames_per_sequence=num_frames,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), num_frames * 10)
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertEqual(len(seq_counts), 10)
self.assertCountEqual(
set(seq_counts.values()), {2}
) # all counts are num_frames
with self.assertRaises(IndexError):
dataset[next(iter(seq_counts)), num_frames + 1]
# test when the limit is not binding
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
n_frames_per_sequence=13,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
def test_filter_medley(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_categories=["cat1"],
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2"
n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
self.assertEqual(seq_counts["cat1_seq1"], 6)
self.assertEqual(seq_counts["cat1_seq2"], 4)
def test_subsets_trivial(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
limit_to=100, # force sorting
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 100)
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number == 0:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 1)
last_frame_number = item.frame_number
def test_subsets_filter_empty_masks(self):
# we need to test this case as it uses quite different logic with `df.drop()`
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 78)
def test_subsets_pick_frames_sql_clause(self):
dataset_no_empty_masks = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# check the datasets are equal
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
for i in range(len(dataset)):
item_nem = dataset_no_empty_masks[i]
item = dataset[i]
self.assertEqual(item_nem.image_path, item.image_path)
# remove_empty_masks together with the custom criterion
dataset_ts = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
pick_frames_sql_clause="frame_timestamp < 0.15",
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset_ts), 19)
def test_single_subset(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
with self.assertRaises(IndexError):
dataset[51]
# check the items are consecutive
past_sequences = set()
last_frame_number = -1
last_sequence = ""
for i in range(len(dataset)):
item = dataset[i]
if item.frame_number < 2:
self.assertNotIn(item.sequence_name, past_sequences)
past_sequences.add(item.sequence_name)
last_sequence = item.sequence_name
else:
self.assertEqual(item.sequence_name, last_sequence)
self.assertEqual(item.frame_number, last_frame_number + 2)
last_frame_number = item.frame_number
item = dataset[last_sequence, 0]
self.assertEqual(item.sequence_name, last_sequence)
with self.assertRaises(IndexError):
dataset[last_sequence, 1]
def test_subset_with_filters(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
pick_categories=["cat1"],
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2"
n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
self.assertEqual(seq_counts["cat1_seq1"], 3)
self.assertEqual(seq_counts["cat1_seq2"], 2)
def test_visitor(self):
dataset_sorted = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
sequences = dataset_sorted.sequence_names()
i = 0
for seq in sequences:
last_ts = float("-Inf")
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq):
self.assertEqual(i, idx)
i += 1
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
# test legacy visitor
old_indices = None
for seq in sequences:
last_ts = float("-Inf")
rows = dataset_sorted._index.index.get_loc(seq)
indices = list(range(rows.start or 0, rows.stop, rows.step or 1))
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices)
self.assertEqual(len(fn_ts_list), len(indices))
if old_indices:
# check raising if we ask for multiple sequences
with self.assertRaises(ValueError):
dataset_sorted.get_frame_numbers_and_timestamps(
indices + old_indices
)
old_indices = indices
def test_visitor_subsets(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
limit_to=100, # force sorting
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
sequences = dataset.sequence_names()
i = 0
for seq in sequences:
last_ts = float("-Inf")
seq_frames = list(dataset.sequence_frames_in_order(seq))
self.assertEqual(len(seq_frames), 10)
for ts, _, idx in seq_frames:
self.assertEqual(i, idx)
i += 1
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
last_ts = float("-Inf")
train_frames = list(dataset.sequence_frames_in_order(seq, "train"))
self.assertEqual(len(train_frames), 5)
for ts, _, _ in train_frames:
self.assertGreaterEqual(ts, last_ts)
last_ts = ts
def test_category_to_sequence_names(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
cat_to_seqs = dataset.category_to_sequence_names()
self.assertEqual(len(cat_to_seqs), 2)
self.assertIn("cat1", cat_to_seqs)
self.assertEqual(len(cat_to_seqs["cat1"]), 5)
# check that override preserves the behavior
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
def test_category_to_sequence_names_filters(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=True,
subset_lists_file=SET_LIST_FILE,
exclude_sequences=["cat1_seq0"],
subsets=["train", "test"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
cat_to_seqs = dataset.category_to_sequence_names()
self.assertEqual(len(cat_to_seqs), 2)
self.assertIn("cat1", cat_to_seqs)
self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one
# check that override preserves the behavior
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
def test_meta_access(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
)
self.assertEqual(len(dataset), 50)
for idx in [10, ("cat0_seq2", 2)]:
example_meta = dataset.meta[idx]
example = dataset[idx]
self.assertEqual(example_meta.sequence_name, example.sequence_name)
self.assertEqual(example_meta.frame_number, example.frame_number)
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
self.assertEqual(example_meta.sequence_category, example.sequence_category)
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
torch.testing.assert_close(
example_meta.camera.focal_length, example.camera.focal_length
)
torch.testing.assert_close(
example_meta.camera.principal_point, example.camera.principal_point
)
def test_meta_access_no_blobs(self):
dataset = SqlIndexDataset(
sqlite_metadata_file=METADATA_FILE,
remove_empty_masks=False,
subset_lists_file=SET_LIST_FILE,
subsets=["train"],
frame_data_builder_FrameDataBuilder_args={
"dataset_root": ".",
"box_crop": False, # required by blob-less accessor
},
)
self.assertIsNone(dataset.meta[0].image_rgb)
self.assertIsNone(dataset.meta[0].fg_probability)
self.assertIsNone(dataset.meta[0].depth_map)
self.assertIsNone(dataset.meta[0].sequence_point_cloud)
self.assertIsNotNone(dataset.meta[0].camera)

View File

@@ -120,6 +120,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
The scene is "already lit", i.e. the textures reflect the lighting
already, so we want to render them with full ambient light.
"""
self.skipTest("Data not available")
glb = DATA_DIR / "apartment_1.glb"
@@ -266,3 +267,117 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
expected = np.array(f)
self.assertClose(image, expected)
def test_load_save_load_cow_texturesvertex(self):
"""
Load the cow as converted to a single mesh in a glb file and then save it to a glb file.
"""
glb = DATA_DIR / "cow.glb"
self.assertTrue(glb.is_file())
device = torch.device("cuda:0")
mesh = _load(glb, device=device, include_textures=False)
self.assertEqual(len(mesh), 1)
self.assertIsNone(mesh.textures)
self.assertEqual(mesh.faces_packed().shape, (5856, 3))
self.assertEqual(mesh.verts_packed().shape, (3225, 3))
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes())
mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded()))
image = _render(mesh, "cow_gray")
with Image.open(DATA_DIR / "glb_cow_gray.png") as f:
expected = np.array(f)
self.assertClose(image, expected)
# save the mesh to a glb file
glb = DATA_DIR / "cow_write_texturesvertex.glb"
_write(mesh, glb)
# reload the mesh glb file saved in TexturesVertex format
glb = DATA_DIR / "cow_write_texturesvertex.glb"
self.assertTrue(glb.is_file())
mesh_dash = _load(glb, device=device)
self.assertEqual(len(mesh_dash), 1)
self.assertEqual(mesh_dash.faces_packed().shape, (5856, 3))
self.assertEqual(mesh_dash.verts_packed().shape, (3225, 3))
self.assertEqual(mesh_dash.textures.verts_features_list()[0].shape, (3225, 3))
# check the re-rendered image with expected
image_dash = _render(mesh, "cow_gray_texturesvertex")
self.assertClose(image_dash, expected)
def test_save_toy(self):
"""
Construct a simple mesh and save it to a glb file in TexturesVertex mode.
"""
example = {}
example["POSITION"] = torch.tensor(
[
[
[0.0, 0.0, 0.0],
[-1.0, 0.0, 0.0],
[-1.0, 0.0, 1.0],
[0.0, 0.0, 1.0],
[0.0, 1.0, 0.0],
[-1.0, 1.0, 0.0],
[-1.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
]
]
)
example["indices"] = torch.tensor(
[
[
[1, 4, 2],
[4, 3, 2],
[3, 7, 2],
[7, 6, 2],
[3, 4, 7],
[4, 8, 7],
[8, 5, 7],
[5, 6, 7],
[5, 2, 6],
[5, 1, 2],
[1, 5, 4],
[5, 8, 4],
]
]
)
example["indices"] -= 1
example["COLOR_0"] = torch.tensor(
[
[
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
]
]
)
# example['prop'] = {'material':
# {'pbrMetallicRoughness':
# {'baseColorFactor':
# torch.tensor([[0.7, 0.7, 1, 0.5]]),
# 'metallicFactor': torch.tensor([1]),
# 'roughnessFactor': torch.tensor([0.1])},
# 'alphaMode': 'BLEND',
# 'doubleSided': True}}
texture = TexturesVertex(example["COLOR_0"])
mesh = Meshes(
verts=example["POSITION"], faces=example["indices"], textures=texture
)
glb = DATA_DIR / "example_write_texturesvertex.glb"
_write(mesh, glb)

View File

@@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 4 2 1",
]
)
actual_file = open(Path(f.name), "r")
self.assertEqual(actual_file.read(), expected_file)
self.assertEqual(Path(f.name).read_text(), expected_file)
def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj"
@@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
def test_save_obj_with_normal(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[
[0.02, 0.5, 0.73],
[0.3, 0.03, 0.361],
[0.32, 0.12, 0.47],
[0.36, 0.17, 0.9],
[0.40, 0.7, 0.19],
[1.0, 0.00, 0.000],
[0.00, 1.00, 0.00],
[0.00, 0.00, 1.0],
],
dtype=torch.float32,
)
faces_normals_idx = torch.tensor(
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
)
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
normals=normals,
faces_normals_idx=faces_normals_idx,
)
expected_obj_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.02 0.50 0.73",
"vn 0.30 0.03 0.36",
"vn 0.32 0.12 0.47",
"vn 0.36 0.17 0.90",
"vn 0.40 0.70 0.19",
"vn 1.00 0.00 0.00",
"vn 0.00 1.00 0.00",
"vn 0.00 0.00 1.00",
"f 1//1 3//2 2//3",
"f 1//3 2//4 3//5",
"f 4//5 3//6 2//7",
"f 4//7 2//8 1//1",
]
)
# Check the obj file is saved correctly
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
def test_save_obj_with_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
@@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
def test_save_obj_with_normal_and_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[
[0.02, 0.5, 0.73],
[0.3, 0.03, 0.361],
[0.32, 0.12, 0.47],
[0.36, 0.17, 0.9],
],
dtype=torch.float32,
)
faces_normals_idx = faces
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
normals=normals,
faces_normals_idx=faces_normals_idx,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.02 0.50 0.73",
"vn 0.30 0.03 0.36",
"vn 0.32 0.12 0.47",
"vn 0.36 0.17 0.90",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1/1 3/3/3 2/2/2",
"f 1/1/1 2/2/2 3/3/3",
"f 4/4/4 3/3/3 2/2/2",
"f 4/4/4 2/2/2 1/1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
@@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(tempfiles, tempfiles_dir)
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
obj_file = StringIO()
with self.assertRaises(ValueError):
@@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)