avoid math.prod for python 3.7

Summary: This makes the new volumes tutorial work on google colab.

Reviewed By: kjchalup

Differential Revision: D38501906

fbshipit-source-id: a606a357e929dae903dc4d9067bd1519f05b1458
This commit is contained in:
Jeremy Reizenstein 2022-08-09 20:48:51 -07:00 committed by Facebook GitHub Bot
parent c49ebad249
commit 791a068183
4 changed files with 21 additions and 11 deletions

View File

@ -10,7 +10,7 @@ import torch
""" """
Some functions which depend on PyTorch versions. Some functions which depend on PyTorch or Python versions.
""" """
@ -79,3 +79,12 @@ def meshgrid_ij(
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got # pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
# `Union[Sequence[Tensor], Tensor]`. # `Union[Sequence[Tensor], Tensor]`.
return torch.meshgrid(*A) return torch.meshgrid(*A)
def prod(iterable, *, start=1):
"""
Like math.prod in Python 3.8 and later.
"""
for i in iterable:
start *= i
return start

View File

@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import tqdm import tqdm
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.metrics import ( from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase, RegularizationMetricsBase,
ViewMetricsBase, ViewMetricsBase,
@ -919,7 +920,7 @@ def _chunk_generator(
f"by n_pts_per_ray ({n_pts_per_ray})" f"by n_pts_per_ray ({n_pts_per_ray})"
) )
n_rays = math.prod(spatial_dim) n_rays = prod(spatial_dim)
# special handling for raytracing-based methods # special handling for raytracing-based methods
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
chunk_size_in_rays = -(-n_rays // n_chunks) chunk_size_in_rays = -(-n_rays // n_chunks)
@ -935,9 +936,9 @@ def _chunk_generator(
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx :, start_idx:end_idx
], ],
lengths=ray_bundle.lengths.reshape( lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
batch_size, math.prod(spatial_dim), n_pts_per_ray :, start_idx:end_idx
)[:, start_idx:end_idx], ],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
) )
extra_args = kwargs.copy() extra_args = kwargs.copy()

View File

@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
from pytorch3d.common.compat import prod
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
@ -52,7 +52,7 @@ def create_embeddings_for_implicit_function(
embeds = torch.empty( embeds = torch.empty(
bs, bs,
1, 1,
math.prod(spatial_size), prod(spatial_size),
pts_per_ray, pts_per_ray,
0, 0,
dtype=xyz_world.dtype, dtype=xyz_world.dtype,
@ -62,7 +62,7 @@ def create_embeddings_for_implicit_function(
embeds = xyz_embedding_function(ray_points_for_embed).reshape( embeds = xyz_embedding_function(ray_points_for_embed).reshape(
bs, bs,
1, 1,
math.prod(spatial_size), prod(spatial_size),
pts_per_ray, pts_per_ray,
-1, -1,
) # flatten spatial, add n_src dim ) # flatten spatial, add n_src dim
@ -73,7 +73,7 @@ def create_embeddings_for_implicit_function(
embed_shape = ( embed_shape = (
bs, bs,
embeds_viewpooled.shape[1], embeds_viewpooled.shape[1],
math.prod(spatial_size), prod(spatial_size),
pts_per_ray, pts_per_ray,
-1, -1,
) )

View File

@ -3,11 +3,11 @@
# implicit_differentiable_renderer.py # implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv # Copyright (c) 2020 Lior Yariv
import functools import functools
import math
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
get_default_args_field, get_default_args_field,
registry, registry,
@ -105,7 +105,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
# object_mask: silhouette of the object # object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape batch_size, *spatial_size, _ = ray_bundle.lengths.shape
num_pixels = math.prod(spatial_size) num_pixels = prod(spatial_size)
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3) cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3) ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)