diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 24ee3981..f130dbc0 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -6,7 +6,7 @@ import math import warnings -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -1851,12 +1851,15 @@ def get_screen_to_ndc_transform( return transform -def try_get_projection_transform(cameras: CamerasBase, kwargs) -> Optional[Transform3d]: +def try_get_projection_transform( + cameras: CamerasBase, cameras_kwargs: Dict[str, Any] +) -> Optional[Transform3d]: """ - Try block to get projection transform. + Try block to get projection transform from cameras and cameras_kwargs. Args: - cameras instance, can be linear cameras or nonliear cameras + cameras: cameras instance, can be linear cameras or nonliear cameras + cameras_kwargs: camera parameters to be passed to cameras Returns: If the camera implemented projection_transform, return the @@ -1865,7 +1868,7 @@ def try_get_projection_transform(cameras: CamerasBase, kwargs) -> Optional[Trans transform = None try: - transform = cameras.get_projection_transform(**kwargs) + transform = cameras.get_projection_transform(**cameras_kwargs) except NotImplementedError: pass return transform