make expand_args_fields optional

Summary: Call expand_args_field when instantiating an object.

Reviewed By: shapovalov

Differential Revision: D39541931

fbshipit-source-id: de8e1038927ff0112463394412d5d8c26c4a1e17
This commit is contained in:
Jeremy Reizenstein 2022-09-22 08:36:09 -07:00 committed by Facebook GitHub Bot
parent 209c160a20
commit d6a197be36
4 changed files with 467 additions and 485 deletions

File diff suppressed because it is too large Load Diff

View File

@ -147,7 +147,7 @@
"from pytorch3d.implicitron.models.generic_model import GenericModel\n", "from pytorch3d.implicitron.models.generic_model import GenericModel\n",
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n", "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n", "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
"from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n", "from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
"from pytorch3d.renderer import RayBundle\n", "from pytorch3d.renderer import RayBundle\n",
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n", "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
"from pytorch3d.structures import Volumes\n", "from pytorch3d.structures import Volumes\n",
@ -245,17 +245,6 @@
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"customInput": null,
"originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd",
"showInput": false
},
"source": [
"If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@ -272,7 +261,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"expand_args_fields(RenderedMeshDatasetMapProvider)\n",
"cow_provider = RenderedMeshDatasetMapProvider(\n", "cow_provider = RenderedMeshDatasetMapProvider(\n",
" data_file=\"data/cow_mesh/cow.obj\",\n", " data_file=\"data/cow_mesh/cow.obj\",\n",
" use_point_light=False,\n", " use_point_light=False,\n",
@ -468,7 +456,6 @@
" gm = GenericModel(**cfg)\n", " gm = GenericModel(**cfg)\n",
"else:\n", "else:\n",
" # constructing GenericModel directly\n", " # constructing GenericModel directly\n",
" expand_args_fields(GenericModel)\n",
" gm = GenericModel(\n", " gm = GenericModel(\n",
" implicit_function_class_type=\"MyVolumes\",\n", " implicit_function_class_type=\"MyVolumes\",\n",
" render_image_height=output_resolution,\n", " render_image_height=output_resolution,\n",

View File

@ -167,12 +167,6 @@ thing as the default for a member of another configured class,
""" """
_unprocessed_warning: str = (
" must be processed before it can be used."
+ " This is done by calling expand_args_fields "
+ "or get_default_args on it."
)
TYPE_SUFFIX: str = "_class_type" TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args" ARGS_SUFFIX: str = "_args"
ENABLED_SUFFIX: str = "_enabled" ENABLED_SUFFIX: str = "_enabled"
@ -183,39 +177,42 @@ TWEAK_SUFFIX: str = "_tweak_args"
class ReplaceableBase: class ReplaceableBase:
""" """
Base class for dataclass-style classes which Base class for a class (a "replaceable") which is a base class for
can be stored in the registry. dataclass-style implementations. The implementations can be stored
in the registry. They get expanded into dataclasses with expand_args_fields.
This expansion is delayed.
""" """
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
This function only exists to raise a These classes should be expanded only when needed (because processing
warning if class construction is attempted fixes the list of replaceable subclasses of members of the class). It
without processing. is safer if users expand the classes explicitly. But if the class gets
instantiated when it hasn't been processed, we expand it here.
""" """
obj = super().__new__(cls) obj = super().__new__(cls)
if cls is not ReplaceableBase and not _is_actually_dataclass(cls): if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning) expand_args_fields(cls)
return obj return obj
class Configurable: class Configurable:
""" """
This indicates a class which is not ReplaceableBase Base class for dataclass-style classes which are not replaceable. These get
but still needs to be
expanded into a dataclass with expand_args_fields. expanded into a dataclass with expand_args_fields.
This expansion is delayed. This expansion is delayed.
""" """
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
""" """
This function only exists to raise a These classes should be expanded only when needed (because processing
warning if class construction is attempted fixes the list of replaceable subclasses of members of the class). It
without processing. is safer if users expand the classes explicitly. But if the class gets
instantiated when it hasn't been processed, we expand it here.
""" """
obj = super().__new__(cls) obj = super().__new__(cls)
if cls is not Configurable and not _is_actually_dataclass(cls): if cls is not Configurable and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning) expand_args_fields(cls)
return obj return obj

View File

@ -315,6 +315,9 @@ class TestConfig(unittest.TestCase):
] ]
self.assertEqual(sorted(large_args.keys()), needed_args) self.assertEqual(sorted(large_args.keys()), needed_args)
with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."):
LargeFruitBowl(extra_fruit_class_type="NotAFruit")
def test_inheritance2(self): def test_inheritance2(self):
# This is a case where a class could contain an instance # This is a case where a class could contain an instance
# of a subclass, which is ignored. # of a subclass, which is ignored.
@ -564,16 +567,19 @@ class TestConfig(unittest.TestCase):
def test_unprocessed(self): def test_unprocessed(self):
# behavior of Configurable classes which need processing in __new__, # behavior of Configurable classes which need processing in __new__,
class Unprocessed(Configurable): class UnprocessedConfigurable(Configurable):
a: int = 9 a: int = 9
class UnprocessedReplaceable(ReplaceableBase): class UnprocessedReplaceable(ReplaceableBase):
a: int = 1 a: int = 9
with self.assertWarnsRegex(UserWarning, "must be processed"): for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]:
Unprocessed()
with self.assertWarnsRegex(UserWarning, "must be processed"): self.assertFalse(_is_actually_dataclass(Unprocessed))
UnprocessedReplaceable() unprocessed = Unprocessed()
self.assertTrue(_is_actually_dataclass(Unprocessed))
self.assertTrue(isinstance(unprocessed, Unprocessed))
self.assertEqual(unprocessed.a, 9)
def test_enum(self): def test_enum(self):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks # Test that enum values are kept, i.e. that OmegaConf's runtime checks