Loosening the checks in eval script for CO3Dv2 style eval

Summary:
V2 dataset does not have the concept of known/unseen frames. Test-time conditining is done with train-set frames, which violates the previous check.

Also fixing a corner case in VideoWriter.

Reviewed By: bottler

Differential Revision: D42706976

fbshipit-source-id: d43be3dd3060d18cb9f46d5dcf6252d9f084110f
This commit is contained in:
Roman Shapovalov 2023-01-26 03:00:46 -08:00 committed by Facebook GitHub Bot
parent 9dc28f5dd5
commit 54eb76d48c
2 changed files with 8 additions and 12 deletions

View File

@ -219,17 +219,10 @@ def eval_batch(
frame_type = [frame_type] frame_type = [frame_type]
is_train = is_train_frame(frame_type) is_train = is_train_frame(frame_type)
if not (is_train[0] == is_train).all(): if len(is_train) > 1 and (is_train[1] != is_train[1:]).any():
raise ValueError("All frames in the eval batch have to be either train/test.")
# pyre-fixme[16]: `Optional` has no attribute `device`.
is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device)
if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()):
raise ValueError( raise ValueError(
"For evaluation the first element of the batch has to be" "All (conditioning) frames in the eval batch have to be either train/test."
+ " a target view while the rest should be source views." )
) # TODO: do we need to enforce this?
for k in [ for k in [
"depth_map", "depth_map",
@ -362,7 +355,7 @@ def eval_batch(
results["meta"] = { results["meta"] = {
# store the size of the batch (corresponds to n_src_views+1) # store the size of the batch (corresponds to n_src_views+1)
"batch_size": int(is_known.numel()), "batch_size": len(frame_type),
# store the type of the target frame # store the type of the target frame
# pyre-fixme[16]: `None` has no attribute `__getitem__`. # pyre-fixme[16]: `None` has no attribute `__getitem__`.
"frame_type": str(frame_data.frame_type[0]), "frame_type": str(frame_data.frame_type[0]),

View File

@ -124,8 +124,11 @@ class VideoWriter:
quiet: If `True`, suppresses logging messages. quiet: If `True`, suppresses logging messages.
Returns: Returns:
video_path: The path to the generated video. video_path: The path to the generated video if any frames were added.
Otherwise returns an empty string.
""" """
if self.frame_num == 0:
return ""
regexp = os.path.join(self.cache_dir, self.regexp) regexp = os.path.join(self.cache_dir, self.regexp)