mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
582 lines
19 KiB
Plaintext
582 lines
19 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "-P3OUvJirQdR"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "44lB2sH-rQdW"
|
|
},
|
|
"source": [
|
|
"# Camera position optimization using differentiable rendering\n",
|
|
"\n",
|
|
"In this tutorial we will learn the [x, y, z] position of a camera given a reference image using differentiable rendering. \n",
|
|
"\n",
|
|
"We will first initialize a renderer with a starting position for the camera. We will then use this to generate an image, compute a loss with the reference image, and finally backpropagate through the entire pipeline to update the position of the camera. \n",
|
|
"\n",
|
|
"This tutorial shows how to:\n",
|
|
"- load a mesh from an `.obj` file\n",
|
|
"- initialize a `Camera`, `Shader` and `Renderer`,\n",
|
|
"- render a mesh\n",
|
|
"- set up an optimization loop with a loss function and optimizer\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "AZGmIlmWrQdX"
|
|
},
|
|
"source": [
|
|
"## 0. Install and import modules"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "qkX7DiM6rmeM"
|
|
},
|
|
"source": [
|
|
"Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 717
|
|
},
|
|
"colab_type": "code",
|
|
"id": "sEVdNGFwripM",
|
|
"outputId": "27047061-a29b-4562-c164-c1288e24c266"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import torch\n",
|
|
"need_pytorch3d=False\n",
|
|
"try:\n",
|
|
" import pytorch3d\n",
|
|
"except ModuleNotFoundError:\n",
|
|
" need_pytorch3d=True\n",
|
|
"if need_pytorch3d:\n",
|
|
" if torch.__version__.startswith(\"2.1.\") and sys.platform.startswith(\"linux\"):\n",
|
|
" # We try to install PyTorch3D via a released wheel.\n",
|
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
|
" version_str=\"\".join([\n",
|
|
" f\"py3{sys.version_info.minor}_cu\",\n",
|
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
|
" f\"_pyt{pyt_version_str}\"\n",
|
|
" ])\n",
|
|
" !pip install fvcore iopath\n",
|
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
|
" else:\n",
|
|
" # We try to install PyTorch3D from source.\n",
|
|
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "w9mH5iVprQdZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"import imageio\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from skimage import img_as_ubyte\n",
|
|
"\n",
|
|
"# io utils\n",
|
|
"from pytorch3d.io import load_obj\n",
|
|
"\n",
|
|
"# datastructures\n",
|
|
"from pytorch3d.structures import Meshes\n",
|
|
"\n",
|
|
"# 3D transformations functions\n",
|
|
"from pytorch3d.transforms import Rotate, Translate\n",
|
|
"\n",
|
|
"# rendering components\n",
|
|
"from pytorch3d.renderer import (\n",
|
|
" FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, \n",
|
|
" RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,\n",
|
|
" SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "cpUf2UvirQdc"
|
|
},
|
|
"source": [
|
|
"## 1. Load the Obj\n",
|
|
"\n",
|
|
"We will load an obj file and create a **Meshes** object. **Meshes** is a unique datastructure provided in PyTorch3D for working with **batches of meshes of different sizes**. It has several useful class methods which are used in the rendering pipeline. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "8d-oREfkrt_Z"
|
|
},
|
|
"source": [
|
|
"If you are running this notebook locally after cloning the PyTorch3D repository, the mesh will already be available. **If using Google Colab, fetch the mesh and save it at the path `data/`**:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 204
|
|
},
|
|
"colab_type": "code",
|
|
"id": "sD5KcLuJr0PL",
|
|
"outputId": "e65061fa-dbd5-4c06-b559-3592632983ee"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!mkdir -p data\n",
|
|
"!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "VWiPKnEIrQdd"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Set the cuda device \n",
|
|
"if torch.cuda.is_available():\n",
|
|
" device = torch.device(\"cuda:0\")\n",
|
|
" torch.cuda.set_device(device)\n",
|
|
"else:\n",
|
|
" device = torch.device(\"cpu\")\n",
|
|
"\n",
|
|
"# Load the obj and ignore the textures and materials.\n",
|
|
"verts, faces_idx, _ = load_obj(\"./data/teapot.obj\")\n",
|
|
"faces = faces_idx.verts_idx\n",
|
|
"\n",
|
|
"# Initialize each vertex to be white in color.\n",
|
|
"verts_rgb = torch.ones_like(verts)[None] # (1, V, 3)\n",
|
|
"textures = TexturesVertex(verts_features=verts_rgb.to(device))\n",
|
|
"\n",
|
|
"# Create a Meshes object for the teapot. Here we have only one mesh in the batch.\n",
|
|
"teapot_mesh = Meshes(\n",
|
|
" verts=[verts.to(device)], \n",
|
|
" faces=[faces.to(device)], \n",
|
|
" textures=textures\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "mgtGbQktrQdh"
|
|
},
|
|
"source": [
|
|
"\n",
|
|
"\n",
|
|
"## 2. Optimization setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "Q6PzKD_NrQdi"
|
|
},
|
|
"source": [
|
|
"### Create a renderer\n",
|
|
"\n",
|
|
"A **renderer** in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a number of subcomponents such as a **camera** (orthographic/perspective). Here we initialize some of these components and use default values for the rest. \n",
|
|
"\n",
|
|
"For optimizing the camera position we will use a renderer which produces a **silhouette** of the object only and does not apply any **lighting** or **shading**. We will also initialize another renderer which applies full **Phong shading** and use this for visualizing the outputs. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "KPlby75GrQdj"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Initialize a perspective camera.\n",
|
|
"cameras = FoVPerspectiveCameras(device=device)\n",
|
|
"\n",
|
|
"# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of \n",
|
|
"# edges. Refer to blending.py for more details. \n",
|
|
"blend_params = BlendParams(sigma=1e-4, gamma=1e-4)\n",
|
|
"\n",
|
|
"# Define the settings for rasterization and shading. Here we set the output image to be of size\n",
|
|
"# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that \n",
|
|
"# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for \n",
|
|
"# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of \n",
|
|
"# the difference between naive and coarse-to-fine rasterization. \n",
|
|
"raster_settings = RasterizationSettings(\n",
|
|
" image_size=256, \n",
|
|
" blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, \n",
|
|
" faces_per_pixel=100, \n",
|
|
")\n",
|
|
"\n",
|
|
"# Create a silhouette mesh renderer by composing a rasterizer and a shader. \n",
|
|
"silhouette_renderer = MeshRenderer(\n",
|
|
" rasterizer=MeshRasterizer(\n",
|
|
" cameras=cameras, \n",
|
|
" raster_settings=raster_settings\n",
|
|
" ),\n",
|
|
" shader=SoftSilhouetteShader(blend_params=blend_params)\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.\n",
|
|
"raster_settings = RasterizationSettings(\n",
|
|
" image_size=256, \n",
|
|
" blur_radius=0.0, \n",
|
|
" faces_per_pixel=1, \n",
|
|
")\n",
|
|
"# We can add a point light in front of the object. \n",
|
|
"lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))\n",
|
|
"phong_renderer = MeshRenderer(\n",
|
|
" rasterizer=MeshRasterizer(\n",
|
|
" cameras=cameras, \n",
|
|
" raster_settings=raster_settings\n",
|
|
" ),\n",
|
|
" shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "osOy2OIJrQdn"
|
|
},
|
|
"source": [
|
|
"### Create a reference image\n",
|
|
"\n",
|
|
"We will first position the teapot and generate an image. We use helper functions to rotate the teapot to a desired viewpoint. Then we can use the renderers to produce an image. Here we will use both renderers and visualize the silhouette and full shaded image. \n",
|
|
"\n",
|
|
"The world coordinate system is defined as +Y up, +X left and +Z in. The teapot in world coordinates has the spout pointing to the left. \n",
|
|
"\n",
|
|
"We defined a camera which is positioned on the positive z axis hence sees the spout to the right. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 305
|
|
},
|
|
"colab_type": "code",
|
|
"id": "EjJrW7qerQdo",
|
|
"outputId": "93545b65-269e-4719-f4a2-52cbc6c9c974"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Select the viewpoint using spherical angles \n",
|
|
"distance = 3 # distance from camera to the object\n",
|
|
"elevation = 50.0 # angle of elevation in degrees\n",
|
|
"azimuth = 0.0 # No rotation so the camera is positioned on the +Z axis. \n",
|
|
"\n",
|
|
"# Get the position of the camera based on the spherical angles\n",
|
|
"R, T = look_at_view_transform(distance, elevation, azimuth, device=device)\n",
|
|
"\n",
|
|
"# Render the teapot providing the values of R and T. \n",
|
|
"silhouette = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
|
|
"image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)\n",
|
|
"\n",
|
|
"silhouette = silhouette.cpu().numpy()\n",
|
|
"image_ref = image_ref.cpu().numpy()\n",
|
|
"\n",
|
|
"plt.figure(figsize=(10, 10))\n",
|
|
"plt.subplot(1, 2, 1)\n",
|
|
"plt.imshow(silhouette.squeeze()[..., 3]) # only plot the alpha channel of the RGBA image\n",
|
|
"plt.grid(False)\n",
|
|
"plt.subplot(1, 2, 2)\n",
|
|
"plt.imshow(image_ref.squeeze())\n",
|
|
"plt.grid(False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "plBJwEslrQdt"
|
|
},
|
|
"source": [
|
|
"### Set up a basic model \n",
|
|
"\n",
|
|
"Here we create a simple model class and initialize a parameter for the camera position. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "YBbP1-EDrQdu"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Model(nn.Module):\n",
|
|
" def __init__(self, meshes, renderer, image_ref):\n",
|
|
" super().__init__()\n",
|
|
" self.meshes = meshes\n",
|
|
" self.device = meshes.device\n",
|
|
" self.renderer = renderer\n",
|
|
" \n",
|
|
" # Get the silhouette of the reference RGB image by finding all non-white pixel values. \n",
|
|
" image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 1).astype(np.float32))\n",
|
|
" self.register_buffer('image_ref', image_ref)\n",
|
|
" \n",
|
|
" # Create an optimizable parameter for the x, y, z position of the camera. \n",
|
|
" self.camera_position = nn.Parameter(\n",
|
|
" torch.from_numpy(np.array([3.0, 6.9, +2.5], dtype=np.float32)).to(meshes.device))\n",
|
|
"\n",
|
|
" def forward(self):\n",
|
|
" \n",
|
|
" # Render the image using the updated camera position. Based on the new position of the \n",
|
|
" # camera we calculate the rotation and translation matrices\n",
|
|
" R = look_at_rotation(self.camera_position[None, :], device=self.device) # (1, 3, 3)\n",
|
|
" T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0] # (1, 3)\n",
|
|
" \n",
|
|
" image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)\n",
|
|
" \n",
|
|
" # Calculate the silhouette loss\n",
|
|
" loss = torch.sum((image[..., 3] - self.image_ref) ** 2)\n",
|
|
" return loss, image\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "qCGLSJtfrQdy"
|
|
},
|
|
"source": [
|
|
"## 3. Initialize the model and optimizer\n",
|
|
"\n",
|
|
"Now we can create an instance of the **model** above and set up an **optimizer** for the camera position parameter. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "srZPBU7_rQdz"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We will save images periodically and compose them into a GIF.\n",
|
|
"filename_output = \"./teapot_optimization_demo.gif\"\n",
|
|
"writer = imageio.get_writer(filename_output, mode='I', duration=0.3)\n",
|
|
"\n",
|
|
"# Initialize a model using the renderer, mesh and reference image\n",
|
|
"model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)\n",
|
|
"\n",
|
|
"# Create an optimizer. Here we are using Adam and we pass in the parameters of the model\n",
|
|
"optimizer = torch.optim.Adam(model.parameters(), lr=0.05)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "dvTLnrWorQd2"
|
|
},
|
|
"source": [
|
|
"### Visualize the starting position and the reference position"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 335
|
|
},
|
|
"colab_type": "code",
|
|
"id": "qyRXpP3mrQd3",
|
|
"outputId": "47ecb12a-e68c-47f5-92fc-821a7a9bd661"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.figure(figsize=(10, 10))\n",
|
|
"\n",
|
|
"_, image_init = model()\n",
|
|
"plt.subplot(1, 2, 1)\n",
|
|
"plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])\n",
|
|
"plt.grid(False)\n",
|
|
"plt.title(\"Starting position\")\n",
|
|
"\n",
|
|
"plt.subplot(1, 2, 2)\n",
|
|
"plt.imshow(model.image_ref.cpu().numpy().squeeze())\n",
|
|
"plt.grid(False)\n",
|
|
"plt.title(\"Reference silhouette\");\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "aGJu7h-lrQd5"
|
|
},
|
|
"source": [
|
|
"## 4. Run the optimization \n",
|
|
"\n",
|
|
"We run several iterations of the forward and backward pass and save outputs every 10 iterations. When this has finished take a look at `./teapot_optimization_demo.gif` for a cool gif of the optimization process!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 1000,
|
|
"referenced_widgets": [
|
|
"79d7fc84b5564206ab64b2759474da04",
|
|
"02acadb61c3949fcaeab177fd184c388",
|
|
"efd9860908c64bfe9d47118be4734648",
|
|
"f8df7c6efb7d47f5be760a39b4bdbcf8",
|
|
"d8a109658c364a00ab4d298112dac6db",
|
|
"2d05db82cc99482bb3d62b6d4e5b1a98",
|
|
"c621d425e2c8426c8cd4f9136d392af1",
|
|
"3df8063f307040ebb8ff8e2f26ccf729"
|
|
]
|
|
},
|
|
"colab_type": "code",
|
|
"id": "HvnK5VI5rQd6",
|
|
"outputId": "4019c697-3fc6-4c7b-cdfe-225633cc0d60"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"loop = tqdm(range(200))\n",
|
|
"for i in loop:\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss, _ = model()\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" loop.set_description('Optimizing (loss %.4f)' % loss.data)\n",
|
|
" \n",
|
|
" if loss.item() < 200:\n",
|
|
" break\n",
|
|
" \n",
|
|
" # Save outputs to create a GIF. \n",
|
|
" if i % 10 == 0:\n",
|
|
" R = look_at_rotation(model.camera_position[None, :], device=model.device)\n",
|
|
" T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0] # (1, 3)\n",
|
|
" image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)\n",
|
|
" image = image[0, ..., :3].detach().squeeze().cpu().numpy()\n",
|
|
" image = img_as_ubyte(image)\n",
|
|
" writer.append_data(image)\n",
|
|
" \n",
|
|
" plt.figure()\n",
|
|
" plt.imshow(image[..., :3])\n",
|
|
" plt.title(\"iter: %d, loss: %0.2f\" % (i, loss.data))\n",
|
|
" plt.axis(\"off\")\n",
|
|
" \n",
|
|
"writer.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"colab_type": "text",
|
|
"id": "mWj80P_SsPTN"
|
|
},
|
|
"source": [
|
|
"## 5. Conclusion \n",
|
|
"\n",
|
|
"In this tutorial we learnt how to **load** a mesh from an obj file, initialize a PyTorch3D datastructure called **Meshes**, set up an **Renderer** consisting of a **Rasterizer** and a **Shader**, set up an optimization loop including a **Model** and a **loss function**, and run the optimization. "
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"anp_metadata": {
|
|
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb"
|
|
},
|
|
"bento_stylesheets": {
|
|
"bento/extensions/flow/main.css": true,
|
|
"bento/extensions/kernel_selector/main.css": true,
|
|
"bento/extensions/kernel_ui/main.css": true,
|
|
"bento/extensions/new_kernel/main.css": true,
|
|
"bento/extensions/system_usage/main.css": true,
|
|
"bento/extensions/theme/main.css": true
|
|
},
|
|
"colab": {
|
|
"name": "camera_position_optimization_with_differentiable_rendering.ipynb",
|
|
"provenance": [],
|
|
"toc_visible": true
|
|
},
|
|
"disseminate_notebook_info": {
|
|
"backup_notebook_id": "1062179640844868"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "pytorch3d (local)",
|
|
"language": "python",
|
|
"name": "pytorch3d_local"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.5+"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|