From 21262e38c7c064f246ec098f7bd53d273497a95b Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 29 Mar 2022 08:43:46 -0700 Subject: [PATCH] Optional ReplaceableBase Summary: Allow things like `renderer:Optional[BaseRenderer]` in configurables. Reviewed By: davnov134 Differential Revision: D35118339 fbshipit-source-id: 1219321b2817ed4b26fe924c6d6f73887095c985 --- pytorch3d/common/datatypes.py | 18 +++ pytorch3d/implicitron/dataset/types.py | 18 +-- pytorch3d/implicitron/tools/config.py | 145 ++++++++++++++++++++----- tests/implicitron/test_config.py | 36 +++++- 4 files changed, 171 insertions(+), 46 deletions(-) diff --git a/pytorch3d/common/datatypes.py b/pytorch3d/common/datatypes.py index 0a9d14d5..489d60ce 100644 --- a/pytorch3d/common/datatypes.py +++ b/pytorch3d/common/datatypes.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import sys from typing import Optional, Union import torch @@ -56,3 +57,20 @@ def get_device(x, device: Optional[Device] = None) -> torch.device: # Default device is cpu return torch.device("cpu") + + +# Provide get_origin and get_args even in Python 3.7. + +if sys.version_info >= (3, 8, 0): + from typing import get_args, get_origin +elif sys.version_info >= (3, 7, 0): + + def get_origin(cls): # pragma: no cover + return getattr(cls, "__origin__", None) + + def get_args(cls): # pragma: no cover + return getattr(cls, "__args__", None) + + +else: + raise ImportError("This module requires Python 3.7+") diff --git a/pytorch3d/implicitron/dataset/types.py b/pytorch3d/implicitron/dataset/types.py index 1264dfb9..cf744b99 100644 --- a/pytorch3d/implicitron/dataset/types.py +++ b/pytorch3d/implicitron/dataset/types.py @@ -8,31 +8,15 @@ import dataclasses import gzip import json -import sys from dataclasses import MISSING, Field, dataclass from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast import numpy as np +from pytorch3d.common.datatypes import get_args, get_origin _X = TypeVar("_X") - -if sys.version_info >= (3, 8, 0): - from typing import get_args, get_origin -elif sys.version_info >= (3, 7, 0): - - def get_origin(cls): - return getattr(cls, "__origin__", None) - - def get_args(cls): - return getattr(cls, "__args__", None) - - -else: - raise ImportError("This module requires Python 3.7+") - - TF3 = Tuple[float, float, float] diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index c867caee..c75b6871 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -10,9 +10,10 @@ import itertools import warnings from collections import Counter, defaultdict from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch3d.common.datatypes import get_args, get_origin """ @@ -97,6 +98,8 @@ you can give a base class and the implementation is looked up by name in the glo class B(Configurable): a: A a_class_type: str = "A2" + b: Optional[A] + b_class_type: Optional[str] = "A2" def __post_init__(self): run_auto_creation(self) @@ -124,6 +127,13 @@ will expand to a_A2_args: DictConfig = dataclasses.field( default_factory=lambda: DictConfig({"k": 1, "m": 3} ) + b_class_type: Optional[str] = "A2" + b_A1_args: DictConfig = dataclasses.field( + default_factory=lambda: DictConfig({"k": 1, "m": 3} + ) + b_A2_args: DictConfig = dataclasses.field( + default_factory=lambda: DictConfig({"k": 1, "m": 3} + ) def __post_init__(self): if self.a_class_type == "A1": @@ -133,6 +143,15 @@ will expand to else: raise ValueError(...) + if self.b_class_type is None: + self.b = None + elif self.b_class_type == "A1": + self.b = A1(**self.b_A1_args) + elif self.b_class_type == "A2": + self.b = A2(**self.b_A2_args) + else: + raise ValueError(...) + 3. Aside from these classes, the members of these classes should be things which DictConfig is happy with: e.g. (bool, int, str, None, float) and what can be built from them with DictConfigs and lists of them. @@ -324,7 +343,19 @@ class _Registry: registry = _Registry() -def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]: +class _ProcessType(Enum): + """ + Type of member which gets rewritten by expand_args_fields. + """ + + CONFIGURABLE = 1 + REPLACEABLE = 2 + OPTIONAL_REPLACEABLE = 3 + + +def _default_create( + name: str, type_: Type, process_type: _ProcessType +) -> Callable[[Any], None]: """ Return the default creation function for a member. This is a function which could be called in __post_init__ to initialise the member, and will be called @@ -332,8 +363,8 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], Args: name: name of the member - type_: declared type of the member - pluggable: True if the member's declared type inherits ReplaceableBase, + type_: type of the member (with any Optional removed) + process_type: Shows whether member's declared type inherits ReplaceableBase, in which case the actual type to be created is decided at runtime. @@ -349,6 +380,10 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], def inner_pluggable(self): type_name = getattr(self, name + TYPE_SUFFIX) + if type_name is None: + setattr(self, name, None) + return + chosen_class = registry.get(type_, type_name) if self._known_implementations.get(type_name, chosen_class) is not chosen_class: # If this warning is raised, it means that a new definition of @@ -362,7 +397,7 @@ def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}") setattr(self, name, chosen_class(**args)) - return inner_pluggable if pluggable else inner + return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable def run_auto_creation(self: Any) -> None: @@ -499,7 +534,7 @@ def expand_args_fields( The transformations this function makes, before the concluding dataclasses.dataclass, are as follows. if X is a base class with registered - subclasses Y and Z, replace + subclasses Y and Z, replace a class member x: X @@ -518,7 +553,32 @@ def expand_args_fields( ) x_class_type: str = "UNDEFAULTED" - without adding the optional things if they are already there. + without adding the optional attributes if they are already there. + + Similarly, replace + + x: Optional[X] + + and optionally + + x_class_type: Optional[str] = "Y" + def create_x(self):... + + with + + x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + def create_x(self): + if self.x_class_type is None: + self.x = None + return + + self.x = registry.get(X, self.x_class_type)( + **self.getattr(f"x_{self.x_class_type}_args) + ) + x_class_type: Optional[str] = "UNDEFAULTED" + + without adding the optional attributes if they are already there. Similarly, if X is a subclass of Configurable, @@ -587,26 +647,21 @@ def expand_args_fields( if "_processed_members" in base.__dict__: processed_members.update(base._processed_members) - to_process: List[Tuple[str, Type, bool]] = [] + to_process: List[Tuple[str, Type, _ProcessType]] = [] if "__annotations__" in some_class.__dict__: for name, type_ in some_class.__annotations__.items(): - if not isinstance(type_, type): - # type_ could be something like typing.Tuple + underlying_and_process_type = _get_type_to_process(type_) + if underlying_and_process_type is None: continue - if ( - issubclass(type_, ReplaceableBase) - and ReplaceableBase in type_.__bases__ - ): - to_process.append((name, type_, True)) - elif issubclass(type_, Configurable): - to_process.append((name, type_, False)) + underlying_type, process_type = underlying_and_process_type + to_process.append((name, underlying_type, process_type)) - for name, type_, pluggable in to_process: + for name, underlying_type, process_type in to_process: _process_member( name=name, - type_=type_, - pluggable=pluggable, - some_class=cast(type, some_class), + type_=underlying_type, + process_type=process_type, + some_class=some_class, creation_functions=creation_functions, _do_not_process=_do_not_process, known_implementations=known_implementations, @@ -641,11 +696,39 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): return dataclasses.field(default_factory=create) +def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]: + """ + If a member is annotated as `type_`, and that should expanded in + expand_args_fields, return how it should be expanded. + """ + if get_origin(type_) == Union: + # We look for Optional[X] which is a Union of X with None. + args = get_args(type_) + if len(args) != 2 or all(a is not type(None) for a in args): # noqa: E721 + return + underlying = args[0] if args[1] is type(None) else args[1] # noqa: E721 + if ( + issubclass(underlying, ReplaceableBase) + and ReplaceableBase in underlying.__bases__ + ): + return underlying, _ProcessType.OPTIONAL_REPLACEABLE + + if not isinstance(type_, type): + # e.g. any other Union or Tuple + return + + if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__: + return type_, _ProcessType.REPLACEABLE + + if issubclass(type_, Configurable): + return type_, _ProcessType.CONFIGURABLE + + def _process_member( *, name: str, type_: Type, - pluggable: bool, + process_type: _ProcessType, some_class: Type, creation_functions: List[str], _do_not_process: Tuple[type, ...], @@ -656,8 +739,8 @@ def _process_member( Args: name: member name - type_: member declared type - plugglable: whether member has dynamic type + type_: member type (with Optional removed if needed) + process_type: whether member has dynamic type some_class: (MODIFIED IN PLACE) the class being processed creation_functions: (MODIFIED IN PLACE) the names of the create functions _do_not_process: as for expand_args_fields. @@ -668,10 +751,13 @@ def _process_member( # there are non-defaulted standard class members. del some_class.__annotations__[name] - if pluggable: + if process_type != _ProcessType.CONFIGURABLE: type_name = name + TYPE_SUFFIX if type_name not in some_class.__annotations__: - some_class.__annotations__[type_name] = str + if process_type == _ProcessType.OPTIONAL_REPLACEABLE: + some_class.__annotations__[type_name] = Optional[str] + else: + some_class.__annotations__[type_name] = str setattr(some_class, type_name, "UNDEFAULTED") for derived_type in registry.get_all(type_): @@ -720,7 +806,7 @@ def _process_member( setattr( some_class, creation_function_name, - _default_create(name, type_, pluggable), + _default_create(name, type_, process_type), ) creation_functions.append(creation_function_name) @@ -743,7 +829,10 @@ def remove_unused_components(dict_: DictConfig) -> None: args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)] for replaceable in replaceables: selected_type = dict_[replaceable + TYPE_SUFFIX] - expect = replaceable + "_" + selected_type + ARGS_SUFFIX + if selected_type is None: + expect = "" + else: + expect = replaceable + "_" + selected_type + ARGS_SUFFIX with open_dict(dict_): for key in args_keys: if key.startswith(replaceable + "_") and key != expect: diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index dcc98848..8fe5aafd 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -14,7 +14,9 @@ from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError from pytorch3d.implicitron.tools.config import ( Configurable, ReplaceableBase, + _get_type_to_process, _is_actually_dataclass, + _ProcessType, _Registry, expand_args_fields, get_default_args, @@ -94,6 +96,19 @@ class TestConfig(unittest.TestCase): self.assertFalse(_is_actually_dataclass(B)) self.assertTrue(is_dataclass(B)) + def test_get_type_to_process(self): + gt = _get_type_to_process + self.assertIsNone(gt(int)) + self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE)) + self.assertEqual( + gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE) + ) + self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE)) + self.assertIsNone(gt(Optional[int])) + self.assertIsNone(gt(Optional[MainTest])) + self.assertIsNone(gt(Tuple[Fruit])) + self.assertIsNone(gt(Tuple[Fruit, Animal])) + def test_simple_replacement(self): struct = get_default_args(MainTest) struct.n_ids = 9780 @@ -247,6 +262,7 @@ class TestConfig(unittest.TestCase): self.assertEqual(container.fruit_Pear_args.n_pips, 13) def test_inheritance(self): + # Also exercises optional replaceables class FruitBowl(ReplaceableBase): main_fruit: Fruit main_fruit_class_type: str = "Orange" @@ -255,8 +271,10 @@ class TestConfig(unittest.TestCase): raise ValueError("This doesn't get called") class LargeFruitBowl(FruitBowl): - extra_fruit: Fruit + extra_fruit: Optional[Fruit] extra_fruit_class_type: str = "Kiwi" + no_fruit: Optional[Fruit] + no_fruit_class_type: Optional[str] = None def __post_init__(self): run_auto_creation(self) @@ -267,6 +285,22 @@ class TestConfig(unittest.TestCase): large = LargeFruitBowl(**large_args) self.assertIsInstance(large.main_fruit, Orange) self.assertIsInstance(large.extra_fruit, Kiwi) + self.assertIsNone(large.no_fruit) + self.assertIn("no_fruit_Kiwi_args", large_args) + + remove_unused_components(large_args) + large2 = LargeFruitBowl(**large_args) + self.assertIsInstance(large2.main_fruit, Orange) + self.assertIsInstance(large2.extra_fruit, Kiwi) + self.assertIsNone(large2.no_fruit) + needed_args = [ + "extra_fruit_Kiwi_args", + "extra_fruit_class_type", + "main_fruit_Orange_args", + "main_fruit_class_type", + "no_fruit_class_type", + ] + self.assertEqual(sorted(large_args.keys()), needed_args) def test_inheritance2(self): # This is a case where a class could contain an instance