mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Omit _check_valid_rotation_matrix by default
Summary: According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1. Comparison after applying the change: ``` Profiling/Function get_world_to_view (ms) Transform_points(ms) specular(ms) before 12.751 18.577 21.384 after 4.432 (34.7%) 9.248 (49.8%) 11.507 (53.8%) ``` Profiling trace: https://pxl.cl/2h687 More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing Reviewed By: kjchalup Differential Revision: D40442503 fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0
This commit is contained in:
parent
8339cf2610
commit
46cb5aaaae
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@ -636,7 +637,10 @@ class Rotate(Transform3d):
|
||||
msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
|
||||
raise ValueError(msg % repr(R.shape))
|
||||
R = R.to(device=device_, dtype=dtype)
|
||||
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
||||
if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1":
|
||||
# Note: aten::all_close in the check is computationally slow, so we
|
||||
# only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on.
|
||||
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
|
||||
N = R.shape[0]
|
||||
mat = torch.eye(4, dtype=dtype, device=device_)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
|
@ -4,9 +4,10 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from pytorch3d.transforms import random_rotations
|
||||
@ -191,7 +192,25 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
||||
|
||||
def test_rotate(self):
|
||||
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
|
||||
def test_rotate_check_rot_valid_on(self):
|
||||
R = so3_exp_map(torch.randn((1, 3)))
|
||||
t = Transform3d().rotate(R)
|
||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||
1, 3, 3
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
|
||||
).view(1, 3, 3)
|
||||
points_out = t.transform_points(points)
|
||||
normals_out = t.transform_normals(normals)
|
||||
points_out_expected = torch.bmm(points, R)
|
||||
normals_out_expected = torch.bmm(normals, R)
|
||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
||||
|
||||
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
|
||||
def test_rotate_check_rot_valid_off(self):
|
||||
R = so3_exp_map(torch.randn((1, 3)))
|
||||
t = Transform3d().rotate(R)
|
||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||
|
Loading…
x
Reference in New Issue
Block a user