mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Optional ReplaceableBase
Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables. Reviewed By: davnov134 Differential Revision: D35118339 fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985
This commit is contained in:
parent
e332f9ffa4
commit
21262e38c7
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -56,3 +57,20 @@ def get_device(x, device: Optional[Device] = None) -> torch.device:
|
||||
|
||||
# Default device is cpu
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
# Provide get_origin and get_args even in Python 3.7.
|
||||
|
||||
if sys.version_info >= (3, 8, 0):
|
||||
from typing import get_args, get_origin
|
||||
elif sys.version_info >= (3, 7, 0):
|
||||
|
||||
def get_origin(cls): # pragma: no cover
|
||||
return getattr(cls, "__origin__", None)
|
||||
|
||||
def get_args(cls): # pragma: no cover
|
||||
return getattr(cls, "__args__", None)
|
||||
|
||||
|
||||
else:
|
||||
raise ImportError("This module requires Python 3.7+")
|
||||
|
@ -8,31 +8,15 @@
|
||||
import dataclasses
|
||||
import gzip
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import MISSING, Field, dataclass
|
||||
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from pytorch3d.common.datatypes import get_args, get_origin
|
||||
|
||||
|
||||
_X = TypeVar("_X")
|
||||
|
||||
|
||||
if sys.version_info >= (3, 8, 0):
|
||||
from typing import get_args, get_origin
|
||||
elif sys.version_info >= (3, 7, 0):
|
||||
|
||||
def get_origin(cls):
|
||||
return getattr(cls, "__origin__", None)
|
||||
|
||||
def get_args(cls):
|
||||
return getattr(cls, "__args__", None)
|
||||
|
||||
|
||||
else:
|
||||
raise ImportError("This module requires Python 3.7+")
|
||||
|
||||
|
||||
TF3 = Tuple[float, float, float]
|
||||
|
||||
|
||||
|
@ -10,9 +10,10 @@ import itertools
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
from pytorch3d.common.datatypes import get_args, get_origin
|
||||
|
||||
|
||||
"""
|
||||
@ -97,6 +98,8 @@ you can give a base class and the implementation is looked up by name in the glo
|
||||
class B(Configurable):
|
||||
a: A
|
||||
a_class_type: str = "A2"
|
||||
b: Optional[A]
|
||||
b_class_type: Optional[str] = "A2"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@ -124,6 +127,13 @@ will expand to
|
||||
a_A2_args: DictConfig = dataclasses.field(
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
)
|
||||
b_class_type: Optional[str] = "A2"
|
||||
b_A1_args: DictConfig = dataclasses.field(
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
)
|
||||
b_A2_args: DictConfig = dataclasses.field(
|
||||
default_factory=lambda: DictConfig({"k": 1, "m": 3}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.a_class_type == "A1":
|
||||
@ -133,6 +143,15 @@ will expand to
|
||||
else:
|
||||
raise ValueError(...)
|
||||
|
||||
if self.b_class_type is None:
|
||||
self.b = None
|
||||
elif self.b_class_type == "A1":
|
||||
self.b = A1(**self.b_A1_args)
|
||||
elif self.b_class_type == "A2":
|
||||
self.b = A2(**self.b_A2_args)
|
||||
else:
|
||||
raise ValueError(...)
|
||||
|
||||
3. Aside from these classes, the members of these classes should be things
|
||||
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
|
||||
can be built from them with DictConfigs and lists of them.
|
||||
@ -324,7 +343,19 @@ class _Registry:
|
||||
registry = _Registry()
|
||||
|
||||
|
||||
def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]:
|
||||
class _ProcessType(Enum):
|
||||
"""
|
||||
Type of member which gets rewritten by expand_args_fields.
|
||||
"""
|
||||
|
||||
CONFIGURABLE = 1
|
||||
REPLACEABLE = 2
|
||||
OPTIONAL_REPLACEABLE = 3
|
||||
|
||||
|
||||
def _default_create(
|
||||
name: str, type_: Type, process_type: _ProcessType
|
||||
) -> Callable[[Any], None]:
|
||||
"""
|
||||
Return the default creation function for a member. This is a function which
|
||||
could be called in __post_init__ to initialise the member, and will be called
|
||||
@ -332,8 +363,8 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
|
||||
|
||||
Args:
|
||||
name: name of the member
|
||||
type_: declared type of the member
|
||||
pluggable: True if the member's declared type inherits ReplaceableBase,
|
||||
type_: type of the member (with any Optional removed)
|
||||
process_type: Shows whether member's declared type inherits ReplaceableBase,
|
||||
in which case the actual type to be created is decided at
|
||||
runtime.
|
||||
|
||||
@ -349,6 +380,10 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
|
||||
|
||||
def inner_pluggable(self):
|
||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||
if type_name is None:
|
||||
setattr(self, name, None)
|
||||
return
|
||||
|
||||
chosen_class = registry.get(type_, type_name)
|
||||
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
|
||||
# If this warning is raised, it means that a new definition of
|
||||
@ -362,7 +397,7 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
|
||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
||||
setattr(self, name, chosen_class(**args))
|
||||
|
||||
return inner_pluggable if pluggable else inner
|
||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||
|
||||
|
||||
def run_auto_creation(self: Any) -> None:
|
||||
@ -499,7 +534,7 @@ def expand_args_fields(
|
||||
|
||||
The transformations this function makes, before the concluding
|
||||
dataclasses.dataclass, are as follows. if X is a base class with registered
|
||||
subclasses Y and Z, replace
|
||||
subclasses Y and Z, replace a class member
|
||||
|
||||
x: X
|
||||
|
||||
@ -518,7 +553,32 @@ def expand_args_fields(
|
||||
)
|
||||
x_class_type: str = "UNDEFAULTED"
|
||||
|
||||
without adding the optional things if they are already there.
|
||||
without adding the optional attributes if they are already there.
|
||||
|
||||
Similarly, replace
|
||||
|
||||
x: Optional[X]
|
||||
|
||||
and optionally
|
||||
|
||||
x_class_type: Optional[str] = "Y"
|
||||
def create_x(self):...
|
||||
|
||||
with
|
||||
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
|
||||
def create_x(self):
|
||||
if self.x_class_type is None:
|
||||
self.x = None
|
||||
return
|
||||
|
||||
self.x = registry.get(X, self.x_class_type)(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
x_class_type: Optional[str] = "UNDEFAULTED"
|
||||
|
||||
without adding the optional attributes if they are already there.
|
||||
|
||||
Similarly, if X is a subclass of Configurable,
|
||||
|
||||
@ -587,26 +647,21 @@ def expand_args_fields(
|
||||
if "_processed_members" in base.__dict__:
|
||||
processed_members.update(base._processed_members)
|
||||
|
||||
to_process: List[Tuple[str, Type, bool]] = []
|
||||
to_process: List[Tuple[str, Type, _ProcessType]] = []
|
||||
if "__annotations__" in some_class.__dict__:
|
||||
for name, type_ in some_class.__annotations__.items():
|
||||
if not isinstance(type_, type):
|
||||
# type_ could be something like typing.Tuple
|
||||
underlying_and_process_type = _get_type_to_process(type_)
|
||||
if underlying_and_process_type is None:
|
||||
continue
|
||||
if (
|
||||
issubclass(type_, ReplaceableBase)
|
||||
and ReplaceableBase in type_.__bases__
|
||||
):
|
||||
to_process.append((name, type_, True))
|
||||
elif issubclass(type_, Configurable):
|
||||
to_process.append((name, type_, False))
|
||||
underlying_type, process_type = underlying_and_process_type
|
||||
to_process.append((name, underlying_type, process_type))
|
||||
|
||||
for name, type_, pluggable in to_process:
|
||||
for name, underlying_type, process_type in to_process:
|
||||
_process_member(
|
||||
name=name,
|
||||
type_=type_,
|
||||
pluggable=pluggable,
|
||||
some_class=cast(type, some_class),
|
||||
type_=underlying_type,
|
||||
process_type=process_type,
|
||||
some_class=some_class,
|
||||
creation_functions=creation_functions,
|
||||
_do_not_process=_do_not_process,
|
||||
known_implementations=known_implementations,
|
||||
@ -641,11 +696,39 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
||||
|
||||
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
|
||||
"""
|
||||
If a member is annotated as `type_`, and that should expanded in
|
||||
expand_args_fields, return how it should be expanded.
|
||||
"""
|
||||
if get_origin(type_) == Union:
|
||||
# We look for Optional[X] which is a Union of X with None.
|
||||
args = get_args(type_)
|
||||
if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
|
||||
return
|
||||
underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||
if (
|
||||
issubclass(underlying, ReplaceableBase)
|
||||
and ReplaceableBase in underlying.__bases__
|
||||
):
|
||||
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
|
||||
|
||||
if not isinstance(type_, type):
|
||||
# e.g. any other Union or Tuple
|
||||
return
|
||||
|
||||
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
|
||||
return type_, _ProcessType.REPLACEABLE
|
||||
|
||||
if issubclass(type_, Configurable):
|
||||
return type_, _ProcessType.CONFIGURABLE
|
||||
|
||||
|
||||
def _process_member(
|
||||
*,
|
||||
name: str,
|
||||
type_: Type,
|
||||
pluggable: bool,
|
||||
process_type: _ProcessType,
|
||||
some_class: Type,
|
||||
creation_functions: List[str],
|
||||
_do_not_process: Tuple[type, ...],
|
||||
@ -656,8 +739,8 @@ def _process_member(
|
||||
|
||||
Args:
|
||||
name: member name
|
||||
type_: member declared type
|
||||
plugglable: whether member has dynamic type
|
||||
type_: member type (with Optional removed if needed)
|
||||
process_type: whether member has dynamic type
|
||||
some_class: (MODIFIED IN PLACE) the class being processed
|
||||
creation_functions: (MODIFIED IN PLACE) the names of the create functions
|
||||
_do_not_process: as for expand_args_fields.
|
||||
@ -668,10 +751,13 @@ def _process_member(
|
||||
# there are non-defaulted standard class members.
|
||||
del some_class.__annotations__[name]
|
||||
|
||||
if pluggable:
|
||||
if process_type != _ProcessType.CONFIGURABLE:
|
||||
type_name = name + TYPE_SUFFIX
|
||||
if type_name not in some_class.__annotations__:
|
||||
some_class.__annotations__[type_name] = str
|
||||
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
|
||||
some_class.__annotations__[type_name] = Optional[str]
|
||||
else:
|
||||
some_class.__annotations__[type_name] = str
|
||||
setattr(some_class, type_name, "UNDEFAULTED")
|
||||
|
||||
for derived_type in registry.get_all(type_):
|
||||
@ -720,7 +806,7 @@ def _process_member(
|
||||
setattr(
|
||||
some_class,
|
||||
creation_function_name,
|
||||
_default_create(name, type_, pluggable),
|
||||
_default_create(name, type_, process_type),
|
||||
)
|
||||
creation_functions.append(creation_function_name)
|
||||
|
||||
@ -743,7 +829,10 @@ def remove_unused_components(dict_: DictConfig) -> None:
|
||||
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
|
||||
for replaceable in replaceables:
|
||||
selected_type = dict_[replaceable + TYPE_SUFFIX]
|
||||
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
|
||||
if selected_type is None:
|
||||
expect = ""
|
||||
else:
|
||||
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
|
||||
with open_dict(dict_):
|
||||
for key in args_keys:
|
||||
if key.startswith(replaceable + "_") and key != expect:
|
||||
|
@ -14,7 +14,9 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
ReplaceableBase,
|
||||
_get_type_to_process,
|
||||
_is_actually_dataclass,
|
||||
_ProcessType,
|
||||
_Registry,
|
||||
expand_args_fields,
|
||||
get_default_args,
|
||||
@ -94,6 +96,19 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertFalse(_is_actually_dataclass(B))
|
||||
self.assertTrue(is_dataclass(B))
|
||||
|
||||
def test_get_type_to_process(self):
|
||||
gt = _get_type_to_process
|
||||
self.assertIsNone(gt(int))
|
||||
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
|
||||
self.assertEqual(
|
||||
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
|
||||
)
|
||||
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
|
||||
self.assertIsNone(gt(Optional[int]))
|
||||
self.assertIsNone(gt(Optional[MainTest]))
|
||||
self.assertIsNone(gt(Tuple[Fruit]))
|
||||
self.assertIsNone(gt(Tuple[Fruit, Animal]))
|
||||
|
||||
def test_simple_replacement(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 9780
|
||||
@ -247,6 +262,7 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
|
||||
|
||||
def test_inheritance(self):
|
||||
# Also exercises optional replaceables
|
||||
class FruitBowl(ReplaceableBase):
|
||||
main_fruit: Fruit
|
||||
main_fruit_class_type: str = "Orange"
|
||||
@ -255,8 +271,10 @@ class TestConfig(unittest.TestCase):
|
||||
raise ValueError("This doesn't get called")
|
||||
|
||||
class LargeFruitBowl(FruitBowl):
|
||||
extra_fruit: Fruit
|
||||
extra_fruit: Optional[Fruit]
|
||||
extra_fruit_class_type: str = "Kiwi"
|
||||
no_fruit: Optional[Fruit]
|
||||
no_fruit_class_type: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@ -267,6 +285,22 @@ class TestConfig(unittest.TestCase):
|
||||
large = LargeFruitBowl(**large_args)
|
||||
self.assertIsInstance(large.main_fruit, Orange)
|
||||
self.assertIsInstance(large.extra_fruit, Kiwi)
|
||||
self.assertIsNone(large.no_fruit)
|
||||
self.assertIn("no_fruit_Kiwi_args", large_args)
|
||||
|
||||
remove_unused_components(large_args)
|
||||
large2 = LargeFruitBowl(**large_args)
|
||||
self.assertIsInstance(large2.main_fruit, Orange)
|
||||
self.assertIsInstance(large2.extra_fruit, Kiwi)
|
||||
self.assertIsNone(large2.no_fruit)
|
||||
needed_args = [
|
||||
"extra_fruit_Kiwi_args",
|
||||
"extra_fruit_class_type",
|
||||
"main_fruit_Orange_args",
|
||||
"main_fruit_class_type",
|
||||
"no_fruit_class_type",
|
||||
]
|
||||
self.assertEqual(sorted(large_args.keys()), needed_args)
|
||||
|
||||
def test_inheritance2(self):
|
||||
# This is a case where a class could contain an instance
|
||||
|
Loading…
x
Reference in New Issue
Block a user