mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c54e048666
commit
e4a3298149
@@ -11,17 +11,20 @@ from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFrameAnnotation:
|
||||
frame_number: int
|
||||
sequence_name: str = "sequence"
|
||||
frame_timestamp: float = 0.0
|
||||
|
||||
|
||||
@@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
|
||||
self.frame_annots = [
|
||||
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
|
||||
]
|
||||
for seq_name, idx in self._seq_to_idx.items():
|
||||
for i in idx:
|
||||
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
|
||||
|
||||
def get_frame_numbers_and_timestamps(self, idxs):
|
||||
out = []
|
||||
@@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
|
||||
)
|
||||
return out
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
fa = self.frame_annots[index]["frame_annotation"]
|
||||
fd = FrameData(
|
||||
sequence_name=fa.sequence_name,
|
||||
sequence_category="default_category",
|
||||
frame_number=torch.LongTensor([fa.frame_number]),
|
||||
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
|
||||
)
|
||||
return fd
|
||||
|
||||
|
||||
class TestSceneBatchSampler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
Reference in New Issue
Block a user