Fix: Correct concatenation of datasets in train conditioning

Summary: ChainDataset is iterable, and it toes not go along with a custom batch sampler.

Reviewed By: bottler

Differential Revision: D42742315

fbshipit-source-id: 40a715c8d24abe72cb2777634247d7467f628564
This commit is contained in:
Roman Shapovalov 2023-01-26 03:00:46 -08:00 committed by Facebook GitHub Bot
parent 11959e0b24
commit 3239594f78

View File

@ -12,7 +12,7 @@ import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
ChainDataset, ConcatDataset,
DataLoader, DataLoader,
RandomSampler, RandomSampler,
Sampler, Sampler,
@ -482,7 +482,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
num_batches=num_batches, num_batches=num_batches,
) )
return DataLoader( return DataLoader(
ChainDataset([dataset, train_dataset]), ConcatDataset([dataset, train_dataset]),
batch_sampler=sampler, batch_sampler=sampler,
**data_loader_kwargs, **data_loader_kwargs,
) )