mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	volumes tutorial
Summary: new Reviewed By: kjchalup Differential Revision: D38425739 fbshipit-source-id: 23e7a1dd01c034b231511acecf40ea0deb9a067f
This commit is contained in:
		
							parent
							
								
									326716781b
								
							
						
					
					
						commit
						89d8eebc5a
					
				
							
								
								
									
										919
									
								
								docs/tutorials/implicitron_volumes.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										919
									
								
								docs/tutorials/implicitron_volumes.ipynb
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,919 @@
 | 
			
		||||
{
 | 
			
		||||
  "metadata": {
 | 
			
		||||
    "dataExplorerConfig": {},
 | 
			
		||||
    "bento_stylesheets": {
 | 
			
		||||
      "bento/extensions/flow/main.css": true,
 | 
			
		||||
      "bento/extensions/kernel_selector/main.css": true,
 | 
			
		||||
      "bento/extensions/kernel_ui/main.css": true,
 | 
			
		||||
      "bento/extensions/new_kernel/main.css": true,
 | 
			
		||||
      "bento/extensions/system_usage/main.css": true,
 | 
			
		||||
      "bento/extensions/theme/main.css": true
 | 
			
		||||
    },
 | 
			
		||||
    "kernelspec": {
 | 
			
		||||
      "display_name": "pytorch3d",
 | 
			
		||||
      "language": "python",
 | 
			
		||||
      "name": "bento_kernel_pytorch3d",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "kernel_name": "bento_kernel_pytorch3d",
 | 
			
		||||
        "nightly_builds": true,
 | 
			
		||||
        "fbpkg_supported": true,
 | 
			
		||||
        "cinder_runtime": false,
 | 
			
		||||
        "is_prebuilt": true
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    "language_info": {
 | 
			
		||||
      "codemirror_mode": {
 | 
			
		||||
        "name": "ipython",
 | 
			
		||||
        "version": 3
 | 
			
		||||
      },
 | 
			
		||||
      "file_extension": ".py",
 | 
			
		||||
      "mimetype": "text/x-python",
 | 
			
		||||
      "name": "python",
 | 
			
		||||
      "nbconvert_exporter": "python",
 | 
			
		||||
      "pygments_lexer": "ipython3"
 | 
			
		||||
    },
 | 
			
		||||
    "last_server_session_id": "2944b203-9ea8-4c0e-9634-645dfea5f26b",
 | 
			
		||||
    "last_kernel_id": "bb33cd83-7924-489a-8bd8-2d9d62eb0126",
 | 
			
		||||
    "last_base_url": "https://9177.od.fbinfra.net:443/",
 | 
			
		||||
    "last_msg_id": "99f7088e-d22b355b859660479ef0574e_5743",
 | 
			
		||||
    "captumWidgetMessage": {},
 | 
			
		||||
    "outputWidgetContext": {}
 | 
			
		||||
  },
 | 
			
		||||
  "nbformat": 4,
 | 
			
		||||
  "nbformat_minor": 2,
 | 
			
		||||
  "cells": [
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "d38652e8-200a-413c-a36a-f4d349b78a9d",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "641de8aa-0e42-4446-9304-c160a2d226bf",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659619824914,
 | 
			
		||||
        "executionStopTime": 1659619825485
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "a48a9dcf-e80f-474b-a0c4-2c9a765b15c5",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# A simple model using Implicitron\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "In this demo, we use the VolumeRenderer from PyTorch3D as a custom implicit function in Implicitron. We will see\n",
 | 
			
		||||
        "* some of the main objects in Implicitron\n",
 | 
			
		||||
        "* how to plug in a custom part of a model"
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "51337c0e-ad27-4b75-ad6a-737dca5d7b95",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "## 0. Install and import modules\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:\n",
 | 
			
		||||
        ""
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "76f1ecd4-6b73-4214-81b0-118ef8d86872",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "deb6a860-6923-4227-abef-d31388b5142d",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659619898147,
 | 
			
		||||
        "executionStopTime": 1659619898274
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "import os\n",
 | 
			
		||||
        "import sys\n",
 | 
			
		||||
        "import torch\n",
 | 
			
		||||
        "need_pytorch3d=False\n",
 | 
			
		||||
        "try:\n",
 | 
			
		||||
        "    import pytorch3d\n",
 | 
			
		||||
        "except ModuleNotFoundError:\n",
 | 
			
		||||
        "    need_pytorch3d=True\n",
 | 
			
		||||
        "if need_pytorch3d:\n",
 | 
			
		||||
        "    if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
 | 
			
		||||
        "        # We try to install PyTorch3D via a released wheel.\n",
 | 
			
		||||
        "        pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
 | 
			
		||||
        "        version_str=\"\".join([\n",
 | 
			
		||||
        "            f\"py3{sys.version_info.minor}_cu\",\n",
 | 
			
		||||
        "            torch.version.cuda.replace(\".\",\"\"),\n",
 | 
			
		||||
        "            f\"_pyt{pyt_version_str}\"\n",
 | 
			
		||||
        "        ])\n",
 | 
			
		||||
        "        !pip install fvcore iopath\n",
 | 
			
		||||
        "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
 | 
			
		||||
        "    else:\n",
 | 
			
		||||
        "        # We try to install PyTorch3D from source.\n",
 | 
			
		||||
        "        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n",
 | 
			
		||||
        "        !tar xzf 1.10.0.tar.gz\n",
 | 
			
		||||
        "        os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n",
 | 
			
		||||
        "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "2c1020e6-eb4a-4644-9719-9147500d8e4f",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "9e751931-a38d-44c9-9ff1-ac2f7d3a3f99",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "customOutput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "!pip install omegaconf"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "originalKey": "86807e4a-1675-4520-a033-c7af85b233ec",
 | 
			
		||||
        "requestMsgId": "880a7e20-4a90-4b37-a5eb-bccc0b23cac6",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659612480556,
 | 
			
		||||
        "executionStopTime": 1659612480644,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "import logging\n",
 | 
			
		||||
        "from typing import Tuple\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "import matplotlib.animation as animation\n",
 | 
			
		||||
        "import matplotlib.pyplot as plt\n",
 | 
			
		||||
        "import numpy as np\n",
 | 
			
		||||
        "import torch\n",
 | 
			
		||||
        "import tqdm\n",
 | 
			
		||||
        "from IPython.display import HTML\n",
 | 
			
		||||
        "from omegaconf import OmegaConf\n",
 | 
			
		||||
        "from PIL import Image\n",
 | 
			
		||||
        "from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
 | 
			
		||||
        "from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
 | 
			
		||||
        "from pytorch3d.implicitron.models.generic_model import GenericModel\n",
 | 
			
		||||
        "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
 | 
			
		||||
        "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
 | 
			
		||||
        "from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n",
 | 
			
		||||
        "from pytorch3d.renderer import RayBundle\n",
 | 
			
		||||
        "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
 | 
			
		||||
        "from pytorch3d.structures import Volumes\n",
 | 
			
		||||
        "from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "b2d9f5bd-a9d4-4f78-b21e-92f2658e0fe9",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "7e43e623-4030-438b-af4e-b96170c9a052",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659610929375,
 | 
			
		||||
        "executionStopTime": 1659610929383,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "output_resolution = 80"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "0b0c2087-4c86-4c57-b0ee-6f48a70a9c78",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "46883aad-f00b-4fd4-ac17-eec0b2ac272a",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659610930042,
 | 
			
		||||
        "executionStopTime": 1659610930050,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "torch.set_printoptions(sci_mode=False)"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "37809d0d-b02e-42df-85b6-cdd038373653",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "## 1. Load renders of a mesh (the cow mesh) as a dataset\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "A dataset's train, val and test parts in Implicitron are represented as a `dataset_map`, and provided by an implementation of `DatasetMapProvider`. \n",
 | 
			
		||||
        "`RenderedMeshDatasetMapProvider` is one which generates a single-scene dataset with only a train component by taking a mesh and rendering it.\n",
 | 
			
		||||
        "We use it with the cow mesh."
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "cc68cb9c-b8bf-4e9e-bef1-2cfafdf6caa2",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "398cfcae-5d43-4b6f-9c75-db3d297364d4",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659620739780,
 | 
			
		||||
        "executionStopTime": 1659620739914
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path data/cow_mesh.\n",
 | 
			
		||||
        "If running locally, the data is already available at the correct path."
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "2c55e002-a885-4169-8fdc-af9078b05968",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "customOutput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "!mkdir -p data/cow_mesh\n",
 | 
			
		||||
        "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
 | 
			
		||||
        "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
 | 
			
		||||
        "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first."
 | 
			
		||||
      ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "eb77aaec-048c-40bd-bd69-0e66b6ab60b1",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "09b9975c-ff86-41c9-b4a9-975d23afc562",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621652237,
 | 
			
		||||
        "executionStopTime": 1659621652903,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "expand_args_fields(RenderedMeshDatasetMapProvider)\n",
 | 
			
		||||
        "cow_provider = RenderedMeshDatasetMapProvider(\n",
 | 
			
		||||
        "    data_file=\"data/cow_mesh/cow.obj\",\n",
 | 
			
		||||
        "    use_point_light=False,\n",
 | 
			
		||||
        "    resolution=output_resolution,\n",
 | 
			
		||||
        ")"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "8210e15b-da48-4306-a49a-41c4e7e7d42f",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "c243edd2-a106-4fba-8471-dfa4f99a2088",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659610966145,
 | 
			
		||||
        "executionStopTime": 1659610966255,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "dataset_map = cow_provider.get_dataset_map()\n",
 | 
			
		||||
        "tr_cameras = [training_frame.camera for training_frame in dataset_map.train]"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "458d72ad-d9a7-4f13-b5b7-90d2aec61c16",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "7f9431f3-8717-4d89-a7fe-1420dd0e00c4",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659610967703,
 | 
			
		||||
        "executionStopTime": 1659610967848,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# The cameras are all in the XZ plane, in a circle about 2.7 from the origin\n",
 | 
			
		||||
        "centers = torch.cat([i.get_camera_center() for i in tr_cameras])\n",
 | 
			
		||||
        "print(centers.min(0).values)\n",
 | 
			
		||||
        "print(centers.max(0).values)"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "931e712b-b141-437a-97fb-dc2a07ce3458",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "931e712b-b141-437a-97fb-dc2a07ce3458",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659552920194,
 | 
			
		||||
        "executionStopTime": 1659552923122,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# visualization of the cameras\n",
 | 
			
		||||
        "plot = plot_scene({\"k\": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)\n",
 | 
			
		||||
        "plot.layout.scene.aspectmode = \"data\"\n",
 | 
			
		||||
        "plot"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "afa9c02d-f76b-4f68-83e9-9733c615406b",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "## 2. Custom implicit function 🧊\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "At the core of neural rendering methods are functions of spatial coordinates called implicit functions, which are used in some kind of rendering process.\n",
 | 
			
		||||
        "(Often those functions can additionally take other data as well, such as view direction.)\n",
 | 
			
		||||
        "A common rendering process is ray marching over densities and colors provided by an implicit function.\n",
 | 
			
		||||
        "In our case, taking samples from a 3D volume grid is a very simple function of spatial coordinates. \n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "Here we define our own implicit function, which uses PyTorch3D's existing functionality for sampling from a volume grid.\n",
 | 
			
		||||
        "We do this by subclassing `ImplicitFunctionBase`.\n",
 | 
			
		||||
        "We need to register our subclass with a special decorator.\n",
 | 
			
		||||
        "We use Python's dataclass annotations for configuring the module."
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "61b55043-dc52-4de7-992e-e2195edd2123",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "dfaace3c-098c-4ffe-9240-6a7ae0ff271e",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659613575850,
 | 
			
		||||
        "executionStopTime": 1659613575940,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "@registry.register\n",
 | 
			
		||||
        "class MyVolumes(ImplicitFunctionBase, torch.nn.Module):\n",
 | 
			
		||||
        "    grid_resolution: int = 50  # common HWD of volumes, the number of voxels in each direction\n",
 | 
			
		||||
        "    extent: float = 1.0  # In world coordinates, the volume occupies is [-extent, extent] along each axis\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "    def __post_init__(self):\n",
 | 
			
		||||
        "        # We have to call this explicitly if there are other base classes like Module\n",
 | 
			
		||||
        "        super().__init__()\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "        # We define parameters like other torch.nn.Module objects.\n",
 | 
			
		||||
        "        # In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.\n",
 | 
			
		||||
        "        density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)\n",
 | 
			
		||||
        "        self.density = torch.nn.Parameter(density)\n",
 | 
			
		||||
        "        color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)\n",
 | 
			
		||||
        "        self.color = torch.nn.Parameter(color)\n",
 | 
			
		||||
        "        self.density_activation = torch.nn.Softplus()\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "    def forward(\n",
 | 
			
		||||
        "        self,\n",
 | 
			
		||||
        "        ray_bundle: RayBundle,\n",
 | 
			
		||||
        "        fun_viewpool=None,\n",
 | 
			
		||||
        "        global_code=None,\n",
 | 
			
		||||
        "    ):\n",
 | 
			
		||||
        "        densities = self.density_activation(self.density[None, None])\n",
 | 
			
		||||
        "        voxel_size = 2.0 * float(self.extent) / self.grid_resolution\n",
 | 
			
		||||
        "        features = self.color.sigmoid()[None]\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "        # Like other PyTorch3D structures, the actual Volumes object should only exist as long\n",
 | 
			
		||||
        "        # as one iteration of training. It is local to this function.\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "        volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)\n",
 | 
			
		||||
        "        sampler = VolumeSampler(volumes=volume)\n",
 | 
			
		||||
        "        densities, features = sampler(ray_bundle)\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "        # When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,\n",
 | 
			
		||||
        "        # it must return (densities, features, an auxiliary tuple)\n",
 | 
			
		||||
        "        return densities, features, {}\n",
 | 
			
		||||
        ""
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "abaf2cd6-1b68-400e-a142-8fb9f49953f3",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "## 3. Construct the model object.\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "The main model object in PyTorch3D is `GenericModel`, which has pluggable components for the major steps, including the renderer and the implicit function(s).\n",
 | 
			
		||||
        "There are two ways to construct it which are equivalent here."
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "f26c3dce-fbae-4592-bd0e-e4a8abc57c2c",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "9213687e-1caf-46a8-a4e5-a9c531530092",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621267561,
 | 
			
		||||
        "executionStopTime": 1659621267938
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "CONSTRUCT_MODEL_FROM_CONFIG = True\n",
 | 
			
		||||
        "if CONSTRUCT_MODEL_FROM_CONFIG:\n",
 | 
			
		||||
        "    # Via a DictConfig - this is how our training loop with hydra works\n",
 | 
			
		||||
        "    cfg = get_default_args(GenericModel)\n",
 | 
			
		||||
        "    cfg.implicit_function_class_type = \"MyVolumes\"\n",
 | 
			
		||||
        "    cfg.render_image_height=output_resolution\n",
 | 
			
		||||
        "    cfg.render_image_width=output_resolution\n",
 | 
			
		||||
        "    cfg.loss_weights={\"loss_rgb_huber\": 1.0}\n",
 | 
			
		||||
        "    cfg.tqdm_trigger_threshold=19000\n",
 | 
			
		||||
        "    cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0\n",
 | 
			
		||||
        "    gm = GenericModel(**cfg)\n",
 | 
			
		||||
        "else:\n",
 | 
			
		||||
        "    # constructing GenericModel directly\n",
 | 
			
		||||
        "    expand_args_fields(GenericModel)\n",
 | 
			
		||||
        "    gm = GenericModel(\n",
 | 
			
		||||
        "        implicit_function_class_type=\"MyVolumes\",\n",
 | 
			
		||||
        "        render_image_height=output_resolution,\n",
 | 
			
		||||
        "        render_image_width=output_resolution,\n",
 | 
			
		||||
        "        loss_weights={\"loss_rgb_huber\": 1.0},\n",
 | 
			
		||||
        "        tqdm_trigger_threshold=19000,\n",
 | 
			
		||||
        "        raysampler_AdaptiveRaySampler_args = {\"scene_extent\": 4.0}\n",
 | 
			
		||||
        "    )\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "    # In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows\n",
 | 
			
		||||
        "    cfg = OmegaConf.structured(gm)\n",
 | 
			
		||||
        ""
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "4e659f7d-ce66-4999-83de-005eb09d7705",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "7b815b2b-cf19-44d0-ae89-76fde6df35ec",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659611214689,
 | 
			
		||||
        "executionStopTime": 1659611214748,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        " The default renderer is an emission-absorbtion raymarcher. We keep that default."
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "d37ae488-c57c-44d3-9def-825dc1a6495b",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "71143ec1-730f-4876-8a14-e46eea9d6dd1",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621268007,
 | 
			
		||||
        "executionStopTime": 1659621268190,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# We can display the configuration in use as follows.\n",
 | 
			
		||||
        "remove_unused_components(cfg)\n",
 | 
			
		||||
        "yaml = OmegaConf.to_yaml(cfg, sort_keys=False)\n",
 | 
			
		||||
        "%page -r yaml"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "52e53179-3c6e-4c1f-a38a-3a6d803687bb",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "05de9bc3-3f74-4a6f-851c-9ec919b59506",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621268727,
 | 
			
		||||
        "executionStopTime": 1659621268776,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "device = torch.device(\"cuda:0\")\n",
 | 
			
		||||
        "gm.to(device)\n",
 | 
			
		||||
        "assert next(gm.parameters()).is_cuda"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "528a7d53-c645-49c2-9021-09adbb18cd23",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "## 4. train the model "
 | 
			
		||||
      ]
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "953280bd-3161-42ba-8dcb-0c8ef2d5cc25",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "9bba424b-7bfd-4e5a-9d79-ae316e20bab0",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621270236,
 | 
			
		||||
        "executionStopTime": 1659621270446,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "2fcf07f0-0c28-49c7-8c76-1c9a9d810167",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "821deb43-6084-4ece-83c3-dee214562c47",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621270815,
 | 
			
		||||
        "executionStopTime": 1659621270948,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "gm.train()\n",
 | 
			
		||||
        "optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "105099f7-ed0c-4e7f-a976-61a93fd0a8fe",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "0c87c108-83e3-4129-ad02-85e0140f1368",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621271875,
 | 
			
		||||
        "executionStopTime": 1659621298146,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "iterator = tqdm.tqdm(range(2000))\n",
 | 
			
		||||
        "for n_batch in iterator:\n",
 | 
			
		||||
        "    optimizer.zero_grad()\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "    frame = train_data_collated[n_batch % len(dataset_map.train)]\n",
 | 
			
		||||
        "    out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)\n",
 | 
			
		||||
        "    out[\"objective\"].backward()\n",
 | 
			
		||||
        "    if n_batch % 100 == 0:\n",
 | 
			
		||||
        "        iterator.set_postfix_str(f\"loss: {float(out['objective']):.5f}\")\n",
 | 
			
		||||
        "    optimizer.step()"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "e3cd494a-536b-48bc-8290-c048118c82eb",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "e3cd494a-536b-48bc-8290-c048118c82eb",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659535024768,
 | 
			
		||||
        "executionStopTime": 1659535024906
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "## 5. Evaluate the module\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "We generate complete images from all the viewpoints to see how they look."
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "fbe1b2ea-cc24-4b20-a2d7-0249185e34a5",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "771ef1f8-5eee-4932-9e81-33604bf0512a",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621299859,
 | 
			
		||||
        "executionStopTime": 1659621311133,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "def to_numpy_image(image):\n",
 | 
			
		||||
        "    # Takes an image of shape (C, H, W) in [0,1], where C=3 or 1\n",
 | 
			
		||||
        "    # to a numpy uint image of shape (H, W, 3)\n",
 | 
			
		||||
        "    return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()\n",
 | 
			
		||||
        "def resize_image(image):\n",
 | 
			
		||||
        "    # Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)\n",
 | 
			
		||||
        "    return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "gm.eval()\n",
 | 
			
		||||
        "images = []\n",
 | 
			
		||||
        "expected = []\n",
 | 
			
		||||
        "masks = []\n",
 | 
			
		||||
        "masks_expected = []\n",
 | 
			
		||||
        "for frame in tqdm.tqdm(train_data_collated):\n",
 | 
			
		||||
        "    with torch.no_grad():\n",
 | 
			
		||||
        "        out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "    image_rgb = to_numpy_image(out[\"images_render\"][0])\n",
 | 
			
		||||
        "    mask = to_numpy_image(out[\"masks_render\"][0])\n",
 | 
			
		||||
        "    expd = to_numpy_image(resize_image(frame.image_rgb)[0])\n",
 | 
			
		||||
        "    mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "    images.append(image_rgb)\n",
 | 
			
		||||
        "    masks.append(mask)\n",
 | 
			
		||||
        "    expected.append(expd)\n",
 | 
			
		||||
        "    masks_expected.append(mask_expected)"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "markdown",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "24953039-9780-40fd-bd81-5d63e9f40069",
 | 
			
		||||
        "showInput": false,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "7af895a3-dfe4-4c28-ac3b-4ff0fbb40c7f",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659614622542,
 | 
			
		||||
        "executionStopTime": 1659614622757
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "We draw a grid showing predicted image and expected image, followed by predicted mask and expected mask, from each viewpoint. \n",
 | 
			
		||||
        "This is a grid of four rows of images, wrapped in to several large rows, i.e..\n",
 | 
			
		||||
        "<small><center>\n",
 | 
			
		||||
        "```\n",
 | 
			
		||||
        "┌────────┬────────┐           ┌────────┐\n",
 | 
			
		||||
        "│pred    │pred    │           │pred    │\n",
 | 
			
		||||
        "│image   │image   │           │image   │\n",
 | 
			
		||||
        "│1       │2       │           │n       │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│expected│expected│           │expected│\n",
 | 
			
		||||
        "│image   │image   │  ...      │image   │\n",
 | 
			
		||||
        "│1       │2       │           │n       │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│pred    │pred    │           │pred    │\n",
 | 
			
		||||
        "│mask    │mask    │           │mask    │\n",
 | 
			
		||||
        "│1       │2       │           │n       │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│expected│expected│           │expected│\n",
 | 
			
		||||
        "│mask    │mask    │           │mask    │\n",
 | 
			
		||||
        "│1       │2       │           │n       │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│pred    │pred    │           │pred    │\n",
 | 
			
		||||
        "│image   │image   │           │image   │\n",
 | 
			
		||||
        "│n+1     │n+1     │           │2n      │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│expected│expected│           │expected│\n",
 | 
			
		||||
        "│image   │image   │  ...      │image   │\n",
 | 
			
		||||
        "│n+1     │n+2     │           │2n      │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│pred    │pred    │           │pred    │\n",
 | 
			
		||||
        "│mask    │mask    │           │mask    │\n",
 | 
			
		||||
        "│n+1     │n+2     │           │2n      │\n",
 | 
			
		||||
        "├────────┼────────┤           ├────────┤\n",
 | 
			
		||||
        "│expected│expected│           │expected│\n",
 | 
			
		||||
        "│mask    │mask    │           │mask    │\n",
 | 
			
		||||
        "│n+1     │n+2     │           │2n      │\n",
 | 
			
		||||
        "└────────┴────────┘           └────────┘\n",
 | 
			
		||||
        "           ...\n",
 | 
			
		||||
        "```\n",
 | 
			
		||||
        "</center></small>"
 | 
			
		||||
      ],
 | 
			
		||||
      "attachments": {}
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "c488a34a-e46d-4649-93fb-4b1bb5a0e439",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "4221e632-fca1-4fe5-b2e3-f92c37aa40e4",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621313894,
 | 
			
		||||
        "executionStopTime": 1659621314042,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]\n",
 | 
			
		||||
        "n_rows = 4\n",
 | 
			
		||||
        "n_images = len(images)\n",
 | 
			
		||||
        "blank_image = images[0] * 0\n",
 | 
			
		||||
        "n_per_row = 1+(n_images-1)//n_rows\n",
 | 
			
		||||
        "for _ in range(n_per_row*n_rows - n_images):\n",
 | 
			
		||||
        "    for group in images_to_display:\n",
 | 
			
		||||
        "        group.append(blank_image)\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "images_to_display_listed = [[[i] for i in j] for j in images_to_display]\n",
 | 
			
		||||
        "split = []\n",
 | 
			
		||||
        "for row in range(n_rows):\n",
 | 
			
		||||
        "    for group in images_to_display_listed:\n",
 | 
			
		||||
        "        split.append(group[row*n_per_row:(row+1)*n_per_row])  \n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "Image.fromarray(np.block(split))\n",
 | 
			
		||||
        ""
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "49eab9e1-4fe2-4fbe-b4f3-7b6953340170",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "85b402ad-f903-431f-a13e-c2d697e869bb",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621323795,
 | 
			
		||||
        "executionStopTime": 1659621323820,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": []
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# Print the maximum channel intensity in the first image.\n",
 | 
			
		||||
        "print(images[1].max()/255)"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "137d2c43-d39d-4266-ac5e-2b714da5e0ee",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "code_folding": [],
 | 
			
		||||
        "hidden_ranges": [],
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "8e27ec57-c2d6-4ae0-be69-b63b6af929ff",
 | 
			
		||||
        "executionStartTime": 1659621408642,
 | 
			
		||||
        "executionStopTime": 1659621409559
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "plt.ioff()\n",
 | 
			
		||||
        "fig, ax = plt.subplots(figsize=(3,3))\n",
 | 
			
		||||
        "\n",
 | 
			
		||||
        "ax.grid(None)\n",
 | 
			
		||||
        "ims = [[ax.imshow(im, animated=True)] for im in images]\n",
 | 
			
		||||
        "ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)\n",
 | 
			
		||||
        "ani_html = ani.to_jshtml()\n",
 | 
			
		||||
        ""
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "783e70d6-7cf1-4d76-a126-ba11ffc2f5be",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "b6843506-c5fa-4508-80fc-8ecae51a934a",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659621409620,
 | 
			
		||||
        "executionStopTime": 1659621409725
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "HTML(ani_html)"
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "cell_type": "code",
 | 
			
		||||
      "metadata": {
 | 
			
		||||
        "originalKey": "0286c350-2362-4f47-8181-2fc2ba51cfcf",
 | 
			
		||||
        "showInput": true,
 | 
			
		||||
        "customInput": null,
 | 
			
		||||
        "collapsed": false,
 | 
			
		||||
        "requestMsgId": "976f4db9-d4c7-466c-bcfd-218234400226",
 | 
			
		||||
        "customOutput": null,
 | 
			
		||||
        "executionStartTime": 1659614670081,
 | 
			
		||||
        "executionStopTime": 1659614670168
 | 
			
		||||
      },
 | 
			
		||||
      "source": [
 | 
			
		||||
        "# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate\n",
 | 
			
		||||
        "# with torch.no_grad():\n",
 | 
			
		||||
        "#      gm._implicit_functions[0]._fn.density.fill_(9.0)\n",
 | 
			
		||||
        "#      gm._implicit_functions[0]._fn.color.fill_(9.0)\n",
 | 
			
		||||
        ""
 | 
			
		||||
      ],
 | 
			
		||||
      "execution_count": null,
 | 
			
		||||
      "outputs": []
 | 
			
		||||
    }
 | 
			
		||||
  ]
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user