FrameAnnotation.meta, Optional in _dataclass_from_dict

Summary: Allow extra data in a FrameAnnotation. Therefore allow Optional[T] systematically in _dataclass_from_dict

Reviewed By: davnov134

Differential Revision: D36442691

fbshipit-source-id: ba70f6491574c08b0d9c9acb63f35514d29de214
This commit is contained in:
Jeremy Reizenstein 2022-05-17 08:16:29 -07:00 committed by Facebook GitHub Bot
parent f36b11fe49
commit f632c423ef
2 changed files with 35 additions and 6 deletions

View File

@ -9,7 +9,7 @@ import dataclasses
import gzip
import json
from dataclasses import dataclass, Field, MISSING
from typing import Any, cast, IO, Optional, Tuple, Type, TypeVar, Union
from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union
import numpy as np
from pytorch3d.common.datatypes import get_args, get_origin
@ -80,6 +80,7 @@ class FrameAnnotation:
depth: Optional[DepthAnnotation] = None
mask: Optional[MaskAnnotation] = None
viewpoint: Optional[ViewpointAnnotation] = None
meta: Optional[Dict[str, Any]] = None
@dataclass
@ -169,9 +170,11 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
cls = get_origin(typeannot) or typeannot
if typeannot is Any:
return dlist
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
elif any(obj is None for obj in dlist):
if any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
@ -180,8 +183,13 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
for i, obj in zip(idx, converted):
res[i] = obj
return res
is_optional, contained_type = _resolve_optional(typeannot)
if is_optional:
return _dataclass_list_from_dict_list(dlist, contained_type)
# otherwise, we dispatch by the type of the provided annotation to convert to
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls._field_types.values()
dlist_T = zip(*dlist)
@ -240,10 +248,15 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
def _dataclass_from_dict(d, typeannot):
cls = get_origin(typeannot) or typeannot
if d is None:
if d is None or typeannot is Any:
return d
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
is_optional, contained_type = _resolve_optional(typeannot)
if is_optional:
# an Optional not set to None, just use the contents of the Optional.
return _dataclass_from_dict(d, contained_type)
cls = get_origin(typeannot) or typeannot
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
types = cls._field_types.values()
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
elif issubclass(cls, (list, tuple)):
@ -315,3 +328,15 @@ def load_dataclass_jgzip(outfile, cls):
"""
with gzip.GzipFile(outfile, "rb") as f:
return load_dataclass(cast(IO, f), cls, binary=True)
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
return True, Any
return False, type_

View File

@ -85,6 +85,10 @@ class TestDatasetTypes(unittest.TestCase):
self._compare_with_scalar([dct], List[FrameAnnotation])
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
dct2 = dct.copy()
dct2["meta"] = {"d": 76}
self._compare_with_scalar(dct2, FrameAnnotation)
def _compare_with_scalar(self, obj, typeannot, repeat=3):
input = [obj] * 3
vect_output = types._dataclass_list_from_dict_list(input, typeannot)