mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	render_flyaround bugfix
Summary: Fixes a bug which would crash render_flyaround anytime visualize_preds_keys is adjusted Reviewed By: shapovalov Differential Revision: D41124462 fbshipit-source-id: 127045a91a055909f8bd56c8af81afac02c00f60
This commit is contained in:
		
							parent
							
								
									35f8cb9430
								
							
						
					
					
						commit
						94f321fa3d
					
				@ -10,7 +10,17 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -180,7 +190,7 @@ def render_flyaround(
 | 
			
		||||
            preds.update(net_input)  # merge everything into one big dict
 | 
			
		||||
 | 
			
		||||
            # Render the predictions to images
 | 
			
		||||
            rendered_pred = _images_from_preds(preds)
 | 
			
		||||
            rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys)
 | 
			
		||||
            preds_total.append(rendered_pred)
 | 
			
		||||
 | 
			
		||||
            # show the preds every 5% of the export iterations
 | 
			
		||||
@ -223,9 +233,9 @@ def _load_whole_dataset(
 | 
			
		||||
    return next(iter(load_all_dataloader))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
 | 
			
		||||
    imout = {}
 | 
			
		||||
    for k in (
 | 
			
		||||
def _images_from_preds(
 | 
			
		||||
    preds: Dict[str, Any],
 | 
			
		||||
    extract_keys: Iterable[str] = (
 | 
			
		||||
        "image_rgb",
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "fg_probability",
 | 
			
		||||
@ -233,7 +243,10 @@ def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "depth_map",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ):
 | 
			
		||||
    ),
 | 
			
		||||
) -> Dict[str, torch.Tensor]:
 | 
			
		||||
    imout = {}
 | 
			
		||||
    for k in extract_keys:
 | 
			
		||||
        if k == "_all_source_images" and "image_rgb" in preds:
 | 
			
		||||
            src_ims = preds["image_rgb"][1:].cpu().detach().clone()
 | 
			
		||||
            v = _stack_images(src_ims, None)[None]
 | 
			
		||||
@ -343,6 +356,9 @@ def _generate_prediction_videos(
 | 
			
		||||
    # init a video writer for each predicted key
 | 
			
		||||
    vws = {}
 | 
			
		||||
    for k in predicted_keys:
 | 
			
		||||
        if k not in preds[0]:
 | 
			
		||||
            logger.warn(f"Cannot generate video for prediction key '{k}'")
 | 
			
		||||
            continue
 | 
			
		||||
        cache_dir = (
 | 
			
		||||
            None
 | 
			
		||||
            if video_frames_dir is None
 | 
			
		||||
@ -355,13 +371,15 @@ def _generate_prediction_videos(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    for rendered_pred in tqdm(preds):
 | 
			
		||||
        for k in predicted_keys:
 | 
			
		||||
        for k in vws:
 | 
			
		||||
            vws[k].write_frame(
 | 
			
		||||
                rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
 | 
			
		||||
                resize=resize,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    for k in predicted_keys:
 | 
			
		||||
        if k not in vws:
 | 
			
		||||
            continue
 | 
			
		||||
        vws[k].get_video()
 | 
			
		||||
        logger.info(f"Generated {vws[k].out_path}.")
 | 
			
		||||
        if viz is not None:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user