Optional ReplaceableBase

Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables.

Reviewed By: davnov134

Differential Revision: D35118339

fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985
This commit is contained in:
Jeremy Reizenstein 2022-03-29 08:43:46 -07:00 committed by Facebook GitHub Bot
parent e332f9ffa4
commit 21262e38c7
4 changed files with 171 additions and 46 deletions

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
from typing import Optional, Union
import torch
@ -56,3 +57,20 @@ def get_device(x, device: Optional[Device] = None) -> torch.device:
# Default device is cpu
return torch.device("cpu")
# Provide get_origin and get_args even in Python 3.7.
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls): # pragma: no cover
return getattr(cls, "__origin__", None)
def get_args(cls): # pragma: no cover
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")

View File

@ -8,31 +8,15 @@
import dataclasses
import gzip
import json
import sys
from dataclasses import MISSING, Field, dataclass
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
import numpy as np
from pytorch3d.common.datatypes import get_args, get_origin
_X = TypeVar("_X")
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls):
return getattr(cls, "__origin__", None)
def get_args(cls):
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")
TF3 = Tuple[float, float, float]

View File

@ -10,9 +10,10 @@ import itertools
import warnings
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch3d.common.datatypes import get_args, get_origin
"""
@ -97,6 +98,8 @@ you can give a base class and the implementation is looked up by name in the glo
class B(Configurable):
a: A
a_class_type: str = "A2"
b: Optional[A]
b_class_type: Optional[str] = "A2"
def __post_init__(self):
run_auto_creation(self)
@ -124,6 +127,13 @@ will expand to
a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
b_class_type: Optional[str] = "A2"
b_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
b_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
def __post_init__(self):
if self.a_class_type == "A1":
@ -133,6 +143,15 @@ will expand to
else:
raise ValueError(...)
if self.b_class_type is None:
self.b = None
elif self.b_class_type == "A1":
self.b = A1(**self.b_A1_args)
elif self.b_class_type == "A2":
self.b = A2(**self.b_A2_args)
else:
raise ValueError(...)
3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them.
@ -324,7 +343,19 @@ class _Registry:
registry = _Registry()
def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]:
class _ProcessType(Enum):
"""
Type of member which gets rewritten by expand_args_fields.
"""
CONFIGURABLE = 1
REPLACEABLE = 2
OPTIONAL_REPLACEABLE = 3
def _default_create(
name: str, type_: Type, process_type: _ProcessType
) -> Callable[[Any], None]:
"""
Return the default creation function for a member. This is a function which
could be called in __post_init__ to initialise the member, and will be called
@ -332,8 +363,8 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
Args:
name: name of the member
type_: declared type of the member
pluggable: True if the member's declared type inherits ReplaceableBase,
type_: type of the member (with any Optional removed)
process_type: Shows whether member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at
runtime.
@ -349,6 +380,10 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
if type_name is None:
setattr(self, name, None)
return
chosen_class = registry.get(type_, type_name)
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
# If this warning is raised, it means that a new definition of
@ -362,7 +397,7 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any],
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args))
return inner_pluggable if pluggable else inner
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
def run_auto_creation(self: Any) -> None:
@ -499,7 +534,7 @@ def expand_args_fields(
The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered
subclasses Y and Z, replace
subclasses Y and Z, replace a class member
x: X
@ -518,7 +553,32 @@ def expand_args_fields(
)
x_class_type: str = "UNDEFAULTED"
without adding the optional things if they are already there.
without adding the optional attributes if they are already there.
Similarly, replace
x: Optional[X]
and optionally
x_class_type: Optional[str] = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
if self.x_class_type is None:
self.x = None
return
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: Optional[str] = "UNDEFAULTED"
without adding the optional attributes if they are already there.
Similarly, if X is a subclass of Configurable,
@ -587,26 +647,21 @@ def expand_args_fields(
if "_processed_members" in base.__dict__:
processed_members.update(base._processed_members)
to_process: List[Tuple[str, Type, bool]] = []
to_process: List[Tuple[str, Type, _ProcessType]] = []
if "__annotations__" in some_class.__dict__:
for name, type_ in some_class.__annotations__.items():
if not isinstance(type_, type):
# type_ could be something like typing.Tuple
underlying_and_process_type = _get_type_to_process(type_)
if underlying_and_process_type is None:
continue
if (
issubclass(type_, ReplaceableBase)
and ReplaceableBase in type_.__bases__
):
to_process.append((name, type_, True))
elif issubclass(type_, Configurable):
to_process.append((name, type_, False))
underlying_type, process_type = underlying_and_process_type
to_process.append((name, underlying_type, process_type))
for name, type_, pluggable in to_process:
for name, underlying_type, process_type in to_process:
_process_member(
name=name,
type_=type_,
pluggable=pluggable,
some_class=cast(type, some_class),
type_=underlying_type,
process_type=process_type,
some_class=some_class,
creation_functions=creation_functions,
_do_not_process=_do_not_process,
known_implementations=known_implementations,
@ -641,11 +696,39 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
return dataclasses.field(default_factory=create)
def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
"""
If a member is annotated as `type_`, and that should expanded in
expand_args_fields, return how it should be expanded.
"""
if get_origin(type_) == Union:
# We look for Optional[X] which is a Union of X with None.
args = get_args(type_)
if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721
return
underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721
if (
issubclass(underlying, ReplaceableBase)
and ReplaceableBase in underlying.__bases__
):
return underlying, _ProcessType.OPTIONAL_REPLACEABLE
if not isinstance(type_, type):
# e.g. any other Union or Tuple
return
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:
return type_, _ProcessType.REPLACEABLE
if issubclass(type_, Configurable):
return type_, _ProcessType.CONFIGURABLE
def _process_member(
*,
name: str,
type_: Type,
pluggable: bool,
process_type: _ProcessType,
some_class: Type,
creation_functions: List[str],
_do_not_process: Tuple[type, ...],
@ -656,8 +739,8 @@ def _process_member(
Args:
name: member name
type_: member declared type
plugglable: whether member has dynamic type
type_: member type (with Optional removed if needed)
process_type: whether member has dynamic type
some_class: (MODIFIED IN PLACE) the class being processed
creation_functions: (MODIFIED IN PLACE) the names of the create functions
_do_not_process: as for expand_args_fields.
@ -668,10 +751,13 @@ def _process_member(
# there are non-defaulted standard class members.
del some_class.__annotations__[name]
if pluggable:
if process_type != _ProcessType.CONFIGURABLE:
type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__:
some_class.__annotations__[type_name] = str
if process_type == _ProcessType.OPTIONAL_REPLACEABLE:
some_class.__annotations__[type_name] = Optional[str]
else:
some_class.__annotations__[type_name] = str
setattr(some_class, type_name, "UNDEFAULTED")
for derived_type in registry.get_all(type_):
@ -720,7 +806,7 @@ def _process_member(
setattr(
some_class,
creation_function_name,
_default_create(name, type_, pluggable),
_default_create(name, type_, process_type),
)
creation_functions.append(creation_function_name)
@ -743,7 +829,10 @@ def remove_unused_components(dict_: DictConfig) -> None:
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
for replaceable in replaceables:
selected_type = dict_[replaceable + TYPE_SUFFIX]
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
if selected_type is None:
expect = ""
else:
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
with open_dict(dict_):
for key in args_keys:
if key.startswith(replaceable + "_") and key != expect:

View File

@ -14,7 +14,9 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
_get_type_to_process,
_is_actually_dataclass,
_ProcessType,
_Registry,
expand_args_fields,
get_default_args,
@ -94,6 +96,19 @@ class TestConfig(unittest.TestCase):
self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B))
def test_get_type_to_process(self):
gt = _get_type_to_process
self.assertIsNone(gt(int))
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE))
self.assertEqual(
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE)
)
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE))
self.assertIsNone(gt(Optional[int]))
self.assertIsNone(gt(Optional[MainTest]))
self.assertIsNone(gt(Tuple[Fruit]))
self.assertIsNone(gt(Tuple[Fruit, Animal]))
def test_simple_replacement(self):
struct = get_default_args(MainTest)
struct.n_ids = 9780
@ -247,6 +262,7 @@ class TestConfig(unittest.TestCase):
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self):
# Also exercises optional replaceables
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Orange"
@ -255,8 +271,10 @@ class TestConfig(unittest.TestCase):
raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl):
extra_fruit: Fruit
extra_fruit: Optional[Fruit]
extra_fruit_class_type: str = "Kiwi"
no_fruit: Optional[Fruit]
no_fruit_class_type: Optional[str] = None
def __post_init__(self):
run_auto_creation(self)
@ -267,6 +285,22 @@ class TestConfig(unittest.TestCase):
large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi)
self.assertIsNone(large.no_fruit)
self.assertIn("no_fruit_Kiwi_args", large_args)
remove_unused_components(large_args)
large2 = LargeFruitBowl(**large_args)
self.assertIsInstance(large2.main_fruit, Orange)
self.assertIsInstance(large2.extra_fruit, Kiwi)
self.assertIsNone(large2.no_fruit)
needed_args = [
"extra_fruit_Kiwi_args",
"extra_fruit_class_type",
"main_fruit_Orange_args",
"main_fruit_class_type",
"no_fruit_class_type",
]
self.assertEqual(sorted(large_args.keys()), needed_args)
def test_inheritance2(self):
# This is a case where a class could contain an instance