Avoid to keep in memory lengths and bins for ImplicitronRayBundle

Summary:
Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes:
- init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor
- lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory.

Reviewed By: shapovalov

Differential Revision: D46686094

fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
Emilien Garreau
2023-07-06 02:41:15 -07:00
committed by Facebook GitHub Bot
parent 3d011a9198
commit 9446d91fae
5 changed files with 103 additions and 61 deletions

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
@@ -106,14 +108,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
z_vals = z_samples
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
kwargs_ray = dict(vars(input_ray_bundle))
ray_bundle = copy.copy(input_ray_bundle)
if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals
return ImplicitronRayBundle(**kwargs_ray)
kwargs_ray["bins"] = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray)
ray_bundle.lengths = z_vals
else:
ray_bundle.bins = z_vals
return ray_bundle
def apply_blurpool_on_weights(weights) -> torch.Tensor: