diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 7795a18e..1b355f26 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -222,6 +222,7 @@ class Experiment(Configurable): # pyre-ignore: 13 train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, + # pyre-ignore[6] train_dataset=datasets.train, model=model, optimizer=optimizer, diff --git a/projects/implicitron_trainer/impl/training_loop.py b/projects/implicitron_trainer/impl/training_loop.py index 31824909..1cafc38b 100644 --- a/projects/implicitron_trainer/impl/training_loop.py +++ b/projects/implicitron_trainer/impl/training_loop.py @@ -22,7 +22,7 @@ from pytorch3d.implicitron.tools.config import ( ) from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.renderer.cameras import CamerasBase -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from .utils import seed_all_random_engines @@ -44,6 +44,7 @@ class TrainingLoopBase(ReplaceableBase): train_loader: DataLoader, val_loader: Optional[DataLoader], test_loader: Optional[DataLoader], + train_dataset: Dataset, model: ImplicitronModelBase, optimizer: torch.optim.Optimizer, scheduler: Any, @@ -116,6 +117,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): train_loader: DataLoader, val_loader: Optional[DataLoader], test_loader: Optional[DataLoader], + train_dataset: Dataset, model: ImplicitronModelBase, optimizer: torch.optim.Optimizer, scheduler: Any, diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index d2a4248c..228bbcec 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -389,7 +389,8 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ) # (1) Sample rendering rays with the ray sampler. - ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29] + # pyre-ignore[29] + ray_bundle: ImplicitronRayBundle = self.raysampler( target_cameras, evaluation_mode, mask=mask_crop[:n_targets]