mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Omit _check_valid_rotation_matrix by default
Summary: According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1. Comparison after applying the change: ``` Profiling/Function get_world_to_view (ms) Transform_points(ms) specular(ms) before 12.751 18.577 21.384 after 4.432 (34.7%) 9.248 (49.8%) 11.507 (53.8%) ``` Profiling trace: https://pxl.cl/2h687 More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing Reviewed By: kjchalup Differential Revision: D40442503 fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0
This commit is contained in:
		
							parent
							
								
									8339cf2610
								
							
						
					
					
						commit
						46cb5aaaae
					
				@ -5,6 +5,7 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import List, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -636,7 +637,10 @@ class Rotate(Transform3d):
 | 
			
		||||
            msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
 | 
			
		||||
            raise ValueError(msg % repr(R.shape))
 | 
			
		||||
        R = R.to(device=device_, dtype=dtype)
 | 
			
		||||
        _check_valid_rotation_matrix(R, tol=orthogonal_tol)
 | 
			
		||||
        if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1":
 | 
			
		||||
            # Note: aten::all_close in the check is computationally slow, so we
 | 
			
		||||
            # only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on.
 | 
			
		||||
            _check_valid_rotation_matrix(R, tol=orthogonal_tol)
 | 
			
		||||
        N = R.shape[0]
 | 
			
		||||
        mat = torch.eye(4, dtype=dtype, device=device_)
 | 
			
		||||
        mat = mat.view(1, 4, 4).repeat(N, 1, 1)
 | 
			
		||||
 | 
			
		||||
@ -4,9 +4,10 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from unittest import mock
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.transforms import random_rotations
 | 
			
		||||
@ -191,7 +192,25 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertTrue(torch.allclose(points_out, points_out_expected))
 | 
			
		||||
        self.assertTrue(torch.allclose(normals_out, normals_out_expected))
 | 
			
		||||
 | 
			
		||||
    def test_rotate(self):
 | 
			
		||||
    @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
 | 
			
		||||
    def test_rotate_check_rot_valid_on(self):
 | 
			
		||||
        R = so3_exp_map(torch.randn((1, 3)))
 | 
			
		||||
        t = Transform3d().rotate(R)
 | 
			
		||||
        points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
 | 
			
		||||
            1, 3, 3
 | 
			
		||||
        )
 | 
			
		||||
        normals = torch.tensor(
 | 
			
		||||
            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
 | 
			
		||||
        ).view(1, 3, 3)
 | 
			
		||||
        points_out = t.transform_points(points)
 | 
			
		||||
        normals_out = t.transform_normals(normals)
 | 
			
		||||
        points_out_expected = torch.bmm(points, R)
 | 
			
		||||
        normals_out_expected = torch.bmm(normals, R)
 | 
			
		||||
        self.assertTrue(torch.allclose(points_out, points_out_expected))
 | 
			
		||||
        self.assertTrue(torch.allclose(normals_out, normals_out_expected))
 | 
			
		||||
 | 
			
		||||
    @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
 | 
			
		||||
    def test_rotate_check_rot_valid_off(self):
 | 
			
		||||
        R = so3_exp_map(torch.randn((1, 3)))
 | 
			
		||||
        t = Transform3d().rotate(R)
 | 
			
		||||
        points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user