diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 4f8fdbd7..f2de3d5a 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -916,8 +916,8 @@ def look_at_rotation( msg = "Expected arg %s to have shape (N, 3); got %r" raise ValueError(msg % (n, t.shape)) z_axis = F.normalize(at - camera_position, eps=1e-5) - x_axis = F.normalize(torch.cross(up, z_axis), eps=1e-5) - y_axis = F.normalize(torch.cross(z_axis, x_axis), eps=1e-5) + x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5) + y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5) R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1) return R.transpose(1, 2)