mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix: Correct concatenation of datasets in train conditioning
Summary: ChainDataset is iterable, and it toes not go along with a custom batch sampler. Reviewed By: bottler Differential Revision: D42742315 fbshipit-source-id: 40a715c8d24abe72cb2777634247d7467f628564
This commit is contained in:
parent
11959e0b24
commit
3239594f78
@ -12,7 +12,7 @@ import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
ChainDataset,
|
||||
ConcatDataset,
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
@ -482,7 +482,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return DataLoader(
|
||||
ChainDataset([dataset, train_dataset]),
|
||||
ConcatDataset([dataset, train_dataset]),
|
||||
batch_sampler=sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user