From 89d8eebc5a57e639220eaf4caad20f73886d88e3 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 4 Aug 2022 13:19:38 -0700 Subject: [PATCH] volumes tutorial Summary: new Reviewed By: kjchalup Differential Revision: D38425739 fbshipit-source-id: 23e7a1dd01c034b231511acecf40ea0deb9a067f --- docs/tutorials/implicitron_volumes.ipynb | 919 +++++++++++++++++++++++ 1 file changed, 919 insertions(+) create mode 100644 docs/tutorials/implicitron_volumes.ipynb diff --git a/docs/tutorials/implicitron_volumes.ipynb b/docs/tutorials/implicitron_volumes.ipynb new file mode 100644 index 00000000..752334db --- /dev/null +++ b/docs/tutorials/implicitron_volumes.ipynb @@ -0,0 +1,919 @@ +{ + "metadata": { + "dataExplorerConfig": {}, + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "kernelspec": { + "display_name": "pytorch3d", + "language": "python", + "name": "bento_kernel_pytorch3d", + "metadata": { + "kernel_name": "bento_kernel_pytorch3d", + "nightly_builds": true, + "fbpkg_supported": true, + "cinder_runtime": false, + "is_prebuilt": true + } + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + }, + "last_server_session_id": "2944b203-9ea8-4c0e-9634-645dfea5f26b", + "last_kernel_id": "bb33cd83-7924-489a-8bd8-2d9d62eb0126", + "last_base_url": "https://9177.od.fbinfra.net:443/", + "last_msg_id": "99f7088e-d22b355b859660479ef0574e_5743", + "captumWidgetMessage": {}, + "outputWidgetContext": {} + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "cell_type": "code", + "metadata": { + "originalKey": "d38652e8-200a-413c-a36a-f4d349b78a9d", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "641de8aa-0e42-4446-9304-c160a2d226bf", + "customOutput": null, + "executionStartTime": 1659619824914, + "executionStopTime": 1659619825485 + }, + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "a48a9dcf-e80f-474b-a0c4-2c9a765b15c5", + "showInput": false, + "customInput": null + }, + "source": [ + "# A simple model using Implicitron\n", + "\n", + "In this demo, we use the VolumeRenderer from PyTorch3D as a custom implicit function in Implicitron. We will see\n", + "* some of the main objects in Implicitron\n", + "* how to plug in a custom part of a model" + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "51337c0e-ad27-4b75-ad6a-737dca5d7b95", + "showInput": false, + "customInput": null + }, + "source": [ + "## 0. Install and import modules\n", + "\n", + "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:\n", + "" + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "76f1ecd4-6b73-4214-81b0-118ef8d86872", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "deb6a860-6923-4227-abef-d31388b5142d", + "customOutput": null, + "executionStartTime": 1659619898147, + "executionStopTime": 1659619898274 + }, + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "need_pytorch3d=False\n", + "try:\n", + " import pytorch3d\n", + "except ModuleNotFoundError:\n", + " need_pytorch3d=True\n", + "if need_pytorch3d:\n", + " if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n", + " # We try to install PyTorch3D via a released wheel.\n", + " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", + " version_str=\"\".join([\n", + " f\"py3{sys.version_info.minor}_cu\",\n", + " torch.version.cuda.replace(\".\",\"\"),\n", + " f\"_pyt{pyt_version_str}\"\n", + " ])\n", + " !pip install fvcore iopath\n", + " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", + " else:\n", + " # We try to install PyTorch3D from source.\n", + " !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n", + " !tar xzf 1.10.0.tar.gz\n", + " os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n", + " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "2c1020e6-eb4a-4644-9719-9147500d8e4f", + "showInput": false, + "customInput": null + }, + "source": [ + "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)" + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "9e751931-a38d-44c9-9ff1-ac2f7d3a3f99", + "showInput": true, + "customInput": null, + "customOutput": null + }, + "source": [ + "!pip install omegaconf" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "collapsed": false, + "originalKey": "86807e4a-1675-4520-a033-c7af85b233ec", + "requestMsgId": "880a7e20-4a90-4b37-a5eb-bccc0b23cac6", + "customOutput": null, + "executionStartTime": 1659612480556, + "executionStopTime": 1659612480644, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "import logging\n", + "from typing import Tuple\n", + "\n", + "import matplotlib.animation as animation\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import tqdm\n", + "from IPython.display import HTML\n", + "from omegaconf import OmegaConf\n", + "from PIL import Image\n", + "from pytorch3d.implicitron.dataset.dataset_base import FrameData\n", + "from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n", + "from pytorch3d.implicitron.models.generic_model import GenericModel\n", + "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n", + "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n", + "from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n", + "from pytorch3d.renderer import RayBundle\n", + "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n", + "from pytorch3d.structures import Volumes\n", + "from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "b2d9f5bd-a9d4-4f78-b21e-92f2658e0fe9", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "7e43e623-4030-438b-af4e-b96170c9a052", + "customOutput": null, + "executionStartTime": 1659610929375, + "executionStopTime": 1659610929383, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "output_resolution = 80" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "0b0c2087-4c86-4c57-b0ee-6f48a70a9c78", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "46883aad-f00b-4fd4-ac17-eec0b2ac272a", + "customOutput": null, + "executionStartTime": 1659610930042, + "executionStopTime": 1659610930050, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "torch.set_printoptions(sci_mode=False)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "37809d0d-b02e-42df-85b6-cdd038373653", + "showInput": false, + "customInput": null + }, + "source": [ + "## 1. Load renders of a mesh (the cow mesh) as a dataset\n", + "\n", + "A dataset's train, val and test parts in Implicitron are represented as a `dataset_map`, and provided by an implementation of `DatasetMapProvider`. \n", + "`RenderedMeshDatasetMapProvider` is one which generates a single-scene dataset with only a train component by taking a mesh and rendering it.\n", + "We use it with the cow mesh." + ], + "attachments": {} + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "cc68cb9c-b8bf-4e9e-bef1-2cfafdf6caa2", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "398cfcae-5d43-4b6f-9c75-db3d297364d4", + "customOutput": null, + "executionStartTime": 1659620739780, + "executionStopTime": 1659620739914 + }, + "source": [ + "If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path data/cow_mesh.\n", + "If running locally, the data is already available at the correct path." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "2c55e002-a885-4169-8fdc-af9078b05968", + "showInput": true, + "customInput": null, + "customOutput": null + }, + "source": [ + "!mkdir -p data/cow_mesh\n", + "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n", + "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n", + "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd", + "showInput": false, + "customInput": null + }, + "source": [ + "If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first." + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "eb77aaec-048c-40bd-bd69-0e66b6ab60b1", + "showInput": true, + "collapsed": false, + "requestMsgId": "09b9975c-ff86-41c9-b4a9-975d23afc562", + "customOutput": null, + "executionStartTime": 1659621652237, + "executionStopTime": 1659621652903, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "expand_args_fields(RenderedMeshDatasetMapProvider)\n", + "cow_provider = RenderedMeshDatasetMapProvider(\n", + " data_file=\"data/cow_mesh/cow.obj\",\n", + " use_point_light=False,\n", + " resolution=output_resolution,\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "8210e15b-da48-4306-a49a-41c4e7e7d42f", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "c243edd2-a106-4fba-8471-dfa4f99a2088", + "customOutput": null, + "executionStartTime": 1659610966145, + "executionStopTime": 1659610966255, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "dataset_map = cow_provider.get_dataset_map()\n", + "tr_cameras = [training_frame.camera for training_frame in dataset_map.train]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "458d72ad-d9a7-4f13-b5b7-90d2aec61c16", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "7f9431f3-8717-4d89-a7fe-1420dd0e00c4", + "customOutput": null, + "executionStartTime": 1659610967703, + "executionStopTime": 1659610967848, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "# The cameras are all in the XZ plane, in a circle about 2.7 from the origin\n", + "centers = torch.cat([i.get_camera_center() for i in tr_cameras])\n", + "print(centers.min(0).values)\n", + "print(centers.max(0).values)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "931e712b-b141-437a-97fb-dc2a07ce3458", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "931e712b-b141-437a-97fb-dc2a07ce3458", + "customOutput": null, + "executionStartTime": 1659552920194, + "executionStopTime": 1659552923122, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "# visualization of the cameras\n", + "plot = plot_scene({\"k\": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)\n", + "plot.layout.scene.aspectmode = \"data\"\n", + "plot" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "afa9c02d-f76b-4f68-83e9-9733c615406b", + "showInput": false, + "customInput": null + }, + "source": [ + "## 2. Custom implicit function 🧊\n", + "\n", + "At the core of neural rendering methods are functions of spatial coordinates called implicit functions, which are used in some kind of rendering process.\n", + "(Often those functions can additionally take other data as well, such as view direction.)\n", + "A common rendering process is ray marching over densities and colors provided by an implicit function.\n", + "In our case, taking samples from a 3D volume grid is a very simple function of spatial coordinates. \n", + "\n", + "Here we define our own implicit function, which uses PyTorch3D's existing functionality for sampling from a volume grid.\n", + "We do this by subclassing `ImplicitFunctionBase`.\n", + "We need to register our subclass with a special decorator.\n", + "We use Python's dataclass annotations for configuring the module." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "61b55043-dc52-4de7-992e-e2195edd2123", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "dfaace3c-098c-4ffe-9240-6a7ae0ff271e", + "customOutput": null, + "executionStartTime": 1659613575850, + "executionStopTime": 1659613575940, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "@registry.register\n", + "class MyVolumes(ImplicitFunctionBase, torch.nn.Module):\n", + " grid_resolution: int = 50 # common HWD of volumes, the number of voxels in each direction\n", + " extent: float = 1.0 # In world coordinates, the volume occupies is [-extent, extent] along each axis\n", + "\n", + " def __post_init__(self):\n", + " # We have to call this explicitly if there are other base classes like Module\n", + " super().__init__()\n", + "\n", + " # We define parameters like other torch.nn.Module objects.\n", + " # In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.\n", + " density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)\n", + " self.density = torch.nn.Parameter(density)\n", + " color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)\n", + " self.color = torch.nn.Parameter(color)\n", + " self.density_activation = torch.nn.Softplus()\n", + "\n", + " def forward(\n", + " self,\n", + " ray_bundle: RayBundle,\n", + " fun_viewpool=None,\n", + " global_code=None,\n", + " ):\n", + " densities = self.density_activation(self.density[None, None])\n", + " voxel_size = 2.0 * float(self.extent) / self.grid_resolution\n", + " features = self.color.sigmoid()[None]\n", + "\n", + " # Like other PyTorch3D structures, the actual Volumes object should only exist as long\n", + " # as one iteration of training. It is local to this function.\n", + "\n", + " volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)\n", + " sampler = VolumeSampler(volumes=volume)\n", + " densities, features = sampler(ray_bundle)\n", + "\n", + " # When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,\n", + " # it must return (densities, features, an auxiliary tuple)\n", + " return densities, features, {}\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "abaf2cd6-1b68-400e-a142-8fb9f49953f3", + "showInput": false, + "customInput": null + }, + "source": [ + "## 3. Construct the model object.\n", + "\n", + "The main model object in PyTorch3D is `GenericModel`, which has pluggable components for the major steps, including the renderer and the implicit function(s).\n", + "There are two ways to construct it which are equivalent here." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "f26c3dce-fbae-4592-bd0e-e4a8abc57c2c", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "9213687e-1caf-46a8-a4e5-a9c531530092", + "customOutput": null, + "executionStartTime": 1659621267561, + "executionStopTime": 1659621267938 + }, + "source": [ + "CONSTRUCT_MODEL_FROM_CONFIG = True\n", + "if CONSTRUCT_MODEL_FROM_CONFIG:\n", + " # Via a DictConfig - this is how our training loop with hydra works\n", + " cfg = get_default_args(GenericModel)\n", + " cfg.implicit_function_class_type = \"MyVolumes\"\n", + " cfg.render_image_height=output_resolution\n", + " cfg.render_image_width=output_resolution\n", + " cfg.loss_weights={\"loss_rgb_huber\": 1.0}\n", + " cfg.tqdm_trigger_threshold=19000\n", + " cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0\n", + " gm = GenericModel(**cfg)\n", + "else:\n", + " # constructing GenericModel directly\n", + " expand_args_fields(GenericModel)\n", + " gm = GenericModel(\n", + " implicit_function_class_type=\"MyVolumes\",\n", + " render_image_height=output_resolution,\n", + " render_image_width=output_resolution,\n", + " loss_weights={\"loss_rgb_huber\": 1.0},\n", + " tqdm_trigger_threshold=19000,\n", + " raysampler_AdaptiveRaySampler_args = {\"scene_extent\": 4.0}\n", + " )\n", + "\n", + " # In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows\n", + " cfg = OmegaConf.structured(gm)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "4e659f7d-ce66-4999-83de-005eb09d7705", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "7b815b2b-cf19-44d0-ae89-76fde6df35ec", + "customOutput": null, + "executionStartTime": 1659611214689, + "executionStopTime": 1659611214748, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + " The default renderer is an emission-absorbtion raymarcher. We keep that default." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "d37ae488-c57c-44d3-9def-825dc1a6495b", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "71143ec1-730f-4876-8a14-e46eea9d6dd1", + "customOutput": null, + "executionStartTime": 1659621268007, + "executionStopTime": 1659621268190, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "# We can display the configuration in use as follows.\n", + "remove_unused_components(cfg)\n", + "yaml = OmegaConf.to_yaml(cfg, sort_keys=False)\n", + "%page -r yaml" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "52e53179-3c6e-4c1f-a38a-3a6d803687bb", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "05de9bc3-3f74-4a6f-851c-9ec919b59506", + "customOutput": null, + "executionStartTime": 1659621268727, + "executionStopTime": 1659621268776, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "device = torch.device(\"cuda:0\")\n", + "gm.to(device)\n", + "assert next(gm.parameters()).is_cuda" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "528a7d53-c645-49c2-9021-09adbb18cd23", + "showInput": false, + "customInput": null + }, + "source": [ + "## 4. train the model " + ] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "953280bd-3161-42ba-8dcb-0c8ef2d5cc25", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "9bba424b-7bfd-4e5a-9d79-ae316e20bab0", + "customOutput": null, + "executionStartTime": 1659621270236, + "executionStopTime": 1659621270446, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "2fcf07f0-0c28-49c7-8c76-1c9a9d810167", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "821deb43-6084-4ece-83c3-dee214562c47", + "customOutput": null, + "executionStartTime": 1659621270815, + "executionStopTime": 1659621270948, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "gm.train()\n", + "optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "105099f7-ed0c-4e7f-a976-61a93fd0a8fe", + "showInput": true, + "collapsed": false, + "requestMsgId": "0c87c108-83e3-4129-ad02-85e0140f1368", + "customOutput": null, + "executionStartTime": 1659621271875, + "executionStopTime": 1659621298146, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "iterator = tqdm.tqdm(range(2000))\n", + "for n_batch in iterator:\n", + " optimizer.zero_grad()\n", + "\n", + " frame = train_data_collated[n_batch % len(dataset_map.train)]\n", + " out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)\n", + " out[\"objective\"].backward()\n", + " if n_batch % 100 == 0:\n", + " iterator.set_postfix_str(f\"loss: {float(out['objective']):.5f}\")\n", + " optimizer.step()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "e3cd494a-536b-48bc-8290-c048118c82eb", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "e3cd494a-536b-48bc-8290-c048118c82eb", + "customOutput": null, + "executionStartTime": 1659535024768, + "executionStopTime": 1659535024906 + }, + "source": [ + "## 5. Evaluate the module\n", + "\n", + "We generate complete images from all the viewpoints to see how they look." + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "fbe1b2ea-cc24-4b20-a2d7-0249185e34a5", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "771ef1f8-5eee-4932-9e81-33604bf0512a", + "customOutput": null, + "executionStartTime": 1659621299859, + "executionStopTime": 1659621311133, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "def to_numpy_image(image):\n", + " # Takes an image of shape (C, H, W) in [0,1], where C=3 or 1\n", + " # to a numpy uint image of shape (H, W, 3)\n", + " return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()\n", + "def resize_image(image):\n", + " # Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)\n", + " return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))\n", + "\n", + "gm.eval()\n", + "images = []\n", + "expected = []\n", + "masks = []\n", + "masks_expected = []\n", + "for frame in tqdm.tqdm(train_data_collated):\n", + " with torch.no_grad():\n", + " out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)\n", + "\n", + " image_rgb = to_numpy_image(out[\"images_render\"][0])\n", + " mask = to_numpy_image(out[\"masks_render\"][0])\n", + " expd = to_numpy_image(resize_image(frame.image_rgb)[0])\n", + " mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])\n", + "\n", + " images.append(image_rgb)\n", + " masks.append(mask)\n", + " expected.append(expd)\n", + " masks_expected.append(mask_expected)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "originalKey": "24953039-9780-40fd-bd81-5d63e9f40069", + "showInput": false, + "customInput": null, + "collapsed": false, + "requestMsgId": "7af895a3-dfe4-4c28-ac3b-4ff0fbb40c7f", + "customOutput": null, + "executionStartTime": 1659614622542, + "executionStopTime": 1659614622757 + }, + "source": [ + "We draw a grid showing predicted image and expected image, followed by predicted mask and expected mask, from each viewpoint. \n", + "This is a grid of four rows of images, wrapped in to several large rows, i.e..\n", + "
\n", + "```\n", + "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”\n", + "β”‚pred β”‚pred β”‚ β”‚pred β”‚\n", + "β”‚image β”‚image β”‚ β”‚image β”‚\n", + "β”‚1 β”‚2 β”‚ β”‚n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚expectedβ”‚expectedβ”‚ β”‚expectedβ”‚\n", + "β”‚image β”‚image β”‚ ... β”‚image β”‚\n", + "β”‚1 β”‚2 β”‚ β”‚n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚pred β”‚pred β”‚ β”‚pred β”‚\n", + "β”‚mask β”‚mask β”‚ β”‚mask β”‚\n", + "β”‚1 β”‚2 β”‚ β”‚n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚expectedβ”‚expectedβ”‚ β”‚expectedβ”‚\n", + "β”‚mask β”‚mask β”‚ β”‚mask β”‚\n", + "β”‚1 β”‚2 β”‚ β”‚n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚pred β”‚pred β”‚ β”‚pred β”‚\n", + "β”‚image β”‚image β”‚ β”‚image β”‚\n", + "β”‚n+1 β”‚n+1 β”‚ β”‚2n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚expectedβ”‚expectedβ”‚ β”‚expectedβ”‚\n", + "β”‚image β”‚image β”‚ ... β”‚image β”‚\n", + "β”‚n+1 β”‚n+2 β”‚ β”‚2n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚pred β”‚pred β”‚ β”‚pred β”‚\n", + "β”‚mask β”‚mask β”‚ β”‚mask β”‚\n", + "β”‚n+1 β”‚n+2 β”‚ β”‚2n β”‚\n", + "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€\n", + "β”‚expectedβ”‚expectedβ”‚ β”‚expectedβ”‚\n", + "β”‚mask β”‚mask β”‚ β”‚mask β”‚\n", + "β”‚n+1 β”‚n+2 β”‚ β”‚2n β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n", + " ...\n", + "```\n", + "
" + ], + "attachments": {} + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "c488a34a-e46d-4649-93fb-4b1bb5a0e439", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "4221e632-fca1-4fe5-b2e3-f92c37aa40e4", + "customOutput": null, + "executionStartTime": 1659621313894, + "executionStopTime": 1659621314042, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]\n", + "n_rows = 4\n", + "n_images = len(images)\n", + "blank_image = images[0] * 0\n", + "n_per_row = 1+(n_images-1)//n_rows\n", + "for _ in range(n_per_row*n_rows - n_images):\n", + " for group in images_to_display:\n", + " group.append(blank_image)\n", + "\n", + "images_to_display_listed = [[[i] for i in j] for j in images_to_display]\n", + "split = []\n", + "for row in range(n_rows):\n", + " for group in images_to_display_listed:\n", + " split.append(group[row*n_per_row:(row+1)*n_per_row]) \n", + "\n", + "Image.fromarray(np.block(split))\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "49eab9e1-4fe2-4fbe-b4f3-7b6953340170", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "85b402ad-f903-431f-a13e-c2d697e869bb", + "customOutput": null, + "executionStartTime": 1659621323795, + "executionStopTime": 1659621323820, + "code_folding": [], + "hidden_ranges": [] + }, + "source": [ + "# Print the maximum channel intensity in the first image.\n", + "print(images[1].max()/255)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "137d2c43-d39d-4266-ac5e-2b714da5e0ee", + "showInput": true, + "customInput": null, + "customOutput": null, + "code_folding": [], + "hidden_ranges": [], + "collapsed": false, + "requestMsgId": "8e27ec57-c2d6-4ae0-be69-b63b6af929ff", + "executionStartTime": 1659621408642, + "executionStopTime": 1659621409559 + }, + "source": [ + "plt.ioff()\n", + "fig, ax = plt.subplots(figsize=(3,3))\n", + "\n", + "ax.grid(None)\n", + "ims = [[ax.imshow(im, animated=True)] for im in images]\n", + "ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)\n", + "ani_html = ani.to_jshtml()\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "783e70d6-7cf1-4d76-a126-ba11ffc2f5be", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "b6843506-c5fa-4508-80fc-8ecae51a934a", + "customOutput": null, + "executionStartTime": 1659621409620, + "executionStopTime": 1659621409725 + }, + "source": [ + "HTML(ani_html)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "originalKey": "0286c350-2362-4f47-8181-2fc2ba51cfcf", + "showInput": true, + "customInput": null, + "collapsed": false, + "requestMsgId": "976f4db9-d4c7-466c-bcfd-218234400226", + "customOutput": null, + "executionStartTime": 1659614670081, + "executionStopTime": 1659614670168 + }, + "source": [ + "# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate\n", + "# with torch.no_grad():\n", + "# gm._implicit_functions[0]._fn.density.fill_(9.0)\n", + "# gm._implicit_functions[0]._fn.color.fill_(9.0)\n", + "" + ], + "execution_count": null, + "outputs": [] + } + ] +}