mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
FrameAnnotation.meta, Optional in _dataclass_from_dict
Summary: Allow extra data in a FrameAnnotation. Therefore allow Optional[T] systematically in _dataclass_from_dict Reviewed By: davnov134 Differential Revision: D36442691 fbshipit-source-id: ba70f6491574c08b0d9c9acb63f35514d29de214
This commit is contained in:
parent
f36b11fe49
commit
f632c423ef
@ -9,7 +9,7 @@ import dataclasses
|
||||
import gzip
|
||||
import json
|
||||
from dataclasses import dataclass, Field, MISSING
|
||||
from typing import Any, cast, IO, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pytorch3d.common.datatypes import get_args, get_origin
|
||||
@ -80,6 +80,7 @@ class FrameAnnotation:
|
||||
depth: Optional[DepthAnnotation] = None
|
||||
mask: Optional[MaskAnnotation] = None
|
||||
viewpoint: Optional[ViewpointAnnotation] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -169,9 +170,11 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
elif any(obj is None for obj in dlist):
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
@ -180,8 +183,13 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls._field_types.values()
|
||||
dlist_T = zip(*dlist)
|
||||
@ -240,10 +248,15 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
|
||||
|
||||
def _dataclass_from_dict(d, typeannot):
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
if d is None:
|
||||
if d is None or typeannot is Any:
|
||||
return d
|
||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
# an Optional not set to None, just use the contents of the Optional.
|
||||
return _dataclass_from_dict(d, contained_type)
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
types = cls._field_types.values()
|
||||
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
@ -315,3 +328,15 @@ def load_dataclass_jgzip(outfile, cls):
|
||||
"""
|
||||
with gzip.GzipFile(outfile, "rb") as f:
|
||||
return load_dataclass(cast(IO, f), cls, binary=True)
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
@ -85,6 +85,10 @@ class TestDatasetTypes(unittest.TestCase):
|
||||
self._compare_with_scalar([dct], List[FrameAnnotation])
|
||||
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
|
||||
|
||||
dct2 = dct.copy()
|
||||
dct2["meta"] = {"d": 76}
|
||||
self._compare_with_scalar(dct2, FrameAnnotation)
|
||||
|
||||
def _compare_with_scalar(self, obj, typeannot, repeat=3):
|
||||
input = [obj] * 3
|
||||
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
|
||||
|
Loading…
x
Reference in New Issue
Block a user