mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Replace torch.det()
with manual implementation for 3x3 matrix
Summary: # Background There is an unstable error during training (it can happen after several minutes or after several hours). The error is connected to `torch.det()` function in `_check_valid_rotation_matrix()`. if I remove the function `torch.det()` in `_check_valid_rotation_matrix()` or remove the whole functions `_check_valid_rotation_matrix()` the error is disappeared (D29555876). # Solution Replace `torch.det()` with manual implementation for 3x3 matrix. Reviewed By: patricklabatut Differential Revision: D29655924 fbshipit-source-id: 41bde1119274a705ab849751ece28873d2c45155
This commit is contained in:
parent
2f668ecefe
commit
bcee361d04
31
pytorch3d/common/workaround.py
Normal file
31
pytorch3d/common/workaround.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_det_3x3(t: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Fast determinant calculation for a batch of 3x3 matrices.
|
||||||
|
|
||||||
|
Note, result of this function might not be the same as `torch.det()`.
|
||||||
|
The differences might be in the last significant digit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: Tensor of shape (N, 3, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (N) with determinants.
|
||||||
|
"""
|
||||||
|
|
||||||
|
det = (
|
||||||
|
t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1])
|
||||||
|
- t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2])
|
||||||
|
+ t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1])
|
||||||
|
)
|
||||||
|
|
||||||
|
return det
|
@ -11,6 +11,7 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..common.types import Device, get_device, make_device
|
from ..common.types import Device, get_device, make_device
|
||||||
|
from ..common.workaround import _safe_det_3x3
|
||||||
from .rotation_conversions import _axis_angle_rotation
|
from .rotation_conversions import _axis_angle_rotation
|
||||||
|
|
||||||
|
|
||||||
@ -774,7 +775,7 @@ def _check_valid_rotation_matrix(R, tol: float = 1e-7):
|
|||||||
eye = torch.eye(3, dtype=R.dtype, device=R.device)
|
eye = torch.eye(3, dtype=R.dtype, device=R.device)
|
||||||
eye = eye.view(1, 3, 3).expand(N, -1, -1)
|
eye = eye.view(1, 3, 3).expand(N, -1, -1)
|
||||||
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
|
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
|
||||||
det_R = torch.det(R)
|
det_R = _safe_det_3x3(R)
|
||||||
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
|
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
|
||||||
if not (orthogonal and no_distortion):
|
if not (orthogonal and no_distortion):
|
||||||
msg = "R is not a valid rotation matrix"
|
msg = "R is not a valid rotation matrix"
|
||||||
|
56
tests/test_common_workaround.py
Normal file
56
tests/test_common_workaround.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.common.workaround import _safe_det_3x3
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeDet3x3(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
def _test_det_3x3(self, batch_size, device):
|
||||||
|
t = torch.rand((batch_size, 3, 3), dtype=torch.float32, device=device)
|
||||||
|
actual_det = _safe_det_3x3(t)
|
||||||
|
expected_det = t.det()
|
||||||
|
self.assertClose(actual_det, expected_det, atol=1e-7)
|
||||||
|
|
||||||
|
def test_empty_batch(self):
|
||||||
|
self._test_det_3x3(0, torch.device("cpu"))
|
||||||
|
self._test_det_3x3(0, torch.device("cuda:0"))
|
||||||
|
|
||||||
|
def test_manual(self):
|
||||||
|
t = torch.Tensor(
|
||||||
|
[
|
||||||
|
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
|
||||||
|
[[2, -5, 3], [0, 7, -2], [-1, 4, 1]],
|
||||||
|
[[6, 1, 1], [4, -2, 5], [2, 8, 7]],
|
||||||
|
]
|
||||||
|
).to(dtype=torch.float32)
|
||||||
|
expected_det = torch.Tensor([1, 41, -306]).to(dtype=torch.float32)
|
||||||
|
self.assertClose(_safe_det_3x3(t), expected_det)
|
||||||
|
|
||||||
|
device_cuda = torch.device("cuda:0")
|
||||||
|
self.assertClose(
|
||||||
|
_safe_det_3x3(t.to(device=device_cuda)), expected_det.to(device=device_cuda)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_regression(self):
|
||||||
|
tries = 32
|
||||||
|
device_cpu = torch.device("cpu")
|
||||||
|
device_cuda = torch.device("cuda:0")
|
||||||
|
batch_sizes = np.random.randint(low=1, high=128, size=tries)
|
||||||
|
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
self._test_det_3x3(batch_size, device_cpu)
|
||||||
|
self._test_det_3x3(batch_size, device_cuda)
|
Loading…
x
Reference in New Issue
Block a user