mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	make expand_args_fields optional
Summary: Call expand_args_field when instantiating an object. Reviewed By: shapovalov Differential Revision: D39541931 fbshipit-source-id: de8e1038927ff0112463394412d5d8c26c4a1e17
This commit is contained in:
		
							parent
							
								
									209c160a20
								
							
						
					
					
						commit
						d6a197be36
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -147,7 +147,7 @@
 | 
			
		||||
        "from pytorch3d.implicitron.models.generic_model import GenericModel\n",
 | 
			
		||||
        "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
 | 
			
		||||
        "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
 | 
			
		||||
        "from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n",
 | 
			
		||||
        "from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
 | 
			
		||||
        "from pytorch3d.renderer import RayBundle\n",
 | 
			
		||||
        "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
 | 
			
		||||
        "from pytorch3d.structures import Volumes\n",
 | 
			
		||||
@ -245,17 +245,6 @@
 | 
			
		||||
        "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
 | 
			
		||||
      ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd",
 | 
			
		||||
        "showInput": false
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first."
 | 
			
		||||
      ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
@ -272,7 +261,6 @@
 | 
			
		||||
      },
 | 
			
		||||
      "outputs": [],
 | 
			
		||||
      "source": [
 | 
			
		||||
        "expand_args_fields(RenderedMeshDatasetMapProvider)\n",
 | 
			
		||||
        "cow_provider = RenderedMeshDatasetMapProvider(\n",
 | 
			
		||||
        "    data_file=\"data/cow_mesh/cow.obj\",\n",
 | 
			
		||||
        "    use_point_light=False,\n",
 | 
			
		||||
@ -468,7 +456,6 @@
 | 
			
		||||
        "    gm = GenericModel(**cfg)\n",
 | 
			
		||||
        "else:\n",
 | 
			
		||||
        "    # constructing GenericModel directly\n",
 | 
			
		||||
        "    expand_args_fields(GenericModel)\n",
 | 
			
		||||
        "    gm = GenericModel(\n",
 | 
			
		||||
        "        implicit_function_class_type=\"MyVolumes\",\n",
 | 
			
		||||
        "        render_image_height=output_resolution,\n",
 | 
			
		||||
 | 
			
		||||
@ -167,12 +167,6 @@ thing as the default for a member of another configured class,
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_unprocessed_warning: str = (
 | 
			
		||||
    " must be processed before it can be used."
 | 
			
		||||
    + " This is done by calling expand_args_fields "
 | 
			
		||||
    + "or get_default_args on it."
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
TYPE_SUFFIX: str = "_class_type"
 | 
			
		||||
ARGS_SUFFIX: str = "_args"
 | 
			
		||||
ENABLED_SUFFIX: str = "_enabled"
 | 
			
		||||
@ -183,39 +177,42 @@ TWEAK_SUFFIX: str = "_tweak_args"
 | 
			
		||||
 | 
			
		||||
class ReplaceableBase:
 | 
			
		||||
    """
 | 
			
		||||
    Base class for dataclass-style classes which
 | 
			
		||||
    can be stored in the registry.
 | 
			
		||||
    Base class for a class (a "replaceable") which is a base class for
 | 
			
		||||
    dataclass-style implementations. The implementations can be stored
 | 
			
		||||
    in the registry. They get expanded into dataclasses with expand_args_fields.
 | 
			
		||||
    This expansion is delayed.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        This function only exists to raise a
 | 
			
		||||
        warning if class construction is attempted
 | 
			
		||||
        without processing.
 | 
			
		||||
        These classes should be expanded only when needed (because processing
 | 
			
		||||
        fixes the list of replaceable subclasses of members of the class). It
 | 
			
		||||
        is safer if users expand the classes explicitly. But if the class gets
 | 
			
		||||
        instantiated when it hasn't been processed, we expand it here.
 | 
			
		||||
        """
 | 
			
		||||
        obj = super().__new__(cls)
 | 
			
		||||
        if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
 | 
			
		||||
            warnings.warn(cls.__name__ + _unprocessed_warning)
 | 
			
		||||
            expand_args_fields(cls)
 | 
			
		||||
        return obj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Configurable:
 | 
			
		||||
    """
 | 
			
		||||
    This indicates a class which is not ReplaceableBase
 | 
			
		||||
    but still needs to be
 | 
			
		||||
    Base class for dataclass-style classes which are not replaceable. These get
 | 
			
		||||
    expanded into a dataclass with expand_args_fields.
 | 
			
		||||
    This expansion is delayed.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __new__(cls, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        This function only exists to raise a
 | 
			
		||||
        warning if class construction is attempted
 | 
			
		||||
        without processing.
 | 
			
		||||
        These classes should be expanded only when needed (because processing
 | 
			
		||||
        fixes the list of replaceable subclasses of members of the class). It
 | 
			
		||||
        is safer if users expand the classes explicitly. But if the class gets
 | 
			
		||||
        instantiated when it hasn't been processed, we expand it here.
 | 
			
		||||
        """
 | 
			
		||||
        obj = super().__new__(cls)
 | 
			
		||||
        if cls is not Configurable and not _is_actually_dataclass(cls):
 | 
			
		||||
            warnings.warn(cls.__name__ + _unprocessed_warning)
 | 
			
		||||
            expand_args_fields(cls)
 | 
			
		||||
        return obj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -315,6 +315,9 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
        ]
 | 
			
		||||
        self.assertEqual(sorted(large_args.keys()), needed_args)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."):
 | 
			
		||||
            LargeFruitBowl(extra_fruit_class_type="NotAFruit")
 | 
			
		||||
 | 
			
		||||
    def test_inheritance2(self):
 | 
			
		||||
        # This is a case where a class could contain an instance
 | 
			
		||||
        # of a subclass, which is ignored.
 | 
			
		||||
@ -564,16 +567,19 @@ class TestConfig(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_unprocessed(self):
 | 
			
		||||
        # behavior of Configurable classes which need processing in __new__,
 | 
			
		||||
        class Unprocessed(Configurable):
 | 
			
		||||
        class UnprocessedConfigurable(Configurable):
 | 
			
		||||
            a: int = 9
 | 
			
		||||
 | 
			
		||||
        class UnprocessedReplaceable(ReplaceableBase):
 | 
			
		||||
            a: int = 1
 | 
			
		||||
            a: int = 9
 | 
			
		||||
 | 
			
		||||
        with self.assertWarnsRegex(UserWarning, "must be processed"):
 | 
			
		||||
            Unprocessed()
 | 
			
		||||
        with self.assertWarnsRegex(UserWarning, "must be processed"):
 | 
			
		||||
            UnprocessedReplaceable()
 | 
			
		||||
        for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]:
 | 
			
		||||
 | 
			
		||||
            self.assertFalse(_is_actually_dataclass(Unprocessed))
 | 
			
		||||
            unprocessed = Unprocessed()
 | 
			
		||||
            self.assertTrue(_is_actually_dataclass(Unprocessed))
 | 
			
		||||
            self.assertTrue(isinstance(unprocessed, Unprocessed))
 | 
			
		||||
            self.assertEqual(unprocessed.a, 9)
 | 
			
		||||
 | 
			
		||||
    def test_enum(self):
 | 
			
		||||
        # Test that enum values are kept, i.e. that OmegaConf's runtime checks
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user