Optional ReplaceableBase

Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables.

Reviewed By: davnov134

Differential Revision: D35118339

fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985
This commit is contained in:
Jeremy Reizenstein 2022-03-29 08:43:46 -07:00 committed by Facebook GitHub Bot
parent e332f9ffa4
commit 21262e38c7
4 changed files with 171 additions and 46 deletions

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sys
from typing import Optional, Union from typing import Optional, Union
import torch import torch
@ -56,3 +57,20 @@ def get_device(x, device: Optional[Device] = None) -> torch.device:
# Default device is cpu # Default device is cpu
return torch.device("cpu") return torch.device("cpu")
# Provide get_origin and get_args even in Python 3.7.
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls): # pragma: no cover
return getattr(cls, "__origin__", None)
def get_args(cls): # pragma: no cover
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")

View File

@ -8,31 +8,15 @@
import dataclasses import dataclasses
import gzip import gzip
import json import json
import sys
from dataclasses import MISSING, Field, dataclass from dataclasses import MISSING, Field, dataclass
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
import numpy as np import numpy as np
from pytorch3d.common.datatypes import get_args, get_origin
_X = TypeVar("_X") _X = TypeVar("_X")
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls):
return getattr(cls, "__origin__", None)
def get_args(cls):
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")
TF3 = Tuple[float, float, float] TF3 = Tuple[float, float, float]

View File

@ -10,9 +10,10 @@ import itertools
import warnings import warnings
from collections import Counter, defaultdict from collections import Counter, defaultdict
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch3d.common.datatypes import get_args, get_origin
""" """
@ -97,6 +98,8 @@ you can give a base class and the implementation is looked up by name in the glo
class B(Configurable): class B(Configurable):
a: A a: A
a_class_type: str = "A2" a_class_type: str = "A2"
b: Optional[A]
b_class_type: Optional[str] = "A2"
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -124,6 +127,13 @@ will expand to
a_A2_args: DictConfig = dataclasses.field( a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3} default_factory=lambda: DictConfig({"k": 1, "m": 3}
) )
b_class_type: Optional[str] = "A2"
b_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
b_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
def __post_init__(self): def __post_init__(self):
if self.a_class_type == "A1": if self.a_class_type == "A1":
@ -133,6 +143,15 @@ will expand to
else: else:
raise ValueError(...) raise ValueError(...)
if self.b_class_type is None:
self.b = None
elif self.b_class_type == "A1":
self.b = A1(**self.b_A1_args)
elif self.b_class_type == "A2":
self.b = A2(**self.b_A2_args)
else:
raise ValueError(...)
3. Aside from these classes, the members of these classes should be things 3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them. can be built from them with DictConfigs and lists of them.
@ -324,7 +343,19 @@ class _Registry:
registry = _Registry() registry = _Registry()
def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]: class _ProcessType(Enum):
"""
Type of member which gets rewritten by expand_args_fields.
"""
CONFIGURABLE = 1
REPLACEABLE = 2
OPTIONAL_REPLACEABLE = 3
def _default_create(
name: str, type_: Type, process_type: _ProcessType
) -> Callable[[Any], None]:
""" """
Return the default creation function for a member. This is a function which Return the default creation function for a member. This is a function which
could be called in __post_init__ to initialise the member, and will be called could be called in __post_init__ to initialise the member, and will be called
@ -332,8 +363,8 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
Args: Args:
name: name of the member name: name of the member
type_: declared type of the member type_: type of the member (with any Optional removed)
pluggable: True if the member's declared type inherits ReplaceableBase, process_type: Shows whether member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at in which case the actual type to be created is decided at
runtime. runtime.
@ -349,6 +380,10 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
def inner_pluggable(self): def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX) type_name = getattr(self, name + TYPE_SUFFIX)
if type_name is None:
setattr(self, name, None)
return
chosen_class = registry.get(type_, type_name) chosen_class = registry.get(type_, type_name)
if self._known_implementations.get(type_name, chosen_class) is not chosen_class: if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
# If this warning is raised, it means that a new definition of # If this warning is raised, it means that a new definition of
@ -362,7 +397,7 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}") args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args)) setattr(self, name, chosen_class(**args))
return inner_pluggable if pluggable else inner return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
def run_auto_creation(self: Any) -> None: def run_auto_creation(self: Any) -> None:
@ -499,7 +534,7 @@ def expand_args_fields(
The transformations this function makes, before the concluding The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered dataclasses.dataclass, are as follows. if X is a base class with registered
subclasses Y and Z, replace subclasses Y and Z, replace a class member
x: X x: X
@ -518,7 +553,32 @@ def expand_args_fields(
) )
x_class_type: str = "UNDEFAULTED" x_class_type: str = "UNDEFAULTED"
without adding the optional things if they are already there. without adding the optional attributes if they are already there.
Similarly, replace
x: Optional[X]
and optionally
x_class_type: Optional[str] = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
if self.x_class_type is None:
self.x = None
return
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: Optional[str] = "UNDEFAULTED"
without adding the optional attributes if they are already there.
Similarly, if X is a subclass of Configurable, Similarly, if X is a subclass of Configurable,
@ -587,26 +647,21 @@ def expand_args_fields(
if "_processed_members" in base.__dict__: if "_processed_members" in base.__dict__:
processed_members.update(base._processed_members) processed_members.update(base._processed_members)
to_process: List[Tuple[str, Type, bool]] = [] to_process: List[Tuple[str, Type, _ProcessType]] = []
if "__annotations__" in some_class.__dict__: if "__annotations__" in some_class.__dict__:
for name, type_ in some_class.__annotations__.items(): for name, type_ in some_class.__annotations__.items():
if not isinstance(type_, type): underlying_and_process_type = _get_type_to_process(type_)
# type_ could be something like typing.Tuple if underlying_and_process_type is None:
continue continue
if ( underlying_type, process_type = underlying_and_process_type
issubclass(type_, ReplaceableBase) to_process.append((name, underlying_type, process_type))
and ReplaceableBase in type_.__bases__
):
to_process.append((name, type_, True))
elif issubclass(type_, Configurable):
to_process.append((name, type_, False))
for name, type_, pluggable in to_process: for name, underlying_type, process_type in to_process:
_process_member( _process_member(
name=name, name=name,
type_=type_, type_=underlying_type,
pluggable=pluggable, process_type=process_type,
some_class=cast(type, some_class), some_class=some_class,
creation_functions=creation_functions, creation_functions=creation_functions,
_do_not_process=_do_not_process, _do_not_process=_do_not_process,
known_implementations=known_implementations, known_implementations=known_implementations,
@ -641,11 +696,39 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
return dataclasses.field(default_factory=create) return dataclasses.field(default_factory=create)
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
"""
If a member is annotated as `type_`, and that should expanded in
expand_args_fields, return how it should be expanded.
"""
if get_origin(type_) == Union:
# We look for Optional[X] which is a Union of X with None.
args = get_args(type_)
if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
return
underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
if (
issubclass(underlying, ReplaceableBase)
and ReplaceableBase in underlying.__bases__
):
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
if not isinstance(type_, type):
# e.g. any other Union or Tuple
return
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
return type_, _ProcessType.REPLACEABLE
if issubclass(type_, Configurable):
return type_, _ProcessType.CONFIGURABLE
def _process_member( def _process_member(
*, *,
name: str, name: str,
type_: Type, type_: Type,
pluggable: bool, process_type: _ProcessType,
some_class: Type, some_class: Type,
creation_functions: List[str], creation_functions: List[str],
_do_not_process: Tuple[type, ...], _do_not_process: Tuple[type, ...],
@ -656,8 +739,8 @@ def _process_member(
Args: Args:
name: member name name: member name
type_: member declared type type_: member type (with Optional removed if needed)
plugglable: whether member has dynamic type process_type: whether member has dynamic type
some_class: (MODIFIED IN PLACE) the class being processed some_class: (MODIFIED IN PLACE) the class being processed
creation_functions: (MODIFIED IN PLACE) the names of the create functions creation_functions: (MODIFIED IN PLACE) the names of the create functions
_do_not_process: as for expand_args_fields. _do_not_process: as for expand_args_fields.
@ -668,9 +751,12 @@ def _process_member(
# there are non-defaulted standard class members. # there are non-defaulted standard class members.
del some_class.__annotations__[name] del some_class.__annotations__[name]
if pluggable: if process_type != _ProcessType.CONFIGURABLE:
type_name = name + TYPE_SUFFIX type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__: if type_name not in some_class.__annotations__:
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
some_class.__annotations__[type_name] = Optional[str]
else:
some_class.__annotations__[type_name] = str some_class.__annotations__[type_name] = str
setattr(some_class, type_name, "UNDEFAULTED") setattr(some_class, type_name, "UNDEFAULTED")
@ -720,7 +806,7 @@ def _process_member(
setattr( setattr(
some_class, some_class,
creation_function_name, creation_function_name,
_default_create(name, type_, pluggable), _default_create(name, type_, process_type),
) )
creation_functions.append(creation_function_name) creation_functions.append(creation_function_name)
@ -743,6 +829,9 @@ def remove_unused_components(dict_: DictConfig) -> None:
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)] args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
for replaceable in replaceables: for replaceable in replaceables:
selected_type = dict_[replaceable + TYPE_SUFFIX] selected_type = dict_[replaceable + TYPE_SUFFIX]
if selected_type is None:
expect = ""
else:
expect = replaceable + "_" + selected_type + ARGS_SUFFIX expect = replaceable + "_" + selected_type + ARGS_SUFFIX
with open_dict(dict_): with open_dict(dict_):
for key in args_keys: for key in args_keys:

View File

@ -14,7 +14,9 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
Configurable, Configurable,
ReplaceableBase, ReplaceableBase,
_get_type_to_process,
_is_actually_dataclass, _is_actually_dataclass,
_ProcessType,
_Registry, _Registry,
expand_args_fields, expand_args_fields,
get_default_args, get_default_args,
@ -94,6 +96,19 @@ class TestConfig(unittest.TestCase):
self.assertFalse(_is_actually_dataclass(B)) self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B)) self.assertTrue(is_dataclass(B))
def test_get_type_to_process(self):
gt = _get_type_to_process
self.assertIsNone(gt(int))
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
self.assertEqual(
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
)
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Optional[MainTest]))
self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal]))
def test_simple_replacement(self): def test_simple_replacement(self):
struct = get_default_args(MainTest) struct = get_default_args(MainTest)
struct.n_ids = 9780 struct.n_ids = 9780
@ -247,6 +262,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(container.fruit_Pear_args.n_pips, 13) self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self): def test_inheritance(self):
# Also exercises optional replaceables
class FruitBowl(ReplaceableBase): class FruitBowl(ReplaceableBase):
main_fruit: Fruit main_fruit: Fruit
main_fruit_class_type: str = "Orange" main_fruit_class_type: str = "Orange"
@ -255,8 +271,10 @@ class TestConfig(unittest.TestCase):
raise ValueError("This doesn't get called") raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl): class LargeFruitBowl(FruitBowl):
extra_fruit: Fruit extra_fruit: Optional[Fruit]
extra_fruit_class_type: str = "Kiwi" extra_fruit_class_type: str = "Kiwi"
no_fruit: Optional[Fruit]
no_fruit_class_type: Optional[str] = None
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -267,6 +285,22 @@ class TestConfig(unittest.TestCase):
large = LargeFruitBowl(**large_args) large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange) self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi) self.assertIsInstance(large.extra_fruit, Kiwi)
self.assertIsNone(large.no_fruit)
self.assertIn("no_fruit_Kiwi_args", large_args)
remove_unused_components(large_args)
large2 = LargeFruitBowl(**large_args)
self.assertIsInstance(large2.main_fruit, Orange)
self.assertIsInstance(large2.extra_fruit, Kiwi)
self.assertIsNone(large2.no_fruit)
needed_args = [
"extra_fruit_Kiwi_args",
"extra_fruit_class_type",
"main_fruit_Orange_args",
"main_fruit_class_type",
"no_fruit_class_type",
]
self.assertEqual(sorted(large_args.keys()), needed_args)
def test_inheritance2(self): def test_inheritance2(self):
# This is a case where a class could contain an instance # This is a case where a class could contain an instance