diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 04765205..42cc0df1 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -97,13 +97,16 @@ try: except ModuleNotFoundError: pass +no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None + def init_model( + *, cfg: DictConfig, + accelerator: Optional[Accelerator] = None, force_load: bool = False, clear_stats: bool = False, load_model_only: bool = False, - accelerator: Accelerator = None, ) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]: """ Returns an instance of `GenericModel`. @@ -166,7 +169,7 @@ def init_model( logger.info(" -> resuming") map_location = None - if not accelerator.is_local_main_process: + if accelerator is not None and not accelerator.is_local_main_process: map_location = { "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index } @@ -228,13 +231,14 @@ def trainvalidate( loader, optimizer, validation: bool, + *, + accelerator: Optional[Accelerator], + device: torch.device, bp_var: str = "objective", metric_print_interval: int = 5, visualize_interval: int = 100, visdom_env_root: str = "trainvalidate", clip_grad: float = 0.0, - device: str = "cuda:0", - accelerator: Accelerator = None, **kwargs, ) -> None: """ @@ -286,7 +290,7 @@ def trainvalidate( last_iter = it == n_batches - 1 # move to gpu where possible (in place) - net_input = net_input.to(accelerator.device) + net_input = net_input.to(device) # run the forward pass if not validation: @@ -313,7 +317,7 @@ def trainvalidate( # visualize results if ( - accelerator.is_local_main_process + (accelerator is None or accelerator.is_local_main_process) and visualize_interval > 0 and it % visualize_interval == 0 ): @@ -331,7 +335,10 @@ def trainvalidate( loss = preds[bp_var] assert torch.isfinite(loss).all(), "Non-finite loss!" # backprop - accelerator.backward(loss) + if accelerator is None: + loss.backward() + else: + accelerator.backward(loss) if clip_grad > 0.0: # Optionally clip the gradient norms. total_norm = torch.nn.utils.clip_grad_norm( @@ -353,15 +360,16 @@ def run_training(cfg: DictConfig) -> None: """ # Initialize the accelerator - accelerator = Accelerator(device_placement=False) - logger.info(accelerator.state) + if no_accelerate: + accelerator = None + device = torch.device("cuda:0") + else: + accelerator = Accelerator(device_placement=False) + logger.info(accelerator.state) + device = accelerator.device - device = accelerator.device logger.info(f"Running experiment on device: {device}") - if accelerator.is_local_main_process: - logger.info(OmegaConf.to_yaml(cfg)) - # set the debug mode if cfg.detect_anomaly: logger.info("Anomaly detection!") @@ -386,11 +394,11 @@ def run_training(cfg: DictConfig) -> None: all_train_cameras = datasource.get_all_train_cameras() # init the model - model, stats, optimizer_state = init_model(cfg, accelerator=accelerator) + model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator) start_epoch = stats.epoch + 1 # move model to gpu - model.to(accelerator.device) + model.to(device) # only run evaluation on the test dataloader if cfg.eval_only: @@ -403,7 +411,6 @@ def run_training(cfg: DictConfig) -> None: model, stats, device=device, - accelerator=accelerator, ) return @@ -422,12 +429,15 @@ def run_training(cfg: DictConfig) -> None: # Wrap all modules in the distributed library # Note: we don't pass the scheduler to prepare as it # doesn't need to be stepped at each optimizer step - ( - model, - optimizer, - train_loader, - val_loader, - ) = accelerator.prepare(model, optimizer, dataloaders.train, dataloaders.val) + train_loader = dataloaders.train + val_loader = dataloaders.val + if accelerator is not None: + ( + model, + optimizer, + train_loader, + val_loader, + ) = accelerator.prepare(model, optimizer, train_loader, val_loader) past_scheduler_lrs = [] # loop through epochs @@ -485,19 +495,22 @@ def run_training(cfg: DictConfig) -> None: task, camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, device=device, - accelerator=accelerator, ) assert stats.epoch == epoch, "inconsistent stats!" # delete previous models if required # save model only on the main process - if cfg.store_checkpoints and accelerator.is_local_main_process: + if cfg.store_checkpoints and ( + accelerator is None or accelerator.is_local_main_process + ): if cfg.store_checkpoints_purge > 0: for prev_epoch in range(epoch - cfg.store_checkpoints_purge): model_io.purge_epoch(cfg.exp_dir, prev_epoch) outfile = model_io.get_checkpoint(cfg.exp_dir, epoch) - unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model = ( + model if accelerator is None else accelerator.unwrap_model(model) + ) model_io.safe_save_model( unwrapped_model, stats, outfile, optimizer=optimizer ) @@ -530,7 +543,6 @@ def _eval_and_dump( model, stats, device, - accelerator: Accelerator = None, ) -> None: """ Run the evaluation loop with the test data loader and @@ -549,7 +561,6 @@ def _eval_and_dump( task, camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, device=device, - accelerator=accelerator, ) # add the evaluation epoch to the results @@ -584,20 +595,19 @@ def _run_eval( task: Task, camera_difficulty_bin_breaks: Tuple[float, float], device, - accelerator: Accelerator = None, ): """ Run the evaluation loop on the test dataloader """ lpips_model = lpips.LPIPS(net="vgg") - lpips_model = lpips_model.to(accelerator.device) + lpips_model = lpips_model.to(device) model.eval() per_batch_eval_results = [] logger.info("Evaluating model ...") for frame_data in tqdm.tqdm(loader): - frame_data = frame_data.to(accelerator.device) + frame_data = frame_data.to(device) # mask out the unknown images so that the model does not see them frame_data_for_eval = _get_eval_frame_data(frame_data) diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index cfec4373..83c10358 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -344,7 +344,7 @@ def export_scenes( os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx) # Load the previously trained model - model, _, _ = init_model(config, force_load=True, load_model_only=True) + model, _, _ = init_model(cfg=config, force_load=True, load_model_only=True) model.cuda() model.eval()