mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
make expand_args_fields optional
Summary: Call expand_args_field when instantiating an object. Reviewed By: shapovalov Differential Revision: D39541931 fbshipit-source-id: de8e1038927ff0112463394412d5d8c26c4a1e17
This commit is contained in:
parent
209c160a20
commit
d6a197be36
File diff suppressed because it is too large
Load Diff
@ -147,7 +147,7 @@
|
||||
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
|
||||
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
|
||||
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
|
||||
"from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n",
|
||||
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
|
||||
"from pytorch3d.renderer import RayBundle\n",
|
||||
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
|
||||
"from pytorch3d.structures import Volumes\n",
|
||||
@ -245,17 +245,6 @@
|
||||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"customInput": null,
|
||||
"originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd",
|
||||
"showInput": false
|
||||
},
|
||||
"source": [
|
||||
"If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -272,7 +261,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"expand_args_fields(RenderedMeshDatasetMapProvider)\n",
|
||||
"cow_provider = RenderedMeshDatasetMapProvider(\n",
|
||||
" data_file=\"data/cow_mesh/cow.obj\",\n",
|
||||
" use_point_light=False,\n",
|
||||
@ -468,7 +456,6 @@
|
||||
" gm = GenericModel(**cfg)\n",
|
||||
"else:\n",
|
||||
" # constructing GenericModel directly\n",
|
||||
" expand_args_fields(GenericModel)\n",
|
||||
" gm = GenericModel(\n",
|
||||
" implicit_function_class_type=\"MyVolumes\",\n",
|
||||
" render_image_height=output_resolution,\n",
|
||||
|
@ -167,12 +167,6 @@ thing as the default for a member of another configured class,
|
||||
"""
|
||||
|
||||
|
||||
_unprocessed_warning: str = (
|
||||
" must be processed before it can be used."
|
||||
+ " This is done by calling expand_args_fields "
|
||||
+ "or get_default_args on it."
|
||||
)
|
||||
|
||||
TYPE_SUFFIX: str = "_class_type"
|
||||
ARGS_SUFFIX: str = "_args"
|
||||
ENABLED_SUFFIX: str = "_enabled"
|
||||
@ -183,39 +177,42 @@ TWEAK_SUFFIX: str = "_tweak_args"
|
||||
|
||||
class ReplaceableBase:
|
||||
"""
|
||||
Base class for dataclass-style classes which
|
||||
can be stored in the registry.
|
||||
Base class for a class (a "replaceable") which is a base class for
|
||||
dataclass-style implementations. The implementations can be stored
|
||||
in the registry. They get expanded into dataclasses with expand_args_fields.
|
||||
This expansion is delayed.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
This function only exists to raise a
|
||||
warning if class construction is attempted
|
||||
without processing.
|
||||
These classes should be expanded only when needed (because processing
|
||||
fixes the list of replaceable subclasses of members of the class). It
|
||||
is safer if users expand the classes explicitly. But if the class gets
|
||||
instantiated when it hasn't been processed, we expand it here.
|
||||
"""
|
||||
obj = super().__new__(cls)
|
||||
if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
|
||||
warnings.warn(cls.__name__ + _unprocessed_warning)
|
||||
expand_args_fields(cls)
|
||||
return obj
|
||||
|
||||
|
||||
class Configurable:
|
||||
"""
|
||||
This indicates a class which is not ReplaceableBase
|
||||
but still needs to be
|
||||
Base class for dataclass-style classes which are not replaceable. These get
|
||||
expanded into a dataclass with expand_args_fields.
|
||||
This expansion is delayed.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
This function only exists to raise a
|
||||
warning if class construction is attempted
|
||||
without processing.
|
||||
These classes should be expanded only when needed (because processing
|
||||
fixes the list of replaceable subclasses of members of the class). It
|
||||
is safer if users expand the classes explicitly. But if the class gets
|
||||
instantiated when it hasn't been processed, we expand it here.
|
||||
"""
|
||||
obj = super().__new__(cls)
|
||||
if cls is not Configurable and not _is_actually_dataclass(cls):
|
||||
warnings.warn(cls.__name__ + _unprocessed_warning)
|
||||
expand_args_fields(cls)
|
||||
return obj
|
||||
|
||||
|
||||
|
@ -315,6 +315,9 @@ class TestConfig(unittest.TestCase):
|
||||
]
|
||||
self.assertEqual(sorted(large_args.keys()), needed_args)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."):
|
||||
LargeFruitBowl(extra_fruit_class_type="NotAFruit")
|
||||
|
||||
def test_inheritance2(self):
|
||||
# This is a case where a class could contain an instance
|
||||
# of a subclass, which is ignored.
|
||||
@ -564,16 +567,19 @@ class TestConfig(unittest.TestCase):
|
||||
|
||||
def test_unprocessed(self):
|
||||
# behavior of Configurable classes which need processing in __new__,
|
||||
class Unprocessed(Configurable):
|
||||
class UnprocessedConfigurable(Configurable):
|
||||
a: int = 9
|
||||
|
||||
class UnprocessedReplaceable(ReplaceableBase):
|
||||
a: int = 1
|
||||
a: int = 9
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
||||
Unprocessed()
|
||||
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
||||
UnprocessedReplaceable()
|
||||
for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]:
|
||||
|
||||
self.assertFalse(_is_actually_dataclass(Unprocessed))
|
||||
unprocessed = Unprocessed()
|
||||
self.assertTrue(_is_actually_dataclass(Unprocessed))
|
||||
self.assertTrue(isinstance(unprocessed, Unprocessed))
|
||||
self.assertEqual(unprocessed.a, 9)
|
||||
|
||||
def test_enum(self):
|
||||
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
|
||||
|
Loading…
x
Reference in New Issue
Block a user