Avoid torch.square

Summary: Fix axis_angle conversions where I used torch.square which doesn't work with pytorch 1.4

Reviewed By: nikhilaravi

Differential Revision: D24451546

fbshipit-source-id: ba26f7dad5fa991f0a8f7d3d09ee7151163aecf4
This commit is contained in:
Jeremy Reizenstein 2020-10-22 02:21:02 -07:00 committed by Facebook GitHub Bot
parent c93c4dd7b2
commit 7e986cfba8

View File

@ -469,7 +469,7 @@ def axis_angle_to_quaternion(axis_angle):
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
@ -503,7 +503,7 @@ def quaternion_to_axis_angle(quaternions):
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles