mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add benchmark for diffuse and specular lighting
Summary: I was trying to speed up the lighting computations, but my ideas didn't work. Even if that didn't work, we can at least commit the benchmarking script I wrote for diffuse and specular shading. Reviewed By: nikhilaravi Differential Revision: D21580171 fbshipit-source-id: 8b60c0284e91ecbe258b6aae839bd5c2bbe788aa
This commit is contained in:
parent
3fef506895
commit
d8987c6f48
@ -61,6 +61,7 @@ def diffuse(normals, color, direction) -> torch.Tensor:
|
||||
color = color.view(expand_dims)
|
||||
|
||||
# Renormalize the normals in case they have been interpolated.
|
||||
# We tried to replace the following with F.cosine_similarity, but it wasn't faster.
|
||||
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
|
||||
direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
|
||||
angle = F.relu(torch.sum(normals * direction, dim=-1))
|
||||
@ -132,6 +133,8 @@ def specular(
|
||||
shininess = shininess.view(expand_dims)
|
||||
|
||||
# Renormalize the normals in case they have been interpolated.
|
||||
# We tried a version that uses F.cosine_similarity instead of renormalizing,
|
||||
# but it was slower.
|
||||
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
|
||||
direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
|
||||
cos_angle = torch.sum(normals * direction, dim=-1)
|
||||
|
47
tests/bm_lighting.py
Normal file
47
tests/bm_lighting.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from pytorch3d.renderer.lighting import diffuse, specular
|
||||
|
||||
|
||||
def _bm_diffuse_cuda_with_init(N, S, K):
|
||||
device = torch.device("cuda")
|
||||
normals = torch.randn(N, S, S, K, 3, device=device)
|
||||
color = torch.randn(1, 3, device=device)
|
||||
direction = torch.randn(N, S, S, K, 3, device=device)
|
||||
args = (normals, color, direction)
|
||||
torch.cuda.synchronize()
|
||||
return lambda: diffuse(*args)
|
||||
|
||||
|
||||
def _bm_specular_cuda_with_init(N, S, K):
|
||||
device = torch.device("cuda")
|
||||
points = torch.randn(N, S, S, K, 3, device=device)
|
||||
normals = torch.randn(N, S, S, K, 3, device=device)
|
||||
direction = torch.randn(N, S, S, K, 3, device=device)
|
||||
color = torch.randn(1, 3, device=device)
|
||||
camera_position = torch.randn(N, 3, device=device)
|
||||
shininess = torch.randn(N, device=device)
|
||||
args = (points, normals, direction, color, camera_position, shininess)
|
||||
torch.cuda.synchronize()
|
||||
return lambda: specular(*args)
|
||||
|
||||
|
||||
def bm_lighting() -> None:
|
||||
# For now only benchmark lighting on GPU
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
kwargs_list = []
|
||||
Ns = [1, 8]
|
||||
Ss = [128, 256]
|
||||
Ks = [1, 10, 80]
|
||||
test_cases = product(Ns, Ss, Ks)
|
||||
for case in test_cases:
|
||||
N, S, K = case
|
||||
kwargs_list.append({"N": N, "S": S, "K": K})
|
||||
benchmark(_bm_diffuse_cuda_with_init, "DIFFUSE", kwargs_list, warmup_iters=3)
|
||||
benchmark(_bm_specular_cuda_with_init, "SPECULAR", kwargs_list, warmup_iters=3)
|
Loading…
x
Reference in New Issue
Block a user