mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
Address black + isort fbsource linter warnings
Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff) Reviewed By: nikhilaravi Differential Revision: D20558373 fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eb512ffde3
commit
d57daa6f85
@@ -1,10 +1,10 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.transforms.so3 import (
|
||||
hat,
|
||||
so3_exponential_map,
|
||||
@@ -26,9 +26,7 @@ class TestSO3(unittest.TestCase):
|
||||
randomly generated logarithms of rotation matrices.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
log_rot = torch.randn(
|
||||
(batch_size, 3), dtype=torch.float32, device=device
|
||||
)
|
||||
log_rot = torch.randn((batch_size, 3), dtype=torch.float32, device=device)
|
||||
return log_rot
|
||||
|
||||
@staticmethod
|
||||
@@ -85,16 +83,12 @@ class TestSO3(unittest.TestCase):
|
||||
log_rot = torch.randn(size=[5, 4], device=device)
|
||||
with self.assertRaises(ValueError) as err:
|
||||
so3_exponential_map(log_rot)
|
||||
self.assertTrue(
|
||||
"Input tensor shape has to be Nx3." in str(err.exception)
|
||||
)
|
||||
self.assertTrue("Input tensor shape has to be Nx3." in str(err.exception))
|
||||
|
||||
rot = torch.randn(size=[5, 3, 5], device=device)
|
||||
with self.assertRaises(ValueError) as err:
|
||||
so3_log_map(rot)
|
||||
self.assertTrue(
|
||||
"Input has to be a batch of 3x3 Tensors." in str(err.exception)
|
||||
)
|
||||
self.assertTrue("Input has to be a batch of 3x3 Tensors." in str(err.exception))
|
||||
|
||||
# trace of rot definitely bigger than 3 or smaller than -1
|
||||
rot = torch.cat(
|
||||
|
||||
Reference in New Issue
Block a user