diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index 36351d0c..4e978b67 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -670,12 +670,13 @@ def _xy_to_ray_bundle( # directions are the differences between the two planes of points rays_directions_world = rays_plane_2_world - rays_plane_1_world - if unit_directions: - rays_directions_world = F.normalize(rays_directions_world, dim=-1) # origins are given by subtracting the ray directions from the first plane rays_origins_world = rays_plane_1_world - rays_directions_world + if unit_directions: + rays_directions_world = F.normalize(rays_directions_world, dim=-1) + return RayBundle( rays_origins_world.view(batch_size, *spatial_size, 3), rays_directions_world.view(batch_size, *spatial_size, 3),