CO3Dv2 multi-category extension

Summary:
Allows loading of multiple categories.
Multiple categories are provided in a comma-separated list of category names.

Reviewed By: bottler, shapovalov

Differential Revision: D40803297

fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
David Novotny
2022-11-02 13:55:25 -07:00
committed by Facebook GitHub Bot
parent c54e048666
commit e4a3298149
9 changed files with 272 additions and 25 deletions

View File

@@ -11,17 +11,20 @@ from dataclasses import dataclass
from itertools import product
import numpy as np
import torch
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
sequence_name: str = "sequence"
frame_timestamp: float = 0.0
@@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
for seq_name, idx in self._seq_to_idx.items():
for i in idx:
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
def get_frame_numbers_and_timestamps(self, idxs):
out = []
@@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
)
return out
def __getitem__(self, index: int):
fa = self.frame_annots[index]["frame_annotation"]
fd = FrameData(
sequence_name=fa.sequence_name,
sequence_category="default_category",
frame_number=torch.LongTensor([fa.frame_number]),
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
)
return fd
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):