mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Summary: This is not actually needed and is causing a conda-forge confusion to do with python_abi - which needs users to have `-c conda-forge` when they install pytorch3d. Reviewed By: patricklabatut Differential Revision: D59587930 fbshipit-source-id: 961ae13a62e1b2b2ce6d8781db38bd97eca69e65
903 lines
33 KiB
Plaintext
903 lines
33 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659619824914,
|
||
"executionStopTime": 1659619825485,
|
||
"originalKey": "d38652e8-200a-413c-a36a-f4d349b78a9d",
|
||
"requestMsgId": "641de8aa-0e42-4446-9304-c160a2d226bf",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "a48a9dcf-e80f-474b-a0c4-2c9a765b15c5",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"# A simple model using Implicitron\n",
|
||
"\n",
|
||
"In this demo, we use the VolumeRenderer from PyTorch3D as a custom implicit function in Implicitron. We will see\n",
|
||
"* some of the main objects in Implicitron\n",
|
||
"* how to plug in a custom part of a model"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "51337c0e-ad27-4b75-ad6a-737dca5d7b95",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"## 0. Install and import modules\n",
|
||
"\n",
|
||
"Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659619898147,
|
||
"executionStopTime": 1659619898274,
|
||
"originalKey": "76f1ecd4-6b73-4214-81b0-118ef8d86872",
|
||
"requestMsgId": "deb6a860-6923-4227-abef-d31388b5142d",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"import sys\n",
|
||
"import torch\n",
|
||
"import subprocess\n",
|
||
"need_pytorch3d=False\n",
|
||
"try:\n",
|
||
" import pytorch3d\n",
|
||
"except ModuleNotFoundError:\n",
|
||
" need_pytorch3d=True\n",
|
||
"if need_pytorch3d:\n",
|
||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||
" version_str=\"\".join([\n",
|
||
" f\"py3{sys.version_info.minor}_cu\",\n",
|
||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||
" f\"_pyt{pyt_version_str}\"\n",
|
||
" ])\n",
|
||
" !pip install iopath\n",
|
||
" if sys.platform.startswith(\"linux\"):\n",
|
||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||
" pip_list = !pip freeze\n",
|
||
" need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
|
||
" if need_pytorch3d:\n",
|
||
" print(f\"failed to find/install wheel for {version_str}\")\n",
|
||
"if need_pytorch3d:\n",
|
||
" print(\"Installing PyTorch3D from source\")\n",
|
||
" !pip install ninja\n",
|
||
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "2c1020e6-eb4a-4644-9719-9147500d8e4f",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"Ensure omegaconf and visdom are installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"originalKey": "9e751931-a38d-44c9-9ff1-ac2f7d3a3f99",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"!pip install omegaconf visdom"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659612480556,
|
||
"executionStopTime": 1659612480644,
|
||
"hidden_ranges": [],
|
||
"originalKey": "86807e4a-1675-4520-a033-c7af85b233ec",
|
||
"requestMsgId": "880a7e20-4a90-4b37-a5eb-bccc0b23cac6"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import logging\n",
|
||
"from typing import Tuple\n",
|
||
"\n",
|
||
"import matplotlib.animation as animation\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import torch\n",
|
||
"import tqdm\n",
|
||
"from IPython.display import HTML\n",
|
||
"from omegaconf import OmegaConf\n",
|
||
"from PIL import Image\n",
|
||
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
|
||
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
|
||
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
|
||
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n",
|
||
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
|
||
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
|
||
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
|
||
"from pytorch3d.structures import Volumes\n",
|
||
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659610929375,
|
||
"executionStopTime": 1659610929383,
|
||
"hidden_ranges": [],
|
||
"originalKey": "b2d9f5bd-a9d4-4f78-b21e-92f2658e0fe9",
|
||
"requestMsgId": "7e43e623-4030-438b-af4e-b96170c9a052",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"output_resolution = 80"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659610930042,
|
||
"executionStopTime": 1659610930050,
|
||
"hidden_ranges": [],
|
||
"originalKey": "0b0c2087-4c86-4c57-b0ee-6f48a70a9c78",
|
||
"requestMsgId": "46883aad-f00b-4fd4-ac17-eec0b2ac272a",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"torch.set_printoptions(sci_mode=False)"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "37809d0d-b02e-42df-85b6-cdd038373653",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"## 1. Load renders of a mesh (the cow mesh) as a dataset\n",
|
||
"\n",
|
||
"A dataset's train, val and test parts in Implicitron are represented as a `dataset_map`, and provided by an implementation of `DatasetMapProvider`. \n",
|
||
"`RenderedMeshDatasetMapProvider` is one which generates a single-scene dataset with only a train component by taking a mesh and rendering it.\n",
|
||
"We use it with the cow mesh."
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659620739780,
|
||
"executionStopTime": 1659620739914,
|
||
"originalKey": "cc68cb9c-b8bf-4e9e-bef1-2cfafdf6caa2",
|
||
"requestMsgId": "398cfcae-5d43-4b6f-9c75-db3d297364d4",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path data/cow_mesh.\n",
|
||
"If running locally, the data is already available at the correct path."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"originalKey": "2c55e002-a885-4169-8fdc-af9078b05968",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"!mkdir -p data/cow_mesh\n",
|
||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
|
||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
|
||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621652237,
|
||
"executionStopTime": 1659621652903,
|
||
"hidden_ranges": [],
|
||
"originalKey": "eb77aaec-048c-40bd-bd69-0e66b6ab60b1",
|
||
"requestMsgId": "09b9975c-ff86-41c9-b4a9-975d23afc562",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"cow_provider = RenderedMeshDatasetMapProvider(\n",
|
||
" data_file=\"data/cow_mesh/cow.obj\",\n",
|
||
" use_point_light=False,\n",
|
||
" resolution=output_resolution,\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659610966145,
|
||
"executionStopTime": 1659610966255,
|
||
"hidden_ranges": [],
|
||
"originalKey": "8210e15b-da48-4306-a49a-41c4e7e7d42f",
|
||
"requestMsgId": "c243edd2-a106-4fba-8471-dfa4f99a2088",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset_map = cow_provider.get_dataset_map()\n",
|
||
"tr_cameras = [training_frame.camera for training_frame in dataset_map.train]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659610967703,
|
||
"executionStopTime": 1659610967848,
|
||
"hidden_ranges": [],
|
||
"originalKey": "458d72ad-d9a7-4f13-b5b7-90d2aec61c16",
|
||
"requestMsgId": "7f9431f3-8717-4d89-a7fe-1420dd0e00c4",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# The cameras are all in the XZ plane, in a circle about 2.7 from the origin\n",
|
||
"centers = torch.cat([i.get_camera_center() for i in tr_cameras])\n",
|
||
"print(centers.min(0).values)\n",
|
||
"print(centers.max(0).values)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659552920194,
|
||
"executionStopTime": 1659552923122,
|
||
"hidden_ranges": [],
|
||
"originalKey": "931e712b-b141-437a-97fb-dc2a07ce3458",
|
||
"requestMsgId": "931e712b-b141-437a-97fb-dc2a07ce3458",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# visualization of the cameras\n",
|
||
"plot = plot_scene({\"k\": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)\n",
|
||
"plot.layout.scene.aspectmode = \"data\"\n",
|
||
"plot"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "afa9c02d-f76b-4f68-83e9-9733c615406b",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"## 2. Custom implicit function 🧊\n",
|
||
"\n",
|
||
"At the core of neural rendering methods are functions of spatial coordinates called implicit functions, which are used in some kind of rendering process.\n",
|
||
"(Often those functions can additionally take other data as well, such as view direction.)\n",
|
||
"A common rendering process is ray marching over densities and colors provided by an implicit function.\n",
|
||
"In our case, taking samples from a 3D volume grid is a very simple function of spatial coordinates. \n",
|
||
"\n",
|
||
"Here we define our own implicit function, which uses PyTorch3D's existing functionality for sampling from a volume grid.\n",
|
||
"We do this by subclassing `ImplicitFunctionBase`.\n",
|
||
"We need to register our subclass with a special decorator.\n",
|
||
"We use Python's dataclass annotations for configuring the module."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659613575850,
|
||
"executionStopTime": 1659613575940,
|
||
"hidden_ranges": [],
|
||
"originalKey": "61b55043-dc52-4de7-992e-e2195edd2123",
|
||
"requestMsgId": "dfaace3c-098c-4ffe-9240-6a7ae0ff271e",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"@registry.register\n",
|
||
"class MyVolumes(ImplicitFunctionBase, torch.nn.Module):\n",
|
||
" grid_resolution: int = 50 # common HWD of volumes, the number of voxels in each direction\n",
|
||
" extent: float = 1.0 # In world coordinates, the volume occupies is [-extent, extent] along each axis\n",
|
||
"\n",
|
||
" def __post_init__(self):\n",
|
||
" # We have to call this explicitly if there are other base classes like Module\n",
|
||
" super().__init__()\n",
|
||
"\n",
|
||
" # We define parameters like other torch.nn.Module objects.\n",
|
||
" # In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.\n",
|
||
" density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)\n",
|
||
" self.density = torch.nn.Parameter(density)\n",
|
||
" color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)\n",
|
||
" self.color = torch.nn.Parameter(color)\n",
|
||
" self.density_activation = torch.nn.Softplus()\n",
|
||
"\n",
|
||
" def forward(\n",
|
||
" self,\n",
|
||
" ray_bundle: ImplicitronRayBundle,\n",
|
||
" fun_viewpool=None,\n",
|
||
" global_code=None,\n",
|
||
" ):\n",
|
||
" densities = self.density_activation(self.density[None, None])\n",
|
||
" voxel_size = 2.0 * float(self.extent) / self.grid_resolution\n",
|
||
" features = self.color.sigmoid()[None]\n",
|
||
"\n",
|
||
" # Like other PyTorch3D structures, the actual Volumes object should only exist as long\n",
|
||
" # as one iteration of training. It is local to this function.\n",
|
||
"\n",
|
||
" volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)\n",
|
||
" sampler = VolumeSampler(volumes=volume)\n",
|
||
" densities, features = sampler(ray_bundle)\n",
|
||
"\n",
|
||
" # When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,\n",
|
||
" # it must return (densities, features, an auxiliary tuple)\n",
|
||
" return densities, features, {}\n"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "abaf2cd6-1b68-400e-a142-8fb9f49953f3",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"## 3. Construct the model object.\n",
|
||
"\n",
|
||
"The main model object in PyTorch3D is `GenericModel`, which has pluggable components for the major steps, including the renderer and the implicit function(s).\n",
|
||
"There are two ways to construct it which are equivalent here."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621267561,
|
||
"executionStopTime": 1659621267938,
|
||
"originalKey": "f26c3dce-fbae-4592-bd0e-e4a8abc57c2c",
|
||
"requestMsgId": "9213687e-1caf-46a8-a4e5-a9c531530092",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"CONSTRUCT_MODEL_FROM_CONFIG = True\n",
|
||
"if CONSTRUCT_MODEL_FROM_CONFIG:\n",
|
||
" # Via a DictConfig - this is how our training loop with hydra works\n",
|
||
" cfg = get_default_args(GenericModel)\n",
|
||
" cfg.implicit_function_class_type = \"MyVolumes\"\n",
|
||
" cfg.render_image_height=output_resolution\n",
|
||
" cfg.render_image_width=output_resolution\n",
|
||
" cfg.loss_weights={\"loss_rgb_huber\": 1.0}\n",
|
||
" cfg.tqdm_trigger_threshold=19000\n",
|
||
" cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0\n",
|
||
" gm = GenericModel(**cfg)\n",
|
||
"else:\n",
|
||
" # constructing GenericModel directly\n",
|
||
" gm = GenericModel(\n",
|
||
" implicit_function_class_type=\"MyVolumes\",\n",
|
||
" render_image_height=output_resolution,\n",
|
||
" render_image_width=output_resolution,\n",
|
||
" loss_weights={\"loss_rgb_huber\": 1.0},\n",
|
||
" tqdm_trigger_threshold=19000,\n",
|
||
" raysampler_AdaptiveRaySampler_args = {\"scene_extent\": 4.0}\n",
|
||
" )\n",
|
||
"\n",
|
||
" # In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows\n",
|
||
" cfg = OmegaConf.structured(gm)\n"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659611214689,
|
||
"executionStopTime": 1659611214748,
|
||
"hidden_ranges": [],
|
||
"originalKey": "4e659f7d-ce66-4999-83de-005eb09d7705",
|
||
"requestMsgId": "7b815b2b-cf19-44d0-ae89-76fde6df35ec",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
" The default renderer is an emission-absorbtion raymarcher. We keep that default."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621268007,
|
||
"executionStopTime": 1659621268190,
|
||
"hidden_ranges": [],
|
||
"originalKey": "d37ae488-c57c-44d3-9def-825dc1a6495b",
|
||
"requestMsgId": "71143ec1-730f-4876-8a14-e46eea9d6dd1",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# We can display the configuration in use as follows.\n",
|
||
"remove_unused_components(cfg)\n",
|
||
"yaml = OmegaConf.to_yaml(cfg, sort_keys=False)\n",
|
||
"%page -r yaml"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621268727,
|
||
"executionStopTime": 1659621268776,
|
||
"hidden_ranges": [],
|
||
"originalKey": "52e53179-3c6e-4c1f-a38a-3a6d803687bb",
|
||
"requestMsgId": "05de9bc3-3f74-4a6f-851c-9ec919b59506",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"device = torch.device(\"cuda:0\")\n",
|
||
"gm.to(device)\n",
|
||
"assert next(gm.parameters()).is_cuda"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"customInput": null,
|
||
"originalKey": "528a7d53-c645-49c2-9021-09adbb18cd23",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"## 4. train the model "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621270236,
|
||
"executionStopTime": 1659621270446,
|
||
"hidden_ranges": [],
|
||
"originalKey": "953280bd-3161-42ba-8dcb-0c8ef2d5cc25",
|
||
"requestMsgId": "9bba424b-7bfd-4e5a-9d79-ae316e20bab0",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621270815,
|
||
"executionStopTime": 1659621270948,
|
||
"hidden_ranges": [],
|
||
"originalKey": "2fcf07f0-0c28-49c7-8c76-1c9a9d810167",
|
||
"requestMsgId": "821deb43-6084-4ece-83c3-dee214562c47",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gm.train()\n",
|
||
"optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621271875,
|
||
"executionStopTime": 1659621298146,
|
||
"hidden_ranges": [],
|
||
"originalKey": "105099f7-ed0c-4e7f-a976-61a93fd0a8fe",
|
||
"requestMsgId": "0c87c108-83e3-4129-ad02-85e0140f1368",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"iterator = tqdm.tqdm(range(2000))\n",
|
||
"for n_batch in iterator:\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" frame = train_data_collated[n_batch % len(dataset_map.train)]\n",
|
||
" out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)\n",
|
||
" out[\"objective\"].backward()\n",
|
||
" if n_batch % 100 == 0:\n",
|
||
" iterator.set_postfix_str(f\"loss: {float(out['objective']):.5f}\")\n",
|
||
" optimizer.step()"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659535024768,
|
||
"executionStopTime": 1659535024906,
|
||
"originalKey": "e3cd494a-536b-48bc-8290-c048118c82eb",
|
||
"requestMsgId": "e3cd494a-536b-48bc-8290-c048118c82eb",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"## 5. Evaluate the module\n",
|
||
"\n",
|
||
"We generate complete images from all the viewpoints to see how they look."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621299859,
|
||
"executionStopTime": 1659621311133,
|
||
"hidden_ranges": [],
|
||
"originalKey": "fbe1b2ea-cc24-4b20-a2d7-0249185e34a5",
|
||
"requestMsgId": "771ef1f8-5eee-4932-9e81-33604bf0512a",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def to_numpy_image(image):\n",
|
||
" # Takes an image of shape (C, H, W) in [0,1], where C=3 or 1\n",
|
||
" # to a numpy uint image of shape (H, W, 3)\n",
|
||
" return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()\n",
|
||
"def resize_image(image):\n",
|
||
" # Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)\n",
|
||
" return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))\n",
|
||
"\n",
|
||
"gm.eval()\n",
|
||
"images = []\n",
|
||
"expected = []\n",
|
||
"masks = []\n",
|
||
"masks_expected = []\n",
|
||
"for frame in tqdm.tqdm(train_data_collated):\n",
|
||
" with torch.no_grad():\n",
|
||
" out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)\n",
|
||
"\n",
|
||
" image_rgb = to_numpy_image(out[\"images_render\"][0])\n",
|
||
" mask = to_numpy_image(out[\"masks_render\"][0])\n",
|
||
" expd = to_numpy_image(resize_image(frame.image_rgb)[0])\n",
|
||
" mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])\n",
|
||
"\n",
|
||
" images.append(image_rgb)\n",
|
||
" masks.append(mask)\n",
|
||
" expected.append(expd)\n",
|
||
" masks_expected.append(mask_expected)"
|
||
]
|
||
},
|
||
{
|
||
"attachments": {},
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659614622542,
|
||
"executionStopTime": 1659614622757,
|
||
"originalKey": "24953039-9780-40fd-bd81-5d63e9f40069",
|
||
"requestMsgId": "7af895a3-dfe4-4c28-ac3b-4ff0fbb40c7f",
|
||
"showInput": false
|
||
},
|
||
"source": [
|
||
"We draw a grid showing predicted image and expected image, followed by predicted mask and expected mask, from each viewpoint. \n",
|
||
"This is a grid of four rows of images, wrapped in to several large rows, i.e..\n",
|
||
"<small><center>\n",
|
||
"```\n",
|
||
"┌────────┬────────┐ ┌────────┐\n",
|
||
"│pred │pred │ │pred │\n",
|
||
"│image │image │ │image │\n",
|
||
"│1 │2 │ │n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│expected│expected│ │expected│\n",
|
||
"│image │image │ ... │image │\n",
|
||
"│1 │2 │ │n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│pred │pred │ │pred │\n",
|
||
"│mask │mask │ │mask │\n",
|
||
"│1 │2 │ │n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│expected│expected│ │expected│\n",
|
||
"│mask │mask │ │mask │\n",
|
||
"│1 │2 │ │n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│pred │pred │ │pred │\n",
|
||
"│image │image │ │image │\n",
|
||
"│n+1 │n+1 │ │2n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│expected│expected│ │expected│\n",
|
||
"│image │image │ ... │image │\n",
|
||
"│n+1 │n+2 │ │2n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│pred │pred │ │pred │\n",
|
||
"│mask │mask │ │mask │\n",
|
||
"│n+1 │n+2 │ │2n │\n",
|
||
"├────────┼────────┤ ├────────┤\n",
|
||
"│expected│expected│ │expected│\n",
|
||
"│mask │mask │ │mask │\n",
|
||
"│n+1 │n+2 │ │2n │\n",
|
||
"└────────┴────────┘ └────────┘\n",
|
||
" ...\n",
|
||
"```\n",
|
||
"</center></small>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621313894,
|
||
"executionStopTime": 1659621314042,
|
||
"hidden_ranges": [],
|
||
"originalKey": "c488a34a-e46d-4649-93fb-4b1bb5a0e439",
|
||
"requestMsgId": "4221e632-fca1-4fe5-b2e3-f92c37aa40e4",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]\n",
|
||
"n_rows = 4\n",
|
||
"n_images = len(images)\n",
|
||
"blank_image = images[0] * 0\n",
|
||
"n_per_row = 1+(n_images-1)//n_rows\n",
|
||
"for _ in range(n_per_row*n_rows - n_images):\n",
|
||
" for group in images_to_display:\n",
|
||
" group.append(blank_image)\n",
|
||
"\n",
|
||
"images_to_display_listed = [[[i] for i in j] for j in images_to_display]\n",
|
||
"split = []\n",
|
||
"for row in range(n_rows):\n",
|
||
" for group in images_to_display_listed:\n",
|
||
" split.append(group[row*n_per_row:(row+1)*n_per_row]) \n",
|
||
"\n",
|
||
"Image.fromarray(np.block(split))\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621323795,
|
||
"executionStopTime": 1659621323820,
|
||
"hidden_ranges": [],
|
||
"originalKey": "49eab9e1-4fe2-4fbe-b4f3-7b6953340170",
|
||
"requestMsgId": "85b402ad-f903-431f-a13e-c2d697e869bb",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Print the maximum channel intensity in the first image.\n",
|
||
"print(images[1].max()/255)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"code_folding": [],
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621408642,
|
||
"executionStopTime": 1659621409559,
|
||
"hidden_ranges": [],
|
||
"originalKey": "137d2c43-d39d-4266-ac5e-2b714da5e0ee",
|
||
"requestMsgId": "8e27ec57-c2d6-4ae0-be69-b63b6af929ff",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"plt.ioff()\n",
|
||
"fig, ax = plt.subplots(figsize=(3,3))\n",
|
||
"\n",
|
||
"ax.grid(None)\n",
|
||
"ims = [[ax.imshow(im, animated=True)] for im in images]\n",
|
||
"ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)\n",
|
||
"ani_html = ani.to_jshtml()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659621409620,
|
||
"executionStopTime": 1659621409725,
|
||
"originalKey": "783e70d6-7cf1-4d76-a126-ba11ffc2f5be",
|
||
"requestMsgId": "b6843506-c5fa-4508-80fc-8ecae51a934a",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"HTML(ani_html)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"customInput": null,
|
||
"customOutput": null,
|
||
"executionStartTime": 1659614670081,
|
||
"executionStopTime": 1659614670168,
|
||
"originalKey": "0286c350-2362-4f47-8181-2fc2ba51cfcf",
|
||
"requestMsgId": "976f4db9-d4c7-466c-bcfd-218234400226",
|
||
"showInput": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate\n",
|
||
"# with torch.no_grad():\n",
|
||
"# gm._implicit_functions[0]._fn.density.fill_(9.0)\n",
|
||
"# gm._implicit_functions[0]._fn.color.fill_(9.0)\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"bento_stylesheets": {
|
||
"bento/extensions/flow/main.css": true,
|
||
"bento/extensions/kernel_selector/main.css": true,
|
||
"bento/extensions/kernel_ui/main.css": true,
|
||
"bento/extensions/new_kernel/main.css": true,
|
||
"bento/extensions/system_usage/main.css": true,
|
||
"bento/extensions/theme/main.css": true
|
||
},
|
||
"captumWidgetMessage": {},
|
||
"dataExplorerConfig": {},
|
||
"kernelspec": {
|
||
"display_name": "pytorch3d",
|
||
"language": "python",
|
||
"metadata": {
|
||
"cinder_runtime": false,
|
||
"fbpkg_supported": true,
|
||
"is_prebuilt": true,
|
||
"kernel_name": "bento_kernel_pytorch3d",
|
||
"nightly_builds": true
|
||
},
|
||
"name": "bento_kernel_pytorch3d"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3"
|
||
},
|
||
"last_base_url": "https://9177.od.fbinfra.net:443/",
|
||
"last_kernel_id": "bb33cd83-7924-489a-8bd8-2d9d62eb0126",
|
||
"last_msg_id": "99f7088e-d22b355b859660479ef0574e_5743",
|
||
"last_server_session_id": "2944b203-9ea8-4c0e-9634-645dfea5f26b",
|
||
"outputWidgetContext": {}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|