From d6a197be3662cdfa57a34e3134fea1bb04eb1614 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 22 Sep 2022 08:36:09 -0700 Subject: [PATCH] make expand_args_fields optional Summary: Call expand_args_field when instantiating an object. Reviewed By: shapovalov Differential Revision: D39541931 fbshipit-source-id: de8e1038927ff0112463394412d5d8c26c4a1e17 --- .../tutorials/implicitron_config_system.ipynb | 886 +++++++++--------- docs/tutorials/implicitron_volumes.ipynb | 15 +- pytorch3d/implicitron/tools/config.py | 33 +- tests/implicitron/test_config.py | 18 +- 4 files changed, 467 insertions(+), 485 deletions(-) diff --git a/docs/tutorials/implicitron_config_system.ipynb b/docs/tutorials/implicitron_config_system.ipynb index 365acf19..9d69259e 100644 --- a/docs/tutorials/implicitron_config_system.ipynb +++ b/docs/tutorials/implicitron_config_system.ipynb @@ -1,79 +1,38 @@ { - "metadata": { - "dataExplorerConfig": {}, - "bento_stylesheets": { - "bento/extensions/flow/main.css": true, - "bento/extensions/kernel_selector/main.css": true, - "bento/extensions/kernel_ui/main.css": true, - "bento/extensions/new_kernel/main.css": true, - "bento/extensions/system_usage/main.css": true, - "bento/extensions/theme/main.css": true - }, - "kernelspec": { - "display_name": "pytorch3d", - "language": "python", - "name": "bento_kernel_pytorch3d", - "metadata": { - "kernel_name": "bento_kernel_pytorch3d", - "nightly_builds": true, - "fbpkg_supported": true, - "cinder_runtime": false, - "is_prebuilt": true - } - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - }, - "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2", - "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb", - "last_base_url": "https://9177.od.fbinfra.net:443/", - "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317", - "captumWidgetMessage": {}, - "outputWidgetContext": {} - }, - "nbformat": 4, - "nbformat_minor": 2, "cells": [ { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b", - "showInput": true, "customInput": null, - "customOutput": null + "customOutput": null, + "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b", + "showInput": true }, + "outputs": [], "source": [ "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved." - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "4e15bfa2-5404-40d0-98b6-eb2732c8b72b", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "# Implicitron's config system" - ], - "attachments": {} + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "287be985-423d-42e0-a2af-1e8c585e723c", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "Implicitron's components are all based on a unified hierarchical configuration system. \n", @@ -83,17 +42,15 @@ "\n", "The file which defines this system is [here](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/tools/config.py) in the PyTorch3D repo.\n", "The Implicitron volumes tutorial contains a simple example of using the config system.\n", - "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n", - "" - ], - "attachments": {} + "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n" + ] }, { "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "fde300a2-99cb-4d52-9d5b-4464a2083e0b", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "## 0. Install and import modules\n", @@ -103,12 +60,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9", - "showInput": true, "customInput": null, - "customOutput": null + "customOutput": null, + "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9", + "showInput": true }, + "outputs": [], "source": [ "import os\n", "import sys\n", @@ -135,48 +94,48 @@ " !tar xzf 1.10.0.tar.gz\n", " os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n", " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "609896c0-9e2e-4716-b074-b565f0170e32", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)" - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c", - "showInput": true, "customInput": null, - "customOutput": null + "customOutput": null, + "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c", + "showInput": true }, + "outputs": [], "source": [ "!pip install omegaconf" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "collapsed": false, - "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f", "code_folding": [], - "hidden_ranges": [], - "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f", + "collapsed": false, "customOutput": null, "executionStartTime": 1659465468717, - "executionStopTime": 1659465468738 + "executionStopTime": 1659465468738, + "hidden_ranges": [], + "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f", + "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f" }, + "outputs": [], "source": [ "from dataclasses import dataclass\n", "from typing import Optional, Tuple\n", @@ -191,16 +150,14 @@ " registry,\n", " run_auto_creation,\n", ")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "a638bf90-eb6b-424d-b53d-eae11954a717", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "## 1. Introducing dataclasses \n", @@ -210,16 +167,18 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae", + "customInput": null, "customOutput": null, "executionStartTime": 1659454972732, - "executionStopTime": 1659454972739 + "executionStopTime": 1659454972739, + "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae", + "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae", + "showInput": true }, + "outputs": [], "source": [ "@dataclass\n", "class MyDataclass:\n", @@ -230,230 +189,228 @@ " def __post_init__(self):\n", " print(f\"created with a = {self.a}\")\n", " self.d = 2 * self.b" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645", + "customInput": null, "customOutput": null, "executionStartTime": 1659454973051, - "executionStopTime": 1659454973077 + "executionStopTime": 1659454973077, + "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645", + "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645", + "showInput": true }, + "outputs": [], "source": [ "my_dataclass_instance = MyDataclass(a=18)\n", "assert my_dataclass_instance.d == 16" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "b67ccb9f-dc6c-4994-9b99-b5a1bcfebd70", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "👷 Note that the `dataclass` decorator here is function which modifies the definition of the class itself.\n", "It runs immediately after the definition.\n", "Our config system requires that implicitron library code contains classes whose modified versions need to be aware of user-defined implementations.\n", - "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n", - "" - ], - "attachments": {} + "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n" + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "3e90f664-99df-4387-9c45-a1ad7939ef3a", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "## 2. Introducing omegaconf and OmegaConf.structured\n", "\n", "The [omegaconf](https://github.com/omry/omegaconf/) library provides a DictConfig class which is like a `dict` with str keys, but with extra features for ease-of-use as a configuration system." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174", + "customInput": null, "customOutput": null, "executionStartTime": 1659451341683, - "executionStopTime": 1659451341690 + "executionStopTime": 1659451341690, + "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174", + "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174", + "showInput": true }, + "outputs": [], "source": [ "dc = DictConfig({\"a\": 2, \"b\": True, \"c\": None, \"d\": \"hello\"})\n", "assert dc.a == dc[\"a\"] == 2" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "3b5b76a9-4b76-4784-96ff-2a1212e48e48", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "OmegaConf has a serialization to and from yaml. The [Hydra](https://hydra.cc/) library relies on this for its configuration files." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61", + "customInput": null, "customOutput": null, "executionStartTime": 1659451411835, - "executionStopTime": 1659451411936 + "executionStopTime": 1659451411936, + "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61", + "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61", + "showInput": true }, + "outputs": [], "source": [ "print(OmegaConf.to_yaml(dc))\n", "assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "777fecdd-8bf6-4fd8-827b-cb8af5477fa8", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "OmegaConf.structured provides a DictConfig from a dataclass or instance of a dataclass. Unlike a normal DictConfig, it is type-checked and only known keys can be added." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162", + "customInput": null, "customOutput": null, "executionStartTime": 1659455098879, - "executionStopTime": 1659455098900 + "executionStopTime": 1659455098900, + "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162", + "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162", + "showInput": true }, + "outputs": [], "source": [ "structured = OmegaConf.structured(MyDataclass)\n", "assert isinstance(structured, DictConfig)\n", "print(structured)\n", "print()\n", "print(OmegaConf.to_yaml(structured))" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "be4446da-e536-4139-9ba3-37669a5b5e61", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "`structured` knows it is missing a value for `a`." - ], - "attachments": {} + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "864811e8-1a75-4932-a85e-f681b0541ae9", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "Such an object has members compatible with the dataclass, so an initialisation can be performed as follows." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb", + "customInput": null, "customOutput": null, "executionStartTime": 1659455580491, - "executionStopTime": 1659455580501 + "executionStopTime": 1659455580501, + "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb", + "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb", + "showInput": true }, + "outputs": [], "source": [ "structured.a = 21\n", "my_dataclass_instance2 = MyDataclass(**structured)\n", "print(my_dataclass_instance2)" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "2d08c81c-9d18-4de9-8464-0da2d89f94f3", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "You can also call OmegaConf.structured on an instance." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155", + "customInput": null, "customOutput": null, "executionStartTime": 1659455594700, - "executionStopTime": 1659455594737 + "executionStopTime": 1659455594737, + "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155", + "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155", + "showInput": true }, + "outputs": [], "source": [ "structured_from_instance = OmegaConf.structured(my_dataclass_instance)\n", "my_dataclass_instance3 = MyDataclass(**structured_from_instance)\n", "print(my_dataclass_instance3)" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "2ed559e3-8552-465a-938f-30c72a321184", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184", + "customInput": null, "customOutput": null, "executionStartTime": 1659452594203, - "executionStopTime": 1659452594333 + "executionStopTime": 1659452594333, + "originalKey": "2ed559e3-8552-465a-938f-30c72a321184", + "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184", + "showInput": false }, "source": [ "## 3. Our approach to OmegaConf.structured\n", @@ -461,21 +418,22 @@ "We provide functions which are equivalent to `OmegaConf.structured` but support more features. \n", "To achieve the above using our functions, the following is used.\n", "Note that we indicate configurable classes using a special base class `Configurable`, not a decorator." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656", + "customInput": null, "customOutput": null, "executionStartTime": 1659454053323, - "executionStopTime": 1659454061629 + "executionStopTime": 1659454061629, + "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656", + "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656", + "showInput": true }, + "outputs": [], "source": [ "class MyConfigurable(Configurable):\n", " a: int\n", @@ -485,105 +443,105 @@ " def __post_init__(self):\n", " print(f\"created with a = {self.a}\")\n", " self.d = 2 * self.b" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9", + "customInput": null, "customOutput": null, "executionStartTime": 1659454784912, - "executionStopTime": 1659454784928 + "executionStopTime": 1659454784928, + "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9", + "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9", + "showInput": true }, + "outputs": [], "source": [ - "# expand_args_fields must be called on an object before it is instantiated.\n", - "# A warning is raised if this is missed, but it is possible to not notice the warning.\n", - "# It modifies the class like @dataclass\n", + "# The expand_args_fields function modifies the class like @dataclasses.dataclass.\n", + "# If it has not been called on a Configurable object before it has been instantiated, it will\n", + "# be called automatically.\n", "expand_args_fields(MyConfigurable)\n", "my_configurable_instance = MyConfigurable(a=18)\n", "assert my_configurable_instance.d == 16" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f", + "customInput": null, "customOutput": null, "executionStartTime": 1659460669541, - "executionStopTime": 1659460669566 + "executionStopTime": 1659460669566, + "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f", + "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f", + "showInput": true }, + "outputs": [], "source": [ - "# get_default_args calls expand_args_fields automatically\n", + "# get_default_args also calls expand_args_fields automatically\n", "our_structured = get_default_args(MyConfigurable)\n", "assert isinstance(our_structured, DictConfig)\n", "print(OmegaConf.to_yaml(our_structured))" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210", + "customInput": null, "customOutput": null, "executionStartTime": 1659460454020, - "executionStopTime": 1659460454032 + "executionStopTime": 1659460454032, + "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210", + "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210", + "showInput": true }, + "outputs": [], "source": [ "our_structured.a = 21\n", "print(MyConfigurable(**our_structured))" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85", + "customInput": null, "customOutput": null, "executionStartTime": 1659460599142, - "executionStopTime": 1659460599149 + "executionStopTime": 1659460599149, + "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85", + "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85", + "showInput": false }, "source": [ "## 4. First enhancement: nested types 🪺\n", "\n", "Our system allows Configurable classes to contain each other. \n", "One thing to remember: add a call to `run_auto_creation` in `__post_init__`." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a", + "customInput": null, "customOutput": null, "executionStartTime": 1659465752418, - "executionStopTime": 1659465752976 + "executionStopTime": 1659465752976, + "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a", + "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a", + "showInput": true }, + "outputs": [], "source": [ "class Inner(Configurable):\n", " a: int = 8\n", @@ -598,77 +556,76 @@ "\n", " def __post_init__(self):\n", " run_auto_creation(self)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7", + "customInput": null, "customOutput": null, "executionStartTime": 1659465762326, - "executionStopTime": 1659465762339 + "executionStopTime": 1659465762339, + "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7", + "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7", + "showInput": true }, + "outputs": [], "source": [ "outer_dc = get_default_args(Outer)\n", "print(OmegaConf.to_yaml(outer_dc))" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7", + "customInput": null, "customOutput": null, "executionStartTime": 1659465772894, - "executionStopTime": 1659465772911 + "executionStopTime": 1659465772911, + "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7", + "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7", + "showInput": true }, + "outputs": [], "source": [ "outer = Outer(**outer_dc)\n", "assert isinstance(outer, Outer)\n", "assert isinstance(outer.inner, Inner)\n", "print(vars(outer))\n", "print(outer.inner)" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "44a78c13-ec92-4a87-808a-c4674b320c22", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "Note how inner_args is an extra member of outer. `run_auto_creation(self)` is equivalent to\n", "```\n", " self.inner = Inner(**self.inner_args)\n", "```" - ], - "attachments": {} + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293", + "customInput": null, "customOutput": null, "executionStartTime": 1659461071129, - "executionStopTime": 1659461071137 + "executionStopTime": 1659461071137, + "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293", + "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293", + "showInput": false }, "source": [ "## 5. Second enhancement: pluggable/replaceable components 🔌\n", @@ -681,21 +638,22 @@ "\n", "A configurable class (i.e. a class which uses our system, i.e. a child of `Configurable` or `ReplaceableBase`) which contains a ReplaceableBase must also \n", "contain a corresponding class_type field of type `str` which indicates which concrete child class to use." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "f2898703-d147-4394-978e-fc7f1f559395", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395", + "customInput": null, "customOutput": null, "executionStartTime": 1659463453457, - "executionStopTime": 1659463453467 + "executionStopTime": 1659463453467, + "originalKey": "f2898703-d147-4394-978e-fc7f1f559395", + "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395", + "showInput": true }, + "outputs": [], "source": [ "class InnerBase(ReplaceableBase):\n", " def say_something(self):\n", @@ -717,22 +675,22 @@ "\n", " def say_something(self):\n", " print(\"hello from an Inner2\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624", + "customInput": null, "customOutput": null, "executionStartTime": 1659463453514, - "executionStopTime": 1659463453592 + "executionStopTime": 1659463453592, + "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624", + "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624", + "showInput": true }, + "outputs": [], "source": [ "class Out(Configurable):\n", " inner: InnerBase\n", @@ -744,128 +702,128 @@ "\n", " def talk(self):\n", " self.inner.say_something()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913", + "customInput": null, "customOutput": null, "executionStartTime": 1659463191360, - "executionStopTime": 1659463191428 + "executionStopTime": 1659463191428, + "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913", + "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913", + "showInput": true }, + "outputs": [], "source": [ "Out_dc = get_default_args(Out)\n", "print(OmegaConf.to_yaml(Out_dc))" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c", + "customInput": null, "customOutput": null, "executionStartTime": 1659463192717, - "executionStopTime": 1659463192754 + "executionStopTime": 1659463192754, + "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c", + "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c", + "showInput": true }, + "outputs": [], "source": [ "Out_dc.inner_class_type = \"Inner2\"\n", "out = Out(**Out_dc)\n", "print(out.inner)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541", + "customInput": null, "customOutput": null, "executionStartTime": 1659463193751, - "executionStopTime": 1659463193791 + "executionStopTime": 1659463193791, + "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541", + "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541", + "showInput": true }, + "outputs": [], "source": [ "out.talk()" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "4f78a56c-39cd-4563-a97e-041e5f360f6b", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "Note in this case there are many `args` members. It is usually fine to ignore them in the code. They are needed for the config." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a", + "customInput": null, "customOutput": null, "executionStartTime": 1659462145294, - "executionStopTime": 1659462145307 + "executionStopTime": 1659462145307, + "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a", + "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a", + "showInput": true }, + "outputs": [], "source": [ "print(vars(out))" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf", + "customInput": null, "customOutput": null, "executionStartTime": 1659462231114, - "executionStopTime": 1659462231130 + "executionStopTime": 1659462231130, + "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf", + "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf", + "showInput": false }, "source": [ "## 6. Example with torch.nn.Module 🔥\n", "Typically in implicitron, we use this system in combination with [`Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)s. \n", "Note in this case it is necessary to call `Module.__init__` explicitly in `__post_init__`." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4", + "customInput": null, "customOutput": null, "executionStartTime": 1659462645018, - "executionStopTime": 1659462645037 + "executionStopTime": 1659462645037, + "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4", + "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4", + "showInput": true }, + "outputs": [], "source": [ "class MyLinear(torch.nn.Module, Configurable):\n", " d_in: int = 2\n", @@ -877,101 +835,100 @@ "\n", " def forward(self, x):\n", " return self.linear.forward(x)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313", + "customInput": null, "customOutput": null, "executionStartTime": 1659462692309, - "executionStopTime": 1659462692346 + "executionStopTime": 1659462692346, + "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313", + "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313", + "showInput": true }, + "outputs": [], "source": [ - "expand_args_fields(MyLinear)\n", "my_linear = MyLinear()\n", "input = torch.zeros(2)\n", "output = my_linear(input)\n", "print(\"output shape:\", output.shape)" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03", + "customInput": null, "customOutput": null, "executionStartTime": 1659462738302, - "executionStopTime": 1659462738419 + "executionStopTime": 1659462738419, + "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03", + "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03", + "showInput": false }, "source": [ "`my_linear` has all the usual features of a Module.\n", "E.g. it can be saved and loaded with `torch.save` and `torch.load`.\n", "It has parameters:" - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3", + "customInput": null, "customOutput": null, "executionStartTime": 1659462821485, - "executionStopTime": 1659462821501 + "executionStopTime": 1659462821501, + "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3", + "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3", + "showInput": true }, + "outputs": [], "source": [ "for name, value in my_linear.named_parameters():\n", " print(name, value.shape)" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b", + "customInput": null, "customOutput": null, "executionStartTime": 1659463222379, - "executionStopTime": 1659463222409 + "executionStopTime": 1659463222409, + "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b", + "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b", + "showInput": false }, "source": [ "## 7. Example of implementing your own pluggable component \n", "Let's say I am using a library with `Out` like in section **5** but I want to implement my own child of InnerBase. \n", "All I need to do is register its definition, but I need to do this before expand_args_fields is explicitly or implicitly called on Out." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157", + "customInput": null, "customOutput": null, "executionStartTime": 1659463694644, - "executionStopTime": 1659463694653 + "executionStopTime": 1659463694653, + "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157", + "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157", + "showInput": true }, + "outputs": [], "source": [ "@registry.register\n", "class UserImplementedInner(InnerBase):\n", @@ -979,16 +936,15 @@ "\n", " def say_something(self):\n", " print(\"hello from the user\")" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "f1511aa2-56b8-4ed0-a453-17e2bbfeefe7", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "At this point, we need to redefine the class Out. \n", @@ -997,21 +953,22 @@ "\n", "If you are running experiments from a script, the thing to remember here is that you must import your own modules, which register your own implementations,\n", "before you *use* the library classes." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406", + "customInput": null, "customOutput": null, "executionStartTime": 1659463745967, - "executionStopTime": 1659463745986 + "executionStopTime": 1659463745986, + "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406", + "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406", + "showInput": true }, + "outputs": [], "source": [ "class Out(Configurable):\n", " inner: InnerBase\n", @@ -1023,61 +980,60 @@ "\n", " def talk(self):\n", " self.inner.say_something()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979", + "customInput": null, "customOutput": null, "executionStartTime": 1659463747398, - "executionStopTime": 1659463747431 + "executionStopTime": 1659463747431, + "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979", + "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979", + "showInput": true }, + "outputs": [], "source": [ - "expand_args_fields(Out)\n", "out2 = Out(inner_class_type=\"UserImplementedInner\")\n", "print(out2.inner)" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9", + "customInput": null, "customOutput": null, "executionStartTime": 1659464033633, - "executionStopTime": 1659464033643 + "executionStopTime": 1659464033643, + "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9", + "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9", + "showInput": false }, "source": [ "## 8: Example of making a subcomponent pluggable\n", "\n", "Let's look what needs to happen if we have a subcomponent which we make pluggable, to allow users to supply their own." - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb", + "customInput": null, "customOutput": null, "executionStartTime": 1659464709922, - "executionStopTime": 1659464709933 + "executionStopTime": 1659464709933, + "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb", + "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb", + "showInput": true }, + "outputs": [], "source": [ "class SubComponent(Configurable):\n", " x: float = 0.25\n", @@ -1097,55 +1053,54 @@ " for _ in range(self.repeats):\n", " a = self.subcomponent.apply(a)\n", " return a" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2", + "customInput": null, "customOutput": null, "executionStartTime": 1659464710339, - "executionStopTime": 1659464710459 + "executionStopTime": 1659464710459, + "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2", + "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2", + "showInput": true }, + "outputs": [], "source": [ - "expand_args_fields(LargeComponent)\n", "large_component = LargeComponent()\n", "assert large_component.apply(3) == 4\n", "print(OmegaConf.to_yaml(LargeComponent))" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { + "customInput": null, "originalKey": "be60323a-badf-46e4-a259-72cae1391028", - "showInput": false, - "customInput": null + "showInput": false }, "source": [ "Made generic:" - ], - "attachments": {} + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8", + "customInput": null, "customOutput": null, "executionStartTime": 1659464717226, - "executionStopTime": 1659464717261 + "executionStopTime": 1659464717261, + "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8", + "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8", + "showInput": true }, + "outputs": [], "source": [ "class SubComponentBase(ReplaceableBase):\n", " def apply(self, a: float) -> float:\n", @@ -1172,42 +1127,40 @@ " for _ in range(self.repeats):\n", " a = self.subcomponent.apply(a)\n", " return a" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5", - "showInput": true, - "customInput": null, "collapsed": false, - "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5", + "customInput": null, "customOutput": null, "executionStartTime": 1659464725473, - "executionStopTime": 1659464725587 + "executionStopTime": 1659464725587, + "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5", + "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5", + "showInput": true }, + "outputs": [], "source": [ - "expand_args_fields(LargeComponent)\n", "large_component = LargeComponent()\n", "assert large_component.apply(3) == 4\n", "print(OmegaConf.to_yaml(LargeComponent))" - ], - "execution_count": null, - "outputs": [] + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60", + "customInput": null, "customOutput": null, "executionStartTime": 1659464672680, - "executionStopTime": 1659464673231 + "executionStopTime": 1659464673231, + "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60", + "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60", + "showInput": false }, "source": [ "The following things had to change:\n", @@ -1215,26 +1168,25 @@ "* SubComponent gained a `@registry.register` decoration and had its base class changed to the new one.\n", "* `subcomponent_class_type` was added as a member of the outer class.\n", "* In any saved configuration yaml files, the key `subcomponent_args` had to be changed to `subcomponent_SubComponent_args`." - ], - "attachments": {} + ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": { - "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54", - "showInput": false, - "customInput": null, "collapsed": false, - "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54", + "customInput": null, "customOutput": null, "executionStartTime": 1659462041307, - "executionStopTime": 1659462041637 + "executionStopTime": 1659462041637, + "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54", + "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54", + "showInput": false }, "source": [ "## Appendix: gotchas ⚠️\n", "\n", "* Omitting to define `__post_init__` or not calling `run_auto_creation` in it.\n", - "* Using a configurable class without calling get_default_args or expand_args_fields on it.\n", "* Omitting a type annotation on a field. For example, writing \n", "```\n", " subcomponent_class_type = \"SubComponent\"\n", @@ -1243,10 +1195,50 @@ "```\n", " subcomponent_class_type: str = \"SubComponent\"\n", "```\n", - "\n", - "" - ], - "attachments": {} + "\n" + ] } - ] + ], + "metadata": { + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "captumWidgetMessage": {}, + "dataExplorerConfig": {}, + "kernelspec": { + "display_name": "pytorch3d", + "language": "python", + "metadata": { + "cinder_runtime": false, + "fbpkg_supported": true, + "is_prebuilt": true, + "kernel_name": "bento_kernel_pytorch3d", + "nightly_builds": true + }, + "name": "bento_kernel_pytorch3d" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + }, + "last_base_url": "https://9177.od.fbinfra.net:443/", + "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb", + "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317", + "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2", + "outputWidgetContext": {} + }, + "nbformat": 4, + "nbformat_minor": 2 } diff --git a/docs/tutorials/implicitron_volumes.ipynb b/docs/tutorials/implicitron_volumes.ipynb index 69a364d4..605edae6 100644 --- a/docs/tutorials/implicitron_volumes.ipynb +++ b/docs/tutorials/implicitron_volumes.ipynb @@ -147,7 +147,7 @@ "from pytorch3d.implicitron.models.generic_model import GenericModel\n", "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n", "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n", - "from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n", + "from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n", "from pytorch3d.renderer import RayBundle\n", "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n", "from pytorch3d.structures import Volumes\n", @@ -245,17 +245,6 @@ "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" ] }, - { - "cell_type": "markdown", - "metadata": { - "customInput": null, - "originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd", - "showInput": false - }, - "source": [ - "If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first." - ] - }, { "cell_type": "code", "execution_count": null, @@ -272,7 +261,6 @@ }, "outputs": [], "source": [ - "expand_args_fields(RenderedMeshDatasetMapProvider)\n", "cow_provider = RenderedMeshDatasetMapProvider(\n", " data_file=\"data/cow_mesh/cow.obj\",\n", " use_point_light=False,\n", @@ -468,7 +456,6 @@ " gm = GenericModel(**cfg)\n", "else:\n", " # constructing GenericModel directly\n", - " expand_args_fields(GenericModel)\n", " gm = GenericModel(\n", " implicit_function_class_type=\"MyVolumes\",\n", " render_image_height=output_resolution,\n", diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 50193662..c777ea74 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -167,12 +167,6 @@ thing as the default for a member of another configured class, """ -_unprocessed_warning: str = ( - " must be processed before it can be used." - + " This is done by calling expand_args_fields " - + "or get_default_args on it." -) - TYPE_SUFFIX: str = "_class_type" ARGS_SUFFIX: str = "_args" ENABLED_SUFFIX: str = "_enabled" @@ -183,39 +177,42 @@ TWEAK_SUFFIX: str = "_tweak_args" class ReplaceableBase: """ - Base class for dataclass-style classes which - can be stored in the registry. + Base class for a class (a "replaceable") which is a base class for + dataclass-style implementations. The implementations can be stored + in the registry. They get expanded into dataclasses with expand_args_fields. + This expansion is delayed. """ def __new__(cls, *args, **kwargs): """ - This function only exists to raise a - warning if class construction is attempted - without processing. + These classes should be expanded only when needed (because processing + fixes the list of replaceable subclasses of members of the class). It + is safer if users expand the classes explicitly. But if the class gets + instantiated when it hasn't been processed, we expand it here. """ obj = super().__new__(cls) if cls is not ReplaceableBase and not _is_actually_dataclass(cls): - warnings.warn(cls.__name__ + _unprocessed_warning) + expand_args_fields(cls) return obj class Configurable: """ - This indicates a class which is not ReplaceableBase - but still needs to be + Base class for dataclass-style classes which are not replaceable. These get expanded into a dataclass with expand_args_fields. This expansion is delayed. """ def __new__(cls, *args, **kwargs): """ - This function only exists to raise a - warning if class construction is attempted - without processing. + These classes should be expanded only when needed (because processing + fixes the list of replaceable subclasses of members of the class). It + is safer if users expand the classes explicitly. But if the class gets + instantiated when it hasn't been processed, we expand it here. """ obj = super().__new__(cls) if cls is not Configurable and not _is_actually_dataclass(cls): - warnings.warn(cls.__name__ + _unprocessed_warning) + expand_args_fields(cls) return obj diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 374677cf..ed1e0696 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -315,6 +315,9 @@ class TestConfig(unittest.TestCase): ] self.assertEqual(sorted(large_args.keys()), needed_args) + with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."): + LargeFruitBowl(extra_fruit_class_type="NotAFruit") + def test_inheritance2(self): # This is a case where a class could contain an instance # of a subclass, which is ignored. @@ -564,16 +567,19 @@ class TestConfig(unittest.TestCase): def test_unprocessed(self): # behavior of Configurable classes which need processing in __new__, - class Unprocessed(Configurable): + class UnprocessedConfigurable(Configurable): a: int = 9 class UnprocessedReplaceable(ReplaceableBase): - a: int = 1 + a: int = 9 - with self.assertWarnsRegex(UserWarning, "must be processed"): - Unprocessed() - with self.assertWarnsRegex(UserWarning, "must be processed"): - UnprocessedReplaceable() + for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]: + + self.assertFalse(_is_actually_dataclass(Unprocessed)) + unprocessed = Unprocessed() + self.assertTrue(_is_actually_dataclass(Unprocessed)) + self.assertTrue(isinstance(unprocessed, Unprocessed)) + self.assertEqual(unprocessed.a, 9) def test_enum(self): # Test that enum values are kept, i.e. that OmegaConf's runtime checks