diff --git a/pytorch3d/implicitron/dataset/types.py b/pytorch3d/implicitron/dataset/types.py index d1dd0a48..497b91ee 100644 --- a/pytorch3d/implicitron/dataset/types.py +++ b/pytorch3d/implicitron/dataset/types.py @@ -9,7 +9,7 @@ import dataclasses import gzip import json from dataclasses import dataclass, Field, MISSING -from typing import Any, cast, IO, Optional, Tuple, Type, TypeVar, Union +from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union import numpy as np from pytorch3d.common.datatypes import get_args, get_origin @@ -80,6 +80,7 @@ class FrameAnnotation: depth: Optional[DepthAnnotation] = None mask: Optional[MaskAnnotation] = None viewpoint: Optional[ViewpointAnnotation] = None + meta: Optional[Dict[str, Any]] = None @dataclass @@ -169,9 +170,11 @@ def _dataclass_list_from_dict_list(dlist, typeannot): cls = get_origin(typeannot) or typeannot + if typeannot is Any: + return dlist if all(obj is None for obj in dlist): # 1st recursion base: all None nodes return dlist - elif any(obj is None for obj in dlist): + if any(obj is None for obj in dlist): # filter out Nones and recurse on the resulting list idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] idx, notnone = zip(*idx_notnone) @@ -180,8 +183,13 @@ def _dataclass_list_from_dict_list(dlist, typeannot): for i, obj in zip(idx, converted): res[i] = obj return res + + is_optional, contained_type = _resolve_optional(typeannot) + if is_optional: + return _dataclass_list_from_dict_list(dlist, contained_type) + # otherwise, we dispatch by the type of the provided annotation to convert to - elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple + if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple # For namedtuple, call the function recursively on the lists of corresponding keys types = cls._field_types.values() dlist_T = zip(*dlist) @@ -240,10 +248,15 @@ def _dataclass_list_from_dict_list(dlist, typeannot): def _dataclass_from_dict(d, typeannot): - cls = get_origin(typeannot) or typeannot - if d is None: + if d is None or typeannot is Any: return d - elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple + is_optional, contained_type = _resolve_optional(typeannot) + if is_optional: + # an Optional not set to None, just use the contents of the Optional. + return _dataclass_from_dict(d, contained_type) + + cls = get_origin(typeannot) or typeannot + if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple types = cls._field_types.values() return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)]) elif issubclass(cls, (list, tuple)): @@ -315,3 +328,15 @@ def load_dataclass_jgzip(outfile, cls): """ with gzip.GzipFile(outfile, "rb") as f: return load_dataclass(cast(IO, f), cls, binary=True) + + +def _resolve_optional(type_: Any) -> Tuple[bool, Any]: + """Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" + if get_origin(type_) is Union: + args = get_args(type_) + if len(args) == 2 and args[1] == type(None): # noqa E721 + return True, args[0] + if type_ is Any: + return True, Any + + return False, type_ diff --git a/tests/implicitron/test_types.py b/tests/implicitron/test_types.py index 91338edc..56352b93 100644 --- a/tests/implicitron/test_types.py +++ b/tests/implicitron/test_types.py @@ -85,6 +85,10 @@ class TestDatasetTypes(unittest.TestCase): self._compare_with_scalar([dct], List[FrameAnnotation]) self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation]) + dct2 = dct.copy() + dct2["meta"] = {"d": 76} + self._compare_with_scalar(dct2, FrameAnnotation) + def _compare_with_scalar(self, obj, typeannot, repeat=3): input = [obj] * 3 vect_output = types._dataclass_list_from_dict_list(input, typeannot)