mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
load whole dataset in train loop
Summary: Loads the whole dataset and moves it to the device and sends it to for sampling to enable full dataset heterogeneous raysampling. Reviewed By: bottler Differential Revision: D39263009 fbshipit-source-id: c527537dfc5f50116849656c9e171e868f6845b1
This commit is contained in:
parent
c311a4cbb9
commit
37bd280d19
@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
train_loader=train_loader,
|
train_loader=train_loader,
|
||||||
val_loader=val_loader,
|
val_loader=val_loader,
|
||||||
test_loader=test_loader,
|
test_loader=test_loader,
|
||||||
|
# pyre-ignore[6]
|
||||||
train_dataset=datasets.train,
|
train_dataset=datasets.train,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -22,7 +22,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.stats import Stats
|
from pytorch3d.implicitron.tools.stats import Stats
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from .utils import seed_all_random_engines
|
from .utils import seed_all_random_engines
|
||||||
|
|
||||||
@ -44,6 +44,7 @@ class TrainingLoopBase(ReplaceableBase):
|
|||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
val_loader: Optional[DataLoader],
|
val_loader: Optional[DataLoader],
|
||||||
test_loader: Optional[DataLoader],
|
test_loader: Optional[DataLoader],
|
||||||
|
train_dataset: Dataset,
|
||||||
model: ImplicitronModelBase,
|
model: ImplicitronModelBase,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: Any,
|
scheduler: Any,
|
||||||
@ -116,6 +117,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
val_loader: Optional[DataLoader],
|
val_loader: Optional[DataLoader],
|
||||||
test_loader: Optional[DataLoader],
|
test_loader: Optional[DataLoader],
|
||||||
|
train_dataset: Dataset,
|
||||||
model: ImplicitronModelBase,
|
model: ImplicitronModelBase,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: Any,
|
scheduler: Any,
|
||||||
|
@ -389,7 +389,8 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
)
|
)
|
||||||
|
|
||||||
# (1) Sample rendering rays with the ray sampler.
|
# (1) Sample rendering rays with the ray sampler.
|
||||||
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
|
# pyre-ignore[29]
|
||||||
|
ray_bundle: ImplicitronRayBundle = self.raysampler(
|
||||||
target_cameras,
|
target_cameras,
|
||||||
evaluation_mode,
|
evaluation_mode,
|
||||||
mask=mask_crop[:n_targets]
|
mask=mask_crop[:n_targets]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user