# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import itertools from fvcore.common.benchmark import benchmark from pytorch3d.renderer import ( GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler, FoVOrthographicCameras, FoVPerspectiveCameras, OrthographicCameras, PerspectiveCameras, ) from test_raysampling import TestRaysampling def bm_raysampling() -> None: case_grid = { "raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler], "camera_type": [ PerspectiveCameras, OrthographicCameras, FoVPerspectiveCameras, FoVOrthographicCameras, ], "batch_size": [1, 10], "n_pts_per_ray": [10, 1000, 10000], "image_width": [10, 300], "image_height": [10, 300], } test_cases = itertools.product(*case_grid.values()) kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] benchmark(TestRaysampling.raysampler, "RAYSAMPLER", kwargs_list, warmup_iters=1) if __name__ == "__main__": bm_raysampling()