From bcee361d048f14b3d1fbfa2c3e498d64c06a7612 Mon Sep 17 00:00:00 2001 From: Alexey Sidnev Date: Mon, 19 Jul 2021 05:01:56 -0700 Subject: [PATCH] Replace `torch.det()` with manual implementation for 3x3 matrix Summary: # Background There is an unstable error during training (it can happen after several minutes or after several hours). The error is connected to `torch.det()` function in `_check_valid_rotation_matrix()`. if I remove the function `torch.det()` in `_check_valid_rotation_matrix()` or remove the whole functions `_check_valid_rotation_matrix()` the error is disappeared (D29555876). # Solution Replace `torch.det()` with manual implementation for 3x3 matrix. Reviewed By: patricklabatut Differential Revision: D29655924 fbshipit-source-id: 41bde1119274a705ab849751ece28873d2c45155 --- pytorch3d/common/workaround.py | 31 ++++++++++++++++ pytorch3d/transforms/transform3d.py | 3 +- tests/test_common_workaround.py | 56 +++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/common/workaround.py create mode 100644 tests/test_common_workaround.py diff --git a/pytorch3d/common/workaround.py b/pytorch3d/common/workaround.py new file mode 100644 index 00000000..dbc5f156 --- /dev/null +++ b/pytorch3d/common/workaround.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + + +def _safe_det_3x3(t: torch.Tensor): + """ + Fast determinant calculation for a batch of 3x3 matrices. + + Note, result of this function might not be the same as `torch.det()`. + The differences might be in the last significant digit. + + Args: + t: Tensor of shape (N, 3, 3). + + Returns: + Tensor of shape (N) with determinants. + """ + + det = ( + t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1]) + - t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2]) + + t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1]) + ) + + return det diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 11543044..cab6a5ff 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -11,6 +11,7 @@ from typing import List, Optional, Union import torch from ..common.types import Device, get_device, make_device +from ..common.workaround import _safe_det_3x3 from .rotation_conversions import _axis_angle_rotation @@ -774,7 +775,7 @@ def _check_valid_rotation_matrix(R, tol: float = 1e-7): eye = torch.eye(3, dtype=R.dtype, device=R.device) eye = eye.view(1, 3, 3).expand(N, -1, -1) orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol) - det_R = torch.det(R) + det_R = _safe_det_3x3(R) no_distortion = torch.allclose(det_R, torch.ones_like(det_R)) if not (orthogonal and no_distortion): msg = "R is not a valid rotation matrix" diff --git a/tests/test_common_workaround.py b/tests/test_common_workaround.py new file mode 100644 index 00000000..a40ba171 --- /dev/null +++ b/tests/test_common_workaround.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.common.workaround import _safe_det_3x3 + + +class TestSafeDet3x3(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + def _test_det_3x3(self, batch_size, device): + t = torch.rand((batch_size, 3, 3), dtype=torch.float32, device=device) + actual_det = _safe_det_3x3(t) + expected_det = t.det() + self.assertClose(actual_det, expected_det, atol=1e-7) + + def test_empty_batch(self): + self._test_det_3x3(0, torch.device("cpu")) + self._test_det_3x3(0, torch.device("cuda:0")) + + def test_manual(self): + t = torch.Tensor( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[2, -5, 3], [0, 7, -2], [-1, 4, 1]], + [[6, 1, 1], [4, -2, 5], [2, 8, 7]], + ] + ).to(dtype=torch.float32) + expected_det = torch.Tensor([1, 41, -306]).to(dtype=torch.float32) + self.assertClose(_safe_det_3x3(t), expected_det) + + device_cuda = torch.device("cuda:0") + self.assertClose( + _safe_det_3x3(t.to(device=device_cuda)), expected_det.to(device=device_cuda) + ) + + def test_regression(self): + tries = 32 + device_cpu = torch.device("cpu") + device_cuda = torch.device("cuda:0") + batch_sizes = np.random.randint(low=1, high=128, size=tries) + + for batch_size in batch_sizes: + self._test_det_3x3(batch_size, device_cpu) + self._test_det_3x3(batch_size, device_cuda)