Fix: Correct concatenation of datasets in train conditioning

Summary: ChainDataset is iterable, and it toes not go along with a custom batch sampler.

Reviewed By: bottler

Differential Revision: D42742315

fbshipit-source-id: 40a715c8d24abe72cb2777634247d7467f628564
This commit is contained in:
Roman Shapovalov 2023-01-26 03:00:46 -08:00 committed by Facebook GitHub Bot
parent 11959e0b24
commit 3239594f78

View File

@ -12,7 +12,7 @@ import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from torch.utils.data import (
BatchSampler,
ChainDataset,
ConcatDataset,
DataLoader,
RandomSampler,
Sampler,
@ -482,7 +482,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
num_batches=num_batches,
)
return DataLoader(
ChainDataset([dataset, train_dataset]),
ConcatDataset([dataset, train_dataset]),
batch_sampler=sampler,
**data_loader_kwargs,
)