mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Summary: Converts the directory specified to use the Ruff formatter in pyfmt ruff_dog If this diff causes merge conflicts when rebasing, please run `hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt` on your diff, and amend any changes before rebasing onto latest. That should help reduce or eliminate any merge conflicts. allow-large-files Reviewed By: bottler Differential Revision: D66472063 fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
249 lines
9.8 KiB
Python
249 lines
9.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import json
|
|
import os
|
|
import random
|
|
import tempfile
|
|
import unittest
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torchvision
|
|
from PIL import Image
|
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
|
JsonIndexDatasetMapProviderV2,
|
|
)
|
|
from pytorch3d.implicitron.dataset.types import (
|
|
dump_dataclass_jgzip,
|
|
FrameAnnotation,
|
|
ImageAnnotation,
|
|
MaskAnnotation,
|
|
SequenceAnnotation,
|
|
)
|
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
|
from tests.common_testing import interactive_testing_requested
|
|
|
|
from .common_resources import CO3DV2_MANIFOLD_PATH
|
|
|
|
|
|
class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
|
def test_random_dataset(self):
|
|
# store random frame annotations
|
|
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
|
categories = ["A", "B"]
|
|
subset_name = "test"
|
|
eval_batch_size = 5
|
|
n_frames = 8 * 3
|
|
n_sequences = 5
|
|
n_eval_batches = 10
|
|
with tempfile.TemporaryDirectory() as tmpd:
|
|
_make_random_json_dataset_map_provider_v2_data(
|
|
tmpd,
|
|
categories,
|
|
eval_batch_size=eval_batch_size,
|
|
n_frames=n_frames,
|
|
n_sequences=n_sequences,
|
|
n_eval_batches=n_eval_batches,
|
|
)
|
|
for n_known_frames_for_test in [0, 2]:
|
|
dataset_providers = {
|
|
category: JsonIndexDatasetMapProviderV2(
|
|
category=category,
|
|
subset_name="test",
|
|
dataset_root=tmpd,
|
|
n_known_frames_for_test=n_known_frames_for_test,
|
|
)
|
|
for category in [*categories, ",".join(sorted(categories))]
|
|
}
|
|
for category, dataset_provider in dataset_providers.items():
|
|
dataset_map = dataset_provider.get_dataset_map()
|
|
for set_ in ["train", "val", "test"]:
|
|
dataset = getattr(dataset_map, set_)
|
|
|
|
cat2seq = dataset.category_to_sequence_names()
|
|
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
|
|
|
|
if not (n_known_frames_for_test != 0 and set_ == "test"):
|
|
# check the lengths only in case we do not have the
|
|
# n_known_frames_for_test set
|
|
expected_dataset_len = n_frames * n_sequences // 3
|
|
if "," in category:
|
|
# multicategory json index dataset, sum the lengths of
|
|
# category-specific ones
|
|
expected_dataset_len = sum(
|
|
len(
|
|
getattr(
|
|
dataset_providers[c].get_dataset_map(), set_
|
|
)
|
|
)
|
|
for c in categories
|
|
)
|
|
self.assertEqual(
|
|
sum(len(s) for s in cat2seq.values()),
|
|
n_sequences * len(categories),
|
|
)
|
|
self.assertEqual(len(cat2seq), len(categories))
|
|
else:
|
|
self.assertEqual(
|
|
len(cat2seq[category]),
|
|
n_sequences,
|
|
)
|
|
self.assertEqual(len(cat2seq), 1)
|
|
self.assertEqual(len(dataset), expected_dataset_len)
|
|
|
|
if set_ == "test":
|
|
# check the number of eval batches
|
|
expected_n_eval_batches = n_eval_batches
|
|
if "," in category:
|
|
expected_n_eval_batches *= len(categories)
|
|
self.assertTrue(
|
|
len(dataset.get_eval_batches())
|
|
== expected_n_eval_batches
|
|
)
|
|
if set_ in ["train", "val"]:
|
|
dataloader = torch.utils.data.DataLoader(
|
|
getattr(dataset_map, set_),
|
|
batch_size=3,
|
|
shuffle=True,
|
|
collate_fn=FrameData.collate,
|
|
)
|
|
else:
|
|
dataloader = torch.utils.data.DataLoader(
|
|
getattr(dataset_map, set_),
|
|
batch_sampler=dataset_map[set_].get_eval_batches(),
|
|
collate_fn=FrameData.collate,
|
|
)
|
|
for batch in dataloader:
|
|
if set_ == "test":
|
|
self.assertTrue(
|
|
batch.image_rgb.shape[0]
|
|
== n_known_frames_for_test + eval_batch_size
|
|
)
|
|
category_to_subset_list = (
|
|
dataset_provider.get_category_to_subset_name_list()
|
|
)
|
|
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
|
|
|
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
|
|
|
|
|
def _make_random_json_dataset_map_provider_v2_data(
|
|
root: str,
|
|
categories: List[str],
|
|
n_frames: int = 8,
|
|
n_sequences: int = 5,
|
|
n_eval_batches: int = 10,
|
|
H: int = 50,
|
|
W: int = 30,
|
|
subset_name: str = "test",
|
|
eval_batch_size: int = 5,
|
|
):
|
|
os.makedirs(root, exist_ok=True)
|
|
category_to_subset_list = {}
|
|
for category in categories:
|
|
frame_annotations = []
|
|
sequence_annotations = []
|
|
frame_index = []
|
|
for seq_i in range(n_sequences):
|
|
seq_name = category + str(seq_i)
|
|
for i in range(n_frames):
|
|
# generate and store image
|
|
imdir = os.path.join(root, category, seq_name, "images")
|
|
os.makedirs(imdir, exist_ok=True)
|
|
img_path = os.path.join(imdir, f"frame{i:05d}.jpg")
|
|
img = torch.rand(3, H, W)
|
|
torchvision.utils.save_image(img, img_path)
|
|
|
|
# generate and store mask
|
|
maskdir = os.path.join(root, category, seq_name, "masks")
|
|
os.makedirs(maskdir, exist_ok=True)
|
|
mask_path = os.path.join(maskdir, f"frame{i:05d}.png")
|
|
mask = np.zeros((H, W))
|
|
mask[H // 2 :, W // 2 :] = 1
|
|
Image.fromarray(
|
|
(mask * 255.0).astype(np.uint8),
|
|
mode="L",
|
|
).convert("L").save(mask_path)
|
|
|
|
fa = FrameAnnotation(
|
|
sequence_name=seq_name,
|
|
frame_number=i,
|
|
frame_timestamp=float(i),
|
|
image=ImageAnnotation(
|
|
path=img_path.replace(os.path.normpath(root) + "/", ""),
|
|
size=list(img.shape[-2:]),
|
|
),
|
|
mask=MaskAnnotation(
|
|
path=mask_path.replace(os.path.normpath(root) + "/", ""),
|
|
mass=mask.sum().item(),
|
|
),
|
|
)
|
|
frame_annotations.append(fa)
|
|
frame_index.append((seq_name, i, fa.image.path))
|
|
|
|
sequence_annotations.append(
|
|
SequenceAnnotation(
|
|
sequence_name=seq_name,
|
|
category=category,
|
|
)
|
|
)
|
|
|
|
dump_dataclass_jgzip(
|
|
os.path.join(root, category, "frame_annotations.jgz"),
|
|
frame_annotations,
|
|
)
|
|
dump_dataclass_jgzip(
|
|
os.path.join(root, category, "sequence_annotations.jgz"),
|
|
sequence_annotations,
|
|
)
|
|
|
|
test_frame_index = frame_index[2::3]
|
|
|
|
set_list = {
|
|
"train": frame_index[0::3],
|
|
"val": frame_index[1::3],
|
|
"test": test_frame_index,
|
|
}
|
|
set_lists_dir = os.path.join(root, category, "set_lists")
|
|
os.makedirs(set_lists_dir, exist_ok=True)
|
|
set_list_file = os.path.join(set_lists_dir, f"set_lists_{subset_name}.json")
|
|
with open(set_list_file, "w") as f:
|
|
json.dump(set_list, f)
|
|
|
|
eval_batches = [
|
|
random.sample(test_frame_index, eval_batch_size)
|
|
for _ in range(n_eval_batches)
|
|
]
|
|
|
|
eval_b_dir = os.path.join(root, category, "eval_batches")
|
|
os.makedirs(eval_b_dir, exist_ok=True)
|
|
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
|
with open(eval_b_file, "w") as f:
|
|
json.dump(eval_batches, f)
|
|
|
|
category_to_subset_list[category] = [subset_name]
|
|
|
|
with open(os.path.join(root, "category_to_subset_name_list.json"), "w") as f:
|
|
json.dump(category_to_subset_list, f)
|
|
|
|
|
|
class TestCo3dv2(unittest.TestCase):
|
|
def test_simple(self):
|
|
if not interactive_testing_requested():
|
|
return
|
|
dataset_provider = JsonIndexDatasetMapProviderV2(
|
|
category="apple",
|
|
subset_name="manyview_dev_0",
|
|
dataset_root=CO3DV2_MANIFOLD_PATH,
|
|
dataset_JsonIndexDataset_args={"load_point_clouds": True},
|
|
)
|
|
dataset_provider.get_dataset_map().train[0]
|