# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from itertools import product from typing import Any, Callable import torch from common_testing import get_random_cuda_device from fvcore.common.benchmark import benchmark from pytorch3d.common.workaround import symeig3x3 from tests.test_symeig3x3 import TestSymEig3x3 torch.set_num_threads(1) CUDA_DEVICE = get_random_cuda_device() def create_traced_func(func, device, batch_size): traced_func = torch.jit.trace( func, (TestSymEig3x3.create_random_sym3x3(device, batch_size),) ) return traced_func FUNC_NAME_TO_FUNC = { "sym3x3eig": (lambda inputs: symeig3x3(inputs, eigenvectors=True)), "sym3x3eig_traced_cuda": create_traced_func( (lambda inputs: symeig3x3(inputs, eigenvectors=True)), CUDA_DEVICE, 1024 ), "torch_symeig": (lambda inputs: torch.symeig(inputs, eigenvectors=True)), "torch_linalg_eigh": (lambda inputs: torch.linalg.eigh(inputs)), "torch_pca_lowrank": ( lambda inputs: torch.pca_lowrank(inputs, center=False, niter=1) ), "sym3x3eig_no_vecs": (lambda inputs: symeig3x3(inputs, eigenvectors=False)), "torch_symeig_no_vecs": (lambda inputs: torch.symeig(inputs, eigenvectors=False)), "torch_linalg_eigvalsh_no_vecs": (lambda inputs: torch.linalg.eigvalsh(inputs)), } def test_symeig3x3(func_name, batch_size, device) -> Callable[[], Any]: func = FUNC_NAME_TO_FUNC[func_name] inputs = TestSymEig3x3.create_random_sym3x3(device, batch_size) torch.cuda.synchronize() def symeig3x3(): func(inputs) torch.cuda.synchronize() return symeig3x3 def bm_symeig3x3() -> None: devices = ["cpu"] if torch.cuda.is_available(): devices.append(CUDA_DEVICE) kwargs_list = [] func_names = FUNC_NAME_TO_FUNC.keys() batch_sizes = [16, 128, 1024, 8192, 65536, 1048576] for func_name, batch_size, device in product(func_names, batch_sizes, devices): # Run CUDA-only implementations only on GPU if "cuda" in func_name and not device.startswith("cuda"): continue # Torch built-ins are quite slow on larger batches if "torch" in func_name and batch_size > 8192: continue # Avoid running CPU implementations on larger batches as well if device == "cpu" and batch_size > 8192: continue kwargs_list.append( {"func_name": func_name, "batch_size": batch_size, "device": device} ) benchmark( test_symeig3x3, "SYMEIG3X3", kwargs_list, warmup_iters=3, ) if __name__ == "__main__": bm_symeig3x3()