make expand_args_fields optional

Summary: Call expand_args_field when instantiating an object.

Reviewed By: shapovalov

Differential Revision: D39541931

fbshipit-source-id: de8e1038927ff0112463394412d5d8c26c4a1e17
This commit is contained in:
Jeremy Reizenstein 2022-09-22 08:36:09 -07:00 committed by Facebook GitHub Bot
parent 209c160a20
commit d6a197be36
4 changed files with 467 additions and 485 deletions

File diff suppressed because it is too large Load Diff

View File

@ -147,7 +147,7 @@
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
"from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n",
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
"from pytorch3d.renderer import RayBundle\n",
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
"from pytorch3d.structures import Volumes\n",
@ -245,17 +245,6 @@
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
]
},
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd",
"showInput": false
},
"source": [
"If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first."
]
},
{
"cell_type": "code",
"execution_count": null,
@ -272,7 +261,6 @@
},
"outputs": [],
"source": [
"expand_args_fields(RenderedMeshDatasetMapProvider)\n",
"cow_provider = RenderedMeshDatasetMapProvider(\n",
" data_file=\"data/cow_mesh/cow.obj\",\n",
" use_point_light=False,\n",
@ -468,7 +456,6 @@
" gm = GenericModel(**cfg)\n",
"else:\n",
" # constructing GenericModel directly\n",
" expand_args_fields(GenericModel)\n",
" gm = GenericModel(\n",
" implicit_function_class_type=\"MyVolumes\",\n",
" render_image_height=output_resolution,\n",

View File

@ -167,12 +167,6 @@ thing as the default for a member of another configured class,
"""
_unprocessed_warning: str = (
" must be processed before it can be used."
+ " This is done by calling expand_args_fields "
+ "or get_default_args on it."
)
TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled"
@ -183,39 +177,42 @@ TWEAK_SUFFIX: str = "_tweak_args"
class ReplaceableBase:
"""
Base class for dataclass-style classes which
can be stored in the registry.
Base class for a class (a "replaceable") which is a base class for
dataclass-style implementations. The implementations can be stored
in the registry. They get expanded into dataclasses with expand_args_fields.
This expansion is delayed.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
These classes should be expanded only when needed (because processing
fixes the list of replaceable subclasses of members of the class). It
is safer if users expand the classes explicitly. But if the class gets
instantiated when it hasn't been processed, we expand it here.
"""
obj = super().__new__(cls)
if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
expand_args_fields(cls)
return obj
class Configurable:
"""
This indicates a class which is not ReplaceableBase
but still needs to be
Base class for dataclass-style classes which are not replaceable. These get
expanded into a dataclass with expand_args_fields.
This expansion is delayed.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
These classes should be expanded only when needed (because processing
fixes the list of replaceable subclasses of members of the class). It
is safer if users expand the classes explicitly. But if the class gets
instantiated when it hasn't been processed, we expand it here.
"""
obj = super().__new__(cls)
if cls is not Configurable and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
expand_args_fields(cls)
return obj

View File

@ -315,6 +315,9 @@ class TestConfig(unittest.TestCase):
]
self.assertEqual(sorted(large_args.keys()), needed_args)
with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."):
LargeFruitBowl(extra_fruit_class_type="NotAFruit")
def test_inheritance2(self):
# This is a case where a class could contain an instance
# of a subclass, which is ignored.
@ -564,16 +567,19 @@ class TestConfig(unittest.TestCase):
def test_unprocessed(self):
# behavior of Configurable classes which need processing in __new__,
class Unprocessed(Configurable):
class UnprocessedConfigurable(Configurable):
a: int = 9
class UnprocessedReplaceable(ReplaceableBase):
a: int = 1
a: int = 9
with self.assertWarnsRegex(UserWarning, "must be processed"):
Unprocessed()
with self.assertWarnsRegex(UserWarning, "must be processed"):
UnprocessedReplaceable()
for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]:
self.assertFalse(_is_actually_dataclass(Unprocessed))
unprocessed = Unprocessed()
self.assertTrue(_is_actually_dataclass(Unprocessed))
self.assertTrue(isinstance(unprocessed, Unprocessed))
self.assertEqual(unprocessed.a, 9)
def test_enum(self):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks