more code blocks for readthedocs

Summary: More rst syntax fixes

Reviewed By: davidsonic

Differential Revision: D40977328

fbshipit-source-id: a3a3accbf2ba7cd9c84a0a82d0265010764a9d61
This commit is contained in:
Jeremy Reizenstein 2022-11-29 03:13:40 -08:00 committed by Facebook GitHub Bot
parent 94f321fa3d
commit a2c6af9250
14 changed files with 167 additions and 174 deletions

View File

@ -25,19 +25,16 @@ This module has the master functions for loading and saving data.
The main usage is via the IO object, and its methods The main usage is via the IO object, and its methods
`load_mesh`, `save_mesh`, `load_pointcloud` and `save_pointcloud`. `load_mesh`, `save_mesh`, `load_pointcloud` and `save_pointcloud`.
For example, to load a mesh you might do For example, to load a mesh you might do::
```
from pytorch3d.io import IO from pytorch3d.io import IO
mesh = IO().load_mesh("mymesh.obj") mesh = IO().load_mesh("mymesh.obj")
```
and to save a point cloud you might do and to save a point cloud you might do::
```
pcl = Pointclouds(...) pcl = Pointclouds(...)
IO().save_pointcloud(pcl, "output_pointcloud.obj") IO().save_pointcloud(pcl, "output_pointcloud.obj")
```
""" """

View File

@ -1067,7 +1067,7 @@ def load_ply(
is to use the IO.load_mesh and IO.load_pointcloud functions, is to use the IO.load_mesh and IO.load_pointcloud functions,
which can read more of the data. which can read more of the data.
Example .ply file format: Example .ply file format::
ply ply
format ascii 1.0 { ascii/binary, format version number } format ascii 1.0 { ascii/binary, format version number }

View File

@ -208,8 +208,8 @@ def add_pointclouds_to_volumes(
of `initial_volumes` with its `features` and `densities` updated with the of `initial_volumes` with its `features` and `densities` updated with the
result of the pointcloud addition. result of the pointcloud addition.
Example: Example::
```
# init a random point cloud # init a random point cloud
pointclouds = Pointclouds( pointclouds = Pointclouds(
points=torch.randn(4, 100, 3), features=torch.rand(4, 100, 5) points=torch.randn(4, 100, 3), features=torch.rand(4, 100, 5)
@ -229,7 +229,6 @@ def add_pointclouds_to_volumes(
initial_volumes=initial_volumes, initial_volumes=initial_volumes,
mode="trilinear", mode="trilinear",
) )
```
Args: Args:
pointclouds: Batch of 3D pointclouds represented with a `Pointclouds` pointclouds: Batch of 3D pointclouds represented with a `Pointclouds`

View File

@ -21,8 +21,8 @@ class HarmonicEmbedding(torch.nn.Module):
(i.e. vector along the last dimension) in `x` (i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`, into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present where for each i in range(dim) the following are present
in embedding[...]: in embedding[...]::
```
[ [
sin(f_1*x[..., i]), sin(f_1*x[..., i]),
sin(f_2*x[..., i]), sin(f_2*x[..., i]),
@ -34,7 +34,7 @@ class HarmonicEmbedding(torch.nn.Module):
cos(f_N * x[..., i]), cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True. x[..., i], # only present if append_input is True.
] ]
```
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding. denoting the i-th frequency of the harmonic embedding.

View File

@ -25,20 +25,20 @@ class EmissionAbsorptionRaymarcher(torch.nn.Module):
(i.e. its density -> 1.0). (i.e. its density -> 1.0).
EA first utilizes `rays_densities` to compute the absorption function EA first utilizes `rays_densities` to compute the absorption function
along each ray as follows: along each ray as follows::
```
absorption = cumprod(1 - rays_densities, dim=-1) absorption = cumprod(1 - rays_densities, dim=-1)
```
The value of absorption at position `absorption[..., k]` specifies The value of absorption at position `absorption[..., k]` specifies
how much light has reached `k`-th point along a ray since starting how much light has reached `k`-th point along a ray since starting
its trajectory at `k=0`-th point. its trajectory at `k=0`-th point.
Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)` Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)`
by taking a weighed combination of per-ray features `rays_features` as follows: by taking a weighed combination of per-ray features `rays_features` as follows::
```
weights = absorption * rays_densities weights = absorption * rays_densities
features = (rays_features * weights).sum(dim=-2) features = (rays_features * weights).sum(dim=-2)
```
Where `weights` denote a function that has a strong peak around the location Where `weights` denote a function that has a strong peak around the location
of the first surface point that a given ray passes through. of the first surface point that a given ray passes through.

View File

@ -32,8 +32,8 @@ class MultinomialRaysampler(torch.nn.Module):
have uniformly-spaced z-coordinates between a predefined have uniformly-spaced z-coordinates between a predefined
minimum and maximum depth. minimum and maximum depth.
The raysampler first generates a 3D coordinate grid of the following form: The raysampler first generates a 3D coordinate grid of the following form::
```
/ min_x, min_y, max_depth -------------- / max_x, min_y, max_depth / min_x, min_y, max_depth -------------- / max_x, min_y, max_depth
/ /| / /|
/ / | ^ / / | ^
@ -48,7 +48,6 @@ class MultinomialRaysampler(torch.nn.Module):
min_x max_y / / n_pts_per_ray min_x max_y / / n_pts_per_ray
max_y ----------------------------- max_x/ min_depth v max_y ----------------------------- max_x/ min_depth v
< --- image_width --- > < --- image_width --- >
```
In order to generate ray points, `MultinomialRaysampler` takes each 3D point of In order to generate ray points, `MultinomialRaysampler` takes each 3D point of
the grid (with coordinates `[x, y, depth]`) and unprojects it the grid (with coordinates `[x, y, depth]`) and unprojects it

View File

@ -41,13 +41,13 @@ class ImplicitRenderer(torch.nn.Module):
as well as the `volumetric_function` `Callable`, which defines a field of opacity as well as the `volumetric_function` `Callable`, which defines a field of opacity
and feature vectors over the 3D domain of the scene. and feature vectors over the 3D domain of the scene.
A standard `volumetric_function` has the following signature: A standard `volumetric_function` has the following signature::
```
def volumetric_function( def volumetric_function(
ray_bundle: Union[RayBundle, HeterogeneousRayBundle], ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]
```
With the following arguments: With the following arguments:
`ray_bundle`: A RayBundle or HeterogeneousRayBundle object `ray_bundle`: A RayBundle or HeterogeneousRayBundle object
containing the following variables: containing the following variables:
@ -79,8 +79,8 @@ class ImplicitRenderer(torch.nn.Module):
Example: Example:
A simple volumetric function of a 0-centered A simple volumetric function of a 0-centered
RGB sphere with a unit diameter is defined as follows: RGB sphere with a unit diameter is defined as follows::
```
def volumetric_function( def volumetric_function(
ray_bundle: Union[RayBundle, HeterogeneousRayBundle], ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
**kwargs, **kwargs,
@ -104,7 +104,7 @@ class ImplicitRenderer(torch.nn.Module):
) * 0.5 + 0.5 ) * 0.5 + 0.5
return rays_densities, rays_features return rays_densities, rays_features
```
""" """
def __init__(self, raysampler: Callable, raymarcher: Callable) -> None: def __init__(self, raysampler: Callable, raymarcher: Callable) -> None:

View File

@ -73,13 +73,13 @@ def ray_bundle_to_ray_points(
extending each ray according to the corresponding length. extending each ray according to the corresponding length.
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions` E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
and `ray_bundle.lengths`, the ray point at position `[i, j]` is: and `ray_bundle.lengths`, the ray point at position `[i, j]` is::
```
ray_bundle.points[i, j, :] = ( ray_bundle.points[i, j, :] = (
ray_bundle.origins[i, :] ray_bundle.origins[i, :]
+ ray_bundle.directions[i, :] * ray_bundle.lengths[i, j] + ray_bundle.directions[i, :] * ray_bundle.lengths[i, j]
) )
```
Note that both the directions and magnitudes of the vectors in Note that both the directions and magnitudes of the vectors in
`ray_bundle.directions` matter. `ray_bundle.directions` matter.
@ -109,13 +109,13 @@ def ray_bundle_variables_to_ray_points(
ray length: ray length:
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions` E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
and `rays_lengths`, the ray point at position `[i, j]` is: and `rays_lengths`, the ray point at position `[i, j]` is::
```
rays_points[i, j, :] = ( rays_points[i, j, :] = (
rays_origins[i, :] rays_origins[i, :]
+ rays_directions[i, :] * rays_lengths[i, j] + rays_directions[i, :] * rays_lengths[i, j]
) )
```
Note that both the directions and magnitudes of the vectors in Note that both the directions and magnitudes of the vectors in
`rays_directions` matter. `rays_directions` matter.

View File

@ -285,26 +285,26 @@ class _DeviceContextStore:
The EGL/CUDA contexts are not meant to be created and destroyed all the time, The EGL/CUDA contexts are not meant to be created and destroyed all the time,
and having multiple on a single device can be troublesome. Intended use is entirely and having multiple on a single device can be troublesome. Intended use is entirely
transparent to the user: transparent to the user::
```
rasterizer1 = MeshRasterizerOpenGL(...some args...) rasterizer1 = MeshRasterizerOpenGL(...some args...)
mesh1 = load_mesh_on_cuda_0() mesh1 = load_mesh_on_cuda_0()
# Now rasterizer1 will request EGL/CUDA contexts from global_device_context_store # Now rasterizer1 will request EGL/CUDA contexts from
# on cuda:0, and since there aren't any, the store will create new ones. # global_device_context_store on cuda:0, and since there aren't any, the
# store will create new ones.
rasterizer1.rasterize(mesh1) rasterizer1.rasterize(mesh1)
# rasterizer2 also needs EGL & CUDA contexts. But global_context_store already has # rasterizer2 also needs EGL & CUDA contexts. But global_context_store
# them for cuda:0. Instead of creating new contexts, the store will tell rasterizer2 # already has them for cuda:0. Instead of creating new contexts, the store
# to use them. # will tell rasterizer2 to use them.
rasterizer2 = MeshRasterizerOpenGL(dcs) rasterizer2 = MeshRasterizerOpenGL(dcs)
rasterize2.rasterize(mesh1) rasterize2.rasterize(mesh1)
# When rasterizer1 needs to render on cuda:1, the store will create new contexts. # When rasterizer1 needs to render on cuda:1, the store will create new contexts.
mesh2 = load_mesh_on_cuda_1() mesh2 = load_mesh_on_cuda_1()
rasterizer1.rasterize(mesh2) rasterizer1.rasterize(mesh2)
```
""" """
def __init__(self): def __init__(self):

View File

@ -11,7 +11,6 @@ Proper Python support for pytorch requires creating a torch.autograd.function
here and a torch.nn.Module is exposed for the use in more complex models. here and a torch.nn.Module is exposed for the use in more complex models.
""" """
import logging import logging
import math
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@ -97,12 +97,12 @@ class Volumes:
in the world coordinates. in the world coordinates.
- They are specified with the following mapping that converts - They are specified with the following mapping that converts
points `x_local` in the local coordinates to points `x_world` points `x_local` in the local coordinates to points `x_world`
in the world coordinates: in the world coordinates::
```
x_world = ( x_world = (
x_local * (volume_size - 1) * 0.5 * voxel_size x_local * (volume_size - 1) * 0.5 * voxel_size
) - volume_translation, ) - volume_translation,
```
here `voxel_size` specifies the size of each voxel of the volume, here `voxel_size` specifies the size of each voxel of the volume,
and `volume_translation` is the 3D offset of the central voxel of and `volume_translation` is the 3D offset of the central voxel of
the volume w.r.t. the origin of the world coordinate frame. the volume w.r.t. the origin of the world coordinate frame.
@ -110,12 +110,12 @@ class Volumes:
the world coordinate units. `volume_size` is the spatial size of the world coordinate units. `volume_size` is the spatial size of
the volume in form of a 3D vector `[width, height, depth]`. the volume in form of a 3D vector `[width, height, depth]`.
- Given the above definition of `x_world`, one can derive the - Given the above definition of `x_world`, one can derive the
inverse mapping from `x_world` to `x_local` as follows: inverse mapping from `x_world` to `x_local` as follows::
```
x_local = ( x_local = (
(x_world + volume_translation) / (0.5 * voxel_size) (x_world + volume_translation) / (0.5 * voxel_size)
) / (volume_size - 1) ) / (volume_size - 1)
```
- For a trivial volume with `volume_translation==[0, 0, 0]` - For a trivial volume with `volume_translation==[0, 0, 0]`
with `voxel_size=-1`, `x_world` would range with `voxel_size=-1`, `x_world` would range
from -(volume_size-1)/2` to `+(volume_size-1)/2`. from -(volume_size-1)/2` to `+(volume_size-1)/2`.
@ -139,13 +139,13 @@ class Volumes:
to `x_local=(-1, 0, 1)`. to `x_local=(-1, 0, 1)`.
- For a "trivial" volume `v` with `voxel_size = 1.`, - For a "trivial" volume `v` with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, the following holds: `volume_translation=[0., 0., 0.]`, the following holds:
```
torch.nn.functional.grid_sample( torch.nn.functional.grid_sample(
v.densities(), v.densities(),
v.get_coord_grid(world_coordinates=False), v.get_coord_grid(world_coordinates=False),
align_corners=True, align_corners=True,
) == v.densities(), ) == v.densities(),
```
i.e. sampling the volume at trivial local coordinates i.e. sampling the volume at trivial local coordinates
(no scaling with `voxel_size`` or shift with `volume_translation`) (no scaling with `voxel_size`` or shift with `volume_translation`)
results in the same volume. results in the same volume.
@ -588,12 +588,12 @@ class VolumeLocator:
in the world coordinates. in the world coordinates.
- They are specified with the following mapping that converts - They are specified with the following mapping that converts
points `x_local` in the local coordinates to points `x_world` points `x_local` in the local coordinates to points `x_world`
in the world coordinates: in the world coordinates::
```
x_world = ( x_world = (
x_local * (volume_size - 1) * 0.5 * voxel_size x_local * (volume_size - 1) * 0.5 * voxel_size
) - volume_translation, ) - volume_translation,
```
here `voxel_size` specifies the size of each voxel of the volume, here `voxel_size` specifies the size of each voxel of the volume,
and `volume_translation` is the 3D offset of the central voxel of and `volume_translation` is the 3D offset of the central voxel of
the volume w.r.t. the origin of the world coordinate frame. the volume w.r.t. the origin of the world coordinate frame.
@ -601,12 +601,12 @@ class VolumeLocator:
the world coordinate units. `volume_size` is the spatial size of the world coordinate units. `volume_size` is the spatial size of
the volume in form of a 3D vector `[width, height, depth]`. the volume in form of a 3D vector `[width, height, depth]`.
- Given the above definition of `x_world`, one can derive the - Given the above definition of `x_world`, one can derive the
inverse mapping from `x_world` to `x_local` as follows: inverse mapping from `x_world` to `x_local` as follows::
```
x_local = ( x_local = (
(x_world + volume_translation) / (0.5 * voxel_size) (x_world + volume_translation) / (0.5 * voxel_size)
) / (volume_size - 1) ) / (volume_size - 1)
```
- For a trivial volume with `volume_translation==[0, 0, 0]` - For a trivial volume with `volume_translation==[0, 0, 0]`
with `voxel_size=-1`, `x_world` would range with `voxel_size=-1`, `x_world` would range
from -(volume_size-1)/2` to `+(volume_size-1)/2`. from -(volume_size-1)/2` to `+(volume_size-1)/2`.
@ -629,14 +629,14 @@ class VolumeLocator:
`DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped `DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
to `x_local=(-1, 0, 1)`. to `x_local=(-1, 0, 1)`.
- For a "trivial" volume `v` with `voxel_size = 1.`, - For a "trivial" volume `v` with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, the following holds: `volume_translation=[0., 0., 0.]`, the following holds::
```
torch.nn.functional.grid_sample( torch.nn.functional.grid_sample(
v.densities(), v.densities(),
v.get_coord_grid(world_coordinates=False), v.get_coord_grid(world_coordinates=False),
align_corners=True, align_corners=True,
) == v.densities(), ) == v.densities(),
```
i.e. sampling the volume at trivial local coordinates i.e. sampling the volume at trivial local coordinates
(no scaling with `voxel_size`` or shift with `volume_translation`) (no scaling with `voxel_size`` or shift with `volume_translation`)
results in the same volume. results in the same volume.

View File

@ -22,8 +22,8 @@ def acos_linear_extrapolation(
domain of `(-1, 1)`. This allows for stable backpropagation in case `x` domain of `(-1, 1)`. This allows for stable backpropagation in case `x`
is not guaranteed to be strictly within `(-1, 1)`. is not guaranteed to be strictly within `(-1, 1)`.
More specifically: More specifically::
```
bounds=(lower_bound, upper_bound) bounds=(lower_bound, upper_bound)
if lower_bound <= x <= upper_bound: if lower_bound <= x <= upper_bound:
acos_linear_extrapolation(x) = acos(x) acos_linear_extrapolation(x) = acos(x)
@ -33,7 +33,6 @@ def acos_linear_extrapolation(
else: # x >= upper_bound else: # x >= upper_bound
acos_linear_extrapolation(x) acos_linear_extrapolation(x)
= acos(upper_bound) + dacos/dx(upper_bound) * (x - upper_bound) = acos(upper_bound) + dacos/dx(upper_bound) * (x - upper_bound)
```
Args: Args:
x: Input `Tensor`. x: Input `Tensor`.

View File

@ -256,12 +256,12 @@ class Transform3d:
The conversion from the 4x4 SE(3) matrix `transform` to the The conversion from the 4x4 SE(3) matrix `transform` to the
6D representation `log_transform = [log_translation | log_rotation]` 6D representation `log_transform = [log_translation | log_rotation]`
is done as follows: is done as follows::
```
log_transform = log(transform.get_matrix()) log_transform = log(transform.get_matrix())
log_translation = log_transform[3, :3] log_translation = log_transform[3, :3]
log_rotation = inv_hat(log_transform[:3, :3]) log_rotation = inv_hat(log_transform[:3, :3])
```
where `log` is the matrix logarithm where `log` is the matrix logarithm
and `inv_hat` is the inverse of the Hat operator [2]. and `inv_hat` is the inverse of the Hat operator [2].

View File

@ -35,10 +35,10 @@ def cameras_from_opencv_projection(
to the NDC screen convention of PyTorch3D. to the NDC screen convention of PyTorch3D.
More specifically, the OpenCV convention projects points to the OpenCV screen More specifically, the OpenCV convention projects points to the OpenCV screen
space as follows: space as follows::
```
x_screen_opencv = camera_matrix @ (R @ x_world + tvec) x_screen_opencv = camera_matrix @ (R @ x_world + tvec)
```
followed by the homogenization of `x_screen_opencv`. followed by the homogenization of `x_screen_opencv`.
Note: Note: