mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
FrameAnnotation.meta, Optional in _dataclass_from_dict
Summary: Allow extra data in a FrameAnnotation. Therefore allow Optional[T] systematically in _dataclass_from_dict Reviewed By: davnov134 Differential Revision: D36442691 fbshipit-source-id: ba70f6491574c08b0d9c9acb63f35514d29de214
This commit is contained in:
parent
f36b11fe49
commit
f632c423ef
@ -9,7 +9,7 @@ import dataclasses
|
|||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, Field, MISSING
|
from dataclasses import dataclass, Field, MISSING
|
||||||
from typing import Any, cast, IO, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pytorch3d.common.datatypes import get_args, get_origin
|
from pytorch3d.common.datatypes import get_args, get_origin
|
||||||
@ -80,6 +80,7 @@ class FrameAnnotation:
|
|||||||
depth: Optional[DepthAnnotation] = None
|
depth: Optional[DepthAnnotation] = None
|
||||||
mask: Optional[MaskAnnotation] = None
|
mask: Optional[MaskAnnotation] = None
|
||||||
viewpoint: Optional[ViewpointAnnotation] = None
|
viewpoint: Optional[ViewpointAnnotation] = None
|
||||||
|
meta: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -169,9 +170,11 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
|
|
||||||
cls = get_origin(typeannot) or typeannot
|
cls = get_origin(typeannot) or typeannot
|
||||||
|
|
||||||
|
if typeannot is Any:
|
||||||
|
return dlist
|
||||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||||
return dlist
|
return dlist
|
||||||
elif any(obj is None for obj in dlist):
|
if any(obj is None for obj in dlist):
|
||||||
# filter out Nones and recurse on the resulting list
|
# filter out Nones and recurse on the resulting list
|
||||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||||
idx, notnone = zip(*idx_notnone)
|
idx, notnone = zip(*idx_notnone)
|
||||||
@ -180,8 +183,13 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
for i, obj in zip(idx, converted):
|
for i, obj in zip(idx, converted):
|
||||||
res[i] = obj
|
res[i] = obj
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
is_optional, contained_type = _resolve_optional(typeannot)
|
||||||
|
if is_optional:
|
||||||
|
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||||
|
|
||||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||||
types = cls._field_types.values()
|
types = cls._field_types.values()
|
||||||
dlist_T = zip(*dlist)
|
dlist_T = zip(*dlist)
|
||||||
@ -240,10 +248,15 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
|
|
||||||
|
|
||||||
def _dataclass_from_dict(d, typeannot):
|
def _dataclass_from_dict(d, typeannot):
|
||||||
cls = get_origin(typeannot) or typeannot
|
if d is None or typeannot is Any:
|
||||||
if d is None:
|
|
||||||
return d
|
return d
|
||||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
is_optional, contained_type = _resolve_optional(typeannot)
|
||||||
|
if is_optional:
|
||||||
|
# an Optional not set to None, just use the contents of the Optional.
|
||||||
|
return _dataclass_from_dict(d, contained_type)
|
||||||
|
|
||||||
|
cls = get_origin(typeannot) or typeannot
|
||||||
|
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||||
types = cls._field_types.values()
|
types = cls._field_types.values()
|
||||||
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
||||||
elif issubclass(cls, (list, tuple)):
|
elif issubclass(cls, (list, tuple)):
|
||||||
@ -315,3 +328,15 @@ def load_dataclass_jgzip(outfile, cls):
|
|||||||
"""
|
"""
|
||||||
with gzip.GzipFile(outfile, "rb") as f:
|
with gzip.GzipFile(outfile, "rb") as f:
|
||||||
return load_dataclass(cast(IO, f), cls, binary=True)
|
return load_dataclass(cast(IO, f), cls, binary=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||||
|
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||||
|
if get_origin(type_) is Union:
|
||||||
|
args = get_args(type_)
|
||||||
|
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||||
|
return True, args[0]
|
||||||
|
if type_ is Any:
|
||||||
|
return True, Any
|
||||||
|
|
||||||
|
return False, type_
|
||||||
|
@ -85,6 +85,10 @@ class TestDatasetTypes(unittest.TestCase):
|
|||||||
self._compare_with_scalar([dct], List[FrameAnnotation])
|
self._compare_with_scalar([dct], List[FrameAnnotation])
|
||||||
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
|
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
|
||||||
|
|
||||||
|
dct2 = dct.copy()
|
||||||
|
dct2["meta"] = {"d": 76}
|
||||||
|
self._compare_with_scalar(dct2, FrameAnnotation)
|
||||||
|
|
||||||
def _compare_with_scalar(self, obj, typeannot, repeat=3):
|
def _compare_with_scalar(self, obj, typeannot, repeat=3):
|
||||||
input = [obj] * 3
|
input = [obj] * 3
|
||||||
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
|
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user