diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 455edfeb..21ec6e9e 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -21,7 +21,7 @@ class Transform3d: points = torch.randn(N, P, 3) normals = torch.randn(N, P, 3) points_transformed = t.transform_points(points) # => (N, P, 3) - normals_transformed = t.transform_points(normals) # => (N, P, 3) + normals_transformed = t.transform_normals(normals) # => (N, P, 3) BROADCASTING