From b28754f8e1e0d5852a529e737d628fb2e7e2bd96 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 4 Aug 2022 13:24:15 -0700 Subject: [PATCH] config system tutorial Summary: new Reviewed By: kjchalup Differential Revision: D38425731 fbshipit-source-id: 0fd8f524df6b29ceb8c7c9a674022412c1efc3b5 --- .../tutorials/implicitron_config_system.ipynb | 1252 +++++++++++++++++ 1 file changed, 1252 insertions(+) create mode 100644 docs/tutorials/implicitron_config_system.ipynb diff --git a/docs/tutorials/implicitron_config_system.ipynb b/docs/tutorials/implicitron_config_system.ipynb new file mode 100644 index 00000000..365acf19 --- /dev/null +++ b/docs/tutorials/implicitron_config_system.ipynb @@ -0,0 +1,1252 @@ +{ + "metadata": { + "dataExplorerConfig": {}, + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "kernelspec": { + "display_name": "pytorch3d", + "language": "python", + "name": "bento_kernel_pytorch3d", + "metadata": { + "kernel_name": "bento_kernel_pytorch3d", + "nightly_builds": true, + "fbpkg_supported": true, + "cinder_runtime": false, + "is_prebuilt": true + } + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + }, + "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2", + "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb", + "last_base_url": "https://9177.od.fbinfra.net:443/", + "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317", + "captumWidgetMessage": {}, + "outputWidgetContext": {} + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "metadata": { + "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b", + "showInput": true, + "customInput": null, + "customOutput": null + }, + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4e15bfa2-5404-40d0-98b6-eb2732c8b72b", + "showInput": false, + "customInput": null + }, + "source": [ + "# Implicitron's config system" + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "287be985-423d-42e0-a2af-1e8c585e723c", + "showInput": false, + "customInput": null + }, + "source": [ + "Implicitron's components are all based on a unified hierarchical configuration system. \n", + "This allows configurable variables and all defaults to be defined separately for each new component.\n", + "All configs relevant to an experiment are then automatically composed into a single configuration file that fully specifies the experiment.\n", + "An especially important feature is extension points where users can insert their own sub-classes of Implicitron's base components.\n", + "\n", + "The file which defines this system is [here](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/tools/config.py) in the PyTorch3D repo.\n", + "The Implicitron volumes tutorial contains a simple example of using the config system.\n", + "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n", + "" + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "fde300a2-99cb-4d52-9d5b-4464a2083e0b", + "showInput": false, + "customInput": null + }, + "source": [ + "## 0. Install and import modules\n", + "\n", + "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9", + "showInput": true, + "customInput": null, + "customOutput": null + }, + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "need_pytorch3d=False\n", + "try:\n", + " import pytorch3d\n", + "except ModuleNotFoundError:\n", + " need_pytorch3d=True\n", + "if need_pytorch3d:\n", + " if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n", + " # We try to install PyTorch3D via a released wheel.\n", + " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", + " version_str=\"\".join([\n", + " f\"py3{sys.version_info.minor}_cu\",\n", + " torch.version.cuda.replace(\".\",\"\"),\n", + " f\"_pyt{pyt_version_str}\"\n", + " ])\n", + " !pip install fvcore iopath\n", + " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", + " else:\n", + " # We try to install PyTorch3D from source.\n", + " !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n", + " !tar xzf 1.10.0.tar.gz\n", + " os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n", + " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "609896c0-9e2e-4716-b074-b565f0170e32", + "showInput": false, + "customInput": null + }, + "source": [ + "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)" + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c", + "showInput": true, + "customInput": null, + "customOutput": null + }, + "source": [ + "!pip install omegaconf" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f", + "code_folding": [], + "hidden_ranges": [], + "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f", + "customOutput": null, + "executionStartTime": 1659465468717, + "executionStopTime": 1659465468738 + }, + "source": [ + "from dataclasses import dataclass\n", + "from typing import Optional, Tuple\n", + "\n", + "import torch\n", + "from omegaconf import DictConfig, OmegaConf\n", + "from pytorch3d.implicitron.tools.config import (\n", + " Configurable,\n", + " ReplaceableBase,\n", + " expand_args_fields,\n", + " get_default_args,\n", + " registry,\n", + " run_auto_creation,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "a638bf90-eb6b-424d-b53d-eae11954a717", + "showInput": false, + "customInput": null + }, + "source": [ + "## 1. Introducing dataclasses \n", + "\n", + "[Type hints](https://docs.python.org/3/library/typing.html) give a taxonomy of types in Python. [Dataclasses](https://docs.python.org/3/library/dataclasses.html) let you create a class based on a list of members which have names, types and possibly default values. The `__init__` function is created automatically, and calls a `__post_init__` function if present as a final step. For example" + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae", + "customOutput": null, + "executionStartTime": 1659454972732, + "executionStopTime": 1659454972739 + }, + "source": [ + "@dataclass\n", + "class MyDataclass:\n", + " a: int\n", + " b: int = 8\n", + " c: Optional[Tuple[int, ...]] = None\n", + "\n", + " def __post_init__(self):\n", + " print(f\"created with a = {self.a}\")\n", + " self.d = 2 * self.b" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645", + "customOutput": null, + "executionStartTime": 1659454973051, + "executionStopTime": 1659454973077 + }, + "source": [ + "my_dataclass_instance = MyDataclass(a=18)\n", + "assert my_dataclass_instance.d == 16" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "b67ccb9f-dc6c-4994-9b99-b5a1bcfebd70", + "showInput": false, + "customInput": null + }, + "source": [ + "👷 Note that the `dataclass` decorator here is function which modifies the definition of the class itself.\n", + "It runs immediately after the definition.\n", + "Our config system requires that implicitron library code contains classes whose modified versions need to be aware of user-defined implementations.\n", + "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n", + "" + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "3e90f664-99df-4387-9c45-a1ad7939ef3a", + "showInput": false, + "customInput": null + }, + "source": [ + "## 2. Introducing omegaconf and OmegaConf.structured\n", + "\n", + "The [omegaconf](https://github.com/omry/omegaconf/) library provides a DictConfig class which is like a `dict` with str keys, but with extra features for ease-of-use as a configuration system." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174", + "customOutput": null, + "executionStartTime": 1659451341683, + "executionStopTime": 1659451341690 + }, + "source": [ + "dc = DictConfig({\"a\": 2, \"b\": True, \"c\": None, \"d\": \"hello\"})\n", + "assert dc.a == dc[\"a\"] == 2" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "3b5b76a9-4b76-4784-96ff-2a1212e48e48", + "showInput": false, + "customInput": null + }, + "source": [ + "OmegaConf has a serialization to and from yaml. The [Hydra](https://hydra.cc/) library relies on this for its configuration files." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61", + "customOutput": null, + "executionStartTime": 1659451411835, + "executionStopTime": 1659451411936 + }, + "source": [ + "print(OmegaConf.to_yaml(dc))\n", + "assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "777fecdd-8bf6-4fd8-827b-cb8af5477fa8", + "showInput": false, + "customInput": null + }, + "source": [ + "OmegaConf.structured provides a DictConfig from a dataclass or instance of a dataclass. Unlike a normal DictConfig, it is type-checked and only known keys can be added." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162", + "customOutput": null, + "executionStartTime": 1659455098879, + "executionStopTime": 1659455098900 + }, + "source": [ + "structured = OmegaConf.structured(MyDataclass)\n", + "assert isinstance(structured, DictConfig)\n", + "print(structured)\n", + "print()\n", + "print(OmegaConf.to_yaml(structured))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "be4446da-e536-4139-9ba3-37669a5b5e61", + "showInput": false, + "customInput": null + }, + "source": [ + "`structured` knows it is missing a value for `a`." + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "864811e8-1a75-4932-a85e-f681b0541ae9", + "showInput": false, + "customInput": null + }, + "source": [ + "Such an object has members compatible with the dataclass, so an initialisation can be performed as follows." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb", + "customOutput": null, + "executionStartTime": 1659455580491, + "executionStopTime": 1659455580501 + }, + "source": [ + "structured.a = 21\n", + "my_dataclass_instance2 = MyDataclass(**structured)\n", + "print(my_dataclass_instance2)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "2d08c81c-9d18-4de9-8464-0da2d89f94f3", + "showInput": false, + "customInput": null + }, + "source": [ + "You can also call OmegaConf.structured on an instance." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155", + "customOutput": null, + "executionStartTime": 1659455594700, + "executionStopTime": 1659455594737 + }, + "source": [ + "structured_from_instance = OmegaConf.structured(my_dataclass_instance)\n", + "my_dataclass_instance3 = MyDataclass(**structured_from_instance)\n", + "print(my_dataclass_instance3)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "2ed559e3-8552-465a-938f-30c72a321184", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184", + "customOutput": null, + "executionStartTime": 1659452594203, + "executionStopTime": 1659452594333 + }, + "source": [ + "## 3. Our approach to OmegaConf.structured\n", + "\n", + "We provide functions which are equivalent to `OmegaConf.structured` but support more features. \n", + "To achieve the above using our functions, the following is used.\n", + "Note that we indicate configurable classes using a special base class `Configurable`, not a decorator." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656", + "customOutput": null, + "executionStartTime": 1659454053323, + "executionStopTime": 1659454061629 + }, + "source": [ + "class MyConfigurable(Configurable):\n", + " a: int\n", + " b: int = 8\n", + " c: Optional[Tuple[int, ...]] = None\n", + "\n", + " def __post_init__(self):\n", + " print(f\"created with a = {self.a}\")\n", + " self.d = 2 * self.b" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9", + "customOutput": null, + "executionStartTime": 1659454784912, + "executionStopTime": 1659454784928 + }, + "source": [ + "# expand_args_fields must be called on an object before it is instantiated.\n", + "# A warning is raised if this is missed, but it is possible to not notice the warning.\n", + "# It modifies the class like @dataclass\n", + "expand_args_fields(MyConfigurable)\n", + "my_configurable_instance = MyConfigurable(a=18)\n", + "assert my_configurable_instance.d == 16" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f", + "customOutput": null, + "executionStartTime": 1659460669541, + "executionStopTime": 1659460669566 + }, + "source": [ + "# get_default_args calls expand_args_fields automatically\n", + "our_structured = get_default_args(MyConfigurable)\n", + "assert isinstance(our_structured, DictConfig)\n", + "print(OmegaConf.to_yaml(our_structured))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210", + "customOutput": null, + "executionStartTime": 1659460454020, + "executionStopTime": 1659460454032 + }, + "source": [ + "our_structured.a = 21\n", + "print(MyConfigurable(**our_structured))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85", + "customOutput": null, + "executionStartTime": 1659460599142, + "executionStopTime": 1659460599149 + }, + "source": [ + "## 4. First enhancement: nested types 🪺\n", + "\n", + "Our system allows Configurable classes to contain each other. \n", + "One thing to remember: add a call to `run_auto_creation` in `__post_init__`." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a", + "customOutput": null, + "executionStartTime": 1659465752418, + "executionStopTime": 1659465752976 + }, + "source": [ + "class Inner(Configurable):\n", + " a: int = 8\n", + " b: bool = True\n", + " c: Tuple[int, ...] = (2, 3, 4, 6)\n", + "\n", + "\n", + "class Outer(Configurable):\n", + " inner: Inner\n", + " x: str = \"hello\"\n", + " xx: bool = False\n", + "\n", + " def __post_init__(self):\n", + " run_auto_creation(self)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7", + "customOutput": null, + "executionStartTime": 1659465762326, + "executionStopTime": 1659465762339 + }, + "source": [ + "outer_dc = get_default_args(Outer)\n", + "print(OmegaConf.to_yaml(outer_dc))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7", + "customOutput": null, + "executionStartTime": 1659465772894, + "executionStopTime": 1659465772911 + }, + "source": [ + "outer = Outer(**outer_dc)\n", + "assert isinstance(outer, Outer)\n", + "assert isinstance(outer.inner, Inner)\n", + "print(vars(outer))\n", + "print(outer.inner)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "44a78c13-ec92-4a87-808a-c4674b320c22", + "showInput": false, + "customInput": null + }, + "source": [ + "Note how inner_args is an extra member of outer. `run_auto_creation(self)` is equivalent to\n", + "```\n", + " self.inner = Inner(**self.inner_args)\n", + "```" + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293", + "customOutput": null, + "executionStartTime": 1659461071129, + "executionStopTime": 1659461071137 + }, + "source": [ + "## 5. Second enhancement: pluggable/replaceable components 🔌\n", + "\n", + "If a class uses `ReplaceableBase` as a base class instead of `Configurable`, we call it a replaceable.\n", + "It indicates that it is designed for child classes to use in its place.\n", + "We might use `NotImplementedError` to indicate functionality which subclasses are expected to implement.\n", + "The system maintains a global `registry` containing subclasses of each ReplaceableBase.\n", + "The subclasses register themselves with it with a decorator.\n", + "\n", + "A configurable class (i.e. a class which uses our system, i.e. a child of `Configurable` or `ReplaceableBase`) which contains a ReplaceableBase must also \n", + "contain a corresponding class_type field of type `str` which indicates which concrete child class to use." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "f2898703-d147-4394-978e-fc7f1f559395", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395", + "customOutput": null, + "executionStartTime": 1659463453457, + "executionStopTime": 1659463453467 + }, + "source": [ + "class InnerBase(ReplaceableBase):\n", + " def say_something(self):\n", + " raise NotImplementedError\n", + "\n", + "\n", + "@registry.register\n", + "class Inner1(InnerBase):\n", + " a: int = 1\n", + " b: str = \"h\"\n", + "\n", + " def say_something(self):\n", + " print(\"hello from an Inner1\")\n", + "\n", + "\n", + "@registry.register\n", + "class Inner2(InnerBase):\n", + " a: int = 2\n", + "\n", + " def say_something(self):\n", + " print(\"hello from an Inner2\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624", + "customOutput": null, + "executionStartTime": 1659463453514, + "executionStopTime": 1659463453592 + }, + "source": [ + "class Out(Configurable):\n", + " inner: InnerBase\n", + " inner_class_type: str = \"Inner1\"\n", + " x: int = 19\n", + "\n", + " def __post_init__(self):\n", + " run_auto_creation(self)\n", + "\n", + " def talk(self):\n", + " self.inner.say_something()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913", + "customOutput": null, + "executionStartTime": 1659463191360, + "executionStopTime": 1659463191428 + }, + "source": [ + "Out_dc = get_default_args(Out)\n", + "print(OmegaConf.to_yaml(Out_dc))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c", + "customOutput": null, + "executionStartTime": 1659463192717, + "executionStopTime": 1659463192754 + }, + "source": [ + "Out_dc.inner_class_type = \"Inner2\"\n", + "out = Out(**Out_dc)\n", + "print(out.inner)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541", + "customOutput": null, + "executionStartTime": 1659463193751, + "executionStopTime": 1659463193791 + }, + "source": [ + "out.talk()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4f78a56c-39cd-4563-a97e-041e5f360f6b", + "showInput": false, + "customInput": null + }, + "source": [ + "Note in this case there are many `args` members. It is usually fine to ignore them in the code. They are needed for the config." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a", + "customOutput": null, + "executionStartTime": 1659462145294, + "executionStopTime": 1659462145307 + }, + "source": [ + "print(vars(out))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf", + "customOutput": null, + "executionStartTime": 1659462231114, + "executionStopTime": 1659462231130 + }, + "source": [ + "## 6. Example with torch.nn.Module 🔥\n", + "Typically in implicitron, we use this system in combination with [`Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)s. \n", + "Note in this case it is necessary to call `Module.__init__` explicitly in `__post_init__`." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4", + "customOutput": null, + "executionStartTime": 1659462645018, + "executionStopTime": 1659462645037 + }, + "source": [ + "class MyLinear(torch.nn.Module, Configurable):\n", + " d_in: int = 2\n", + " d_out: int = 200\n", + "\n", + " def __post_init__(self):\n", + " super().__init__()\n", + " self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)\n", + "\n", + " def forward(self, x):\n", + " return self.linear.forward(x)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313", + "customOutput": null, + "executionStartTime": 1659462692309, + "executionStopTime": 1659462692346 + }, + "source": [ + "expand_args_fields(MyLinear)\n", + "my_linear = MyLinear()\n", + "input = torch.zeros(2)\n", + "output = my_linear(input)\n", + "print(\"output shape:\", output.shape)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03", + "customOutput": null, + "executionStartTime": 1659462738302, + "executionStopTime": 1659462738419 + }, + "source": [ + "`my_linear` has all the usual features of a Module.\n", + "E.g. it can be saved and loaded with `torch.save` and `torch.load`.\n", + "It has parameters:" + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3", + "customOutput": null, + "executionStartTime": 1659462821485, + "executionStopTime": 1659462821501 + }, + "source": [ + "for name, value in my_linear.named_parameters():\n", + " print(name, value.shape)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b", + "customOutput": null, + "executionStartTime": 1659463222379, + "executionStopTime": 1659463222409 + }, + "source": [ + "## 7. Example of implementing your own pluggable component \n", + "Let's say I am using a library with `Out` like in section **5** but I want to implement my own child of InnerBase. \n", + "All I need to do is register its definition, but I need to do this before expand_args_fields is explicitly or implicitly called on Out." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157", + "customOutput": null, + "executionStartTime": 1659463694644, + "executionStopTime": 1659463694653 + }, + "source": [ + "@registry.register\n", + "class UserImplementedInner(InnerBase):\n", + " a: int = 200\n", + "\n", + " def say_something(self):\n", + " print(\"hello from the user\")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "f1511aa2-56b8-4ed0-a453-17e2bbfeefe7", + "showInput": false, + "customInput": null + }, + "source": [ + "At this point, we need to redefine the class Out. \n", + "Otherwise if it has already been expanded without UserImplementedInner, then the following would not work,\n", + "because the implementations known to a class are fixed when it is expanded.\n", + "\n", + "If you are running experiments from a script, the thing to remember here is that you must import your own modules, which register your own implementations,\n", + "before you *use* the library classes." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406", + "customOutput": null, + "executionStartTime": 1659463745967, + "executionStopTime": 1659463745986 + }, + "source": [ + "class Out(Configurable):\n", + " inner: InnerBase\n", + " inner_class_type: str = \"Inner1\"\n", + " x: int = 19\n", + "\n", + " def __post_init__(self):\n", + " run_auto_creation(self)\n", + "\n", + " def talk(self):\n", + " self.inner.say_something()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979", + "customOutput": null, + "executionStartTime": 1659463747398, + "executionStopTime": 1659463747431 + }, + "source": [ + "expand_args_fields(Out)\n", + "out2 = Out(inner_class_type=\"UserImplementedInner\")\n", + "print(out2.inner)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9", + "customOutput": null, + "executionStartTime": 1659464033633, + "executionStopTime": 1659464033643 + }, + "source": [ + "## 8: Example of making a subcomponent pluggable\n", + "\n", + "Let's look what needs to happen if we have a subcomponent which we make pluggable, to allow users to supply their own." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb", + "customOutput": null, + "executionStartTime": 1659464709922, + "executionStopTime": 1659464709933 + }, + "source": [ + "class SubComponent(Configurable):\n", + " x: float = 0.25\n", + "\n", + " def apply(self, a: float) -> float:\n", + " return a + self.x\n", + "\n", + "\n", + "class LargeComponent(Configurable):\n", + " repeats: int = 4\n", + " subcomponent: SubComponent\n", + "\n", + " def __post_init__(self):\n", + " run_auto_creation(self)\n", + "\n", + " def apply(self, a: float) -> float:\n", + " for _ in range(self.repeats):\n", + " a = self.subcomponent.apply(a)\n", + " return a" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2", + "customOutput": null, + "executionStartTime": 1659464710339, + "executionStopTime": 1659464710459 + }, + "source": [ + "expand_args_fields(LargeComponent)\n", + "large_component = LargeComponent()\n", + "assert large_component.apply(3) == 4\n", + "print(OmegaConf.to_yaml(LargeComponent))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "be60323a-badf-46e4-a259-72cae1391028", + "showInput": false, + "customInput": null + }, + "source": [ + "Made generic:" + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8", + "customOutput": null, + "executionStartTime": 1659464717226, + "executionStopTime": 1659464717261 + }, + "source": [ + "class SubComponentBase(ReplaceableBase):\n", + " def apply(self, a: float) -> float:\n", + " raise NotImplementedError\n", + "\n", + "\n", + "@registry.register\n", + "class SubComponent(SubComponentBase):\n", + " x: float = 0.25\n", + "\n", + " def apply(self, a: float) -> float:\n", + " return a + self.x\n", + "\n", + "\n", + "class LargeComponent(Configurable):\n", + " repeats: int = 4\n", + " subcomponent: SubComponentBase\n", + " subcomponent_class_type: str = \"SubComponent\"\n", + "\n", + " def __post_init__(self):\n", + " run_auto_creation(self)\n", + "\n", + " def apply(self, a: float) -> float:\n", + " for _ in range(self.repeats):\n", + " a = self.subcomponent.apply(a)\n", + " return a" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5", + "customOutput": null, + "executionStartTime": 1659464725473, + "executionStopTime": 1659464725587 + }, + "source": [ + "expand_args_fields(LargeComponent)\n", + "large_component = LargeComponent()\n", + "assert large_component.apply(3) == 4\n", + "print(OmegaConf.to_yaml(LargeComponent))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60", + "customOutput": null, + "executionStartTime": 1659464672680, + "executionStopTime": 1659464673231 + }, + "source": [ + "The following things had to change:\n", + "* The base class SubComponentBase was defined.\n", + "* SubComponent gained a `@registry.register` decoration and had its base class changed to the new one.\n", + "* `subcomponent_class_type` was added as a member of the outer class.\n", + "* In any saved configuration yaml files, the key `subcomponent_args` had to be changed to `subcomponent_SubComponent_args`." + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54", + "customOutput": null, + "executionStartTime": 1659462041307, + "executionStopTime": 1659462041637 + }, + "source": [ + "## Appendix: gotchas ⚠️\n", + "\n", + "* Omitting to define `__post_init__` or not calling `run_auto_creation` in it.\n", + "* Using a configurable class without calling get_default_args or expand_args_fields on it.\n", + "* Omitting a type annotation on a field. For example, writing \n", + "```\n", + " subcomponent_class_type = \"SubComponent\"\n", + "```\n", + "instead of \n", + "```\n", + " subcomponent_class_type: str = \"SubComponent\"\n", + "```\n", + "\n", + "" + ], + "attachments": {} + } + ] +}