mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	load whole dataset in train loop
Summary: Loads the whole dataset and moves it to the device and sends it to for sampling to enable full dataset heterogeneous raysampling. Reviewed By: bottler Differential Revision: D39263009 fbshipit-source-id: c527537dfc5f50116849656c9e171e868f6845b1
This commit is contained in:
		
							parent
							
								
									c311a4cbb9
								
							
						
					
					
						commit
						37bd280d19
					
				@ -222,6 +222,7 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
            train_loader=train_loader,
 | 
			
		||||
            val_loader=val_loader,
 | 
			
		||||
            test_loader=test_loader,
 | 
			
		||||
            # pyre-ignore[6]
 | 
			
		||||
            train_dataset=datasets.train,
 | 
			
		||||
            model=model,
 | 
			
		||||
            optimizer=optimizer,
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.stats import Stats
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from torch.utils.data import DataLoader
 | 
			
		||||
from torch.utils.data import DataLoader, Dataset
 | 
			
		||||
 | 
			
		||||
from .utils import seed_all_random_engines
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,7 @@ class TrainingLoopBase(ReplaceableBase):
 | 
			
		||||
        train_loader: DataLoader,
 | 
			
		||||
        val_loader: Optional[DataLoader],
 | 
			
		||||
        test_loader: Optional[DataLoader],
 | 
			
		||||
        train_dataset: Dataset,
 | 
			
		||||
        model: ImplicitronModelBase,
 | 
			
		||||
        optimizer: torch.optim.Optimizer,
 | 
			
		||||
        scheduler: Any,
 | 
			
		||||
@ -116,6 +117,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
 | 
			
		||||
        train_loader: DataLoader,
 | 
			
		||||
        val_loader: Optional[DataLoader],
 | 
			
		||||
        test_loader: Optional[DataLoader],
 | 
			
		||||
        train_dataset: Dataset,
 | 
			
		||||
        model: ImplicitronModelBase,
 | 
			
		||||
        optimizer: torch.optim.Optimizer,
 | 
			
		||||
        scheduler: Any,
 | 
			
		||||
 | 
			
		||||
@ -389,7 +389,8 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # (1) Sample rendering rays with the ray sampler.
 | 
			
		||||
        ray_bundle: ImplicitronRayBundle = self.raysampler(  # pyre-fixme[29]
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        ray_bundle: ImplicitronRayBundle = self.raysampler(
 | 
			
		||||
            target_cameras,
 | 
			
		||||
            evaluation_mode,
 | 
			
		||||
            mask=mask_crop[:n_targets]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user