mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Consider the first frame as target ignoring subset labels in evaluator
Summary:
Aligning the logic with the official CO3Dv2 evaluation: 92283c4368/co3d/dataset/utils.py (L7)
This will make the evaluator work with the datasets that do not define known/unseen subsets.
Reviewed By: bottler
Differential Revision: D42803136
fbshipit-source-id: cfac389eab010c32d2e33b40fc7f6ed845c327ef
			
			
This commit is contained in:
		
							parent
							
								
									9540c29023
								
							
						
					
					
						commit
						a7256e4034
					
				@ -15,7 +15,7 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
 | 
			
		||||
from pytorch3d.implicitron.tools import vis_utils
 | 
			
		||||
from pytorch3d.implicitron.tools.image_utils import mask_background
 | 
			
		||||
 | 
			
		||||
@ -149,14 +149,11 @@ def _dump_to_json(
 | 
			
		||||
 | 
			
		||||
def _get_eval_frame_data(frame_data: Any) -> Any:
 | 
			
		||||
    """
 | 
			
		||||
    Masks the unknown image data to make sure we cannot use it at model evaluation time.
 | 
			
		||||
    Masks the target image data to make sure we cannot use it at model evaluation
 | 
			
		||||
    time. Assumes the first batch element is target, the rest are source.
 | 
			
		||||
    """
 | 
			
		||||
    frame_data_for_eval = copy.deepcopy(frame_data)
 | 
			
		||||
    is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
 | 
			
		||||
        frame_data.image_rgb
 | 
			
		||||
    )[:, None, None, None]
 | 
			
		||||
    for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
 | 
			
		||||
        value = getattr(frame_data_for_eval, k)
 | 
			
		||||
        value_masked = value.clone() * is_known if value is not None else None
 | 
			
		||||
        setattr(frame_data_for_eval, k, value_masked)
 | 
			
		||||
        value[0].zero_()
 | 
			
		||||
    return frame_data_for_eval
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user