Hard population of registry system with pre_expand

Summary: Provide an extension point pre_expand to let a configurable class A make sure another class B is registered before A is expanded. This reduces top level imports.

Reviewed By: bottler

Differential Revision: D44504122

fbshipit-source-id: c418bebbe6d33862d239be592d9751378eee3a62
This commit is contained in:
Dejan Kovachev
2023-03-31 07:44:38 -07:00
committed by Facebook GitHub Bot
parent 813e941de5
commit c759fc560f
5 changed files with 117 additions and 27 deletions

View File

@@ -10,6 +10,7 @@ import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
@@ -805,6 +806,39 @@ class TestConfig(unittest.TestCase):
self.assertEqual(control_args, ["Orange", "Orange", True, True])
def test_pre_expand(self):
# Check that the precreate method of a class is called once before
# when expand_args_fields is called on the class.
class A(Configurable):
n: int = 9
@classmethod
def pre_expand(cls):
pass
A.pre_expand = Mock()
expand_args_fields(A)
A.pre_expand.assert_called()
def test_pre_expand_replaceable(self):
# Check that the precreate method of a class is called once before
# when expand_args_fields is called on the class.
class A(ReplaceableBase):
pass
@classmethod
def pre_expand(cls):
pass
class A1(A):
n: 9
A.pre_expand = Mock()
expand_args_fields(A1)
A.pre_expand.assert_called()
@dataclass(eq=False)
class MockDataclass: