mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Summary: # Background There is an unstable error during training (it can happen after several minutes or after several hours). The error is connected to `torch.det()` function in `_check_valid_rotation_matrix()`. if I remove the function `torch.det()` in `_check_valid_rotation_matrix()` or remove the whole functions `_check_valid_rotation_matrix()` the error is disappeared (D29555876). # Solution Replace `torch.det()` with manual implementation for 3x3 matrix. Reviewed By: patricklabatut Differential Revision: D29655924 fbshipit-source-id: 41bde1119274a705ab849751ece28873d2c45155
32 lines
870 B
Python
32 lines
870 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import torch
|
|
|
|
|
|
def _safe_det_3x3(t: torch.Tensor):
|
|
"""
|
|
Fast determinant calculation for a batch of 3x3 matrices.
|
|
|
|
Note, result of this function might not be the same as `torch.det()`.
|
|
The differences might be in the last significant digit.
|
|
|
|
Args:
|
|
t: Tensor of shape (N, 3, 3).
|
|
|
|
Returns:
|
|
Tensor of shape (N) with determinants.
|
|
"""
|
|
|
|
det = (
|
|
t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1])
|
|
- t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2])
|
|
+ t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1])
|
|
)
|
|
|
|
return det
|