mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-02 10:15:59 +08:00
Avoid to keep in memory lengths and bins for ImplicitronRayBundle
Summary: Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes: - init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor - lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory. Reviewed By: shapovalov Differential Revision: D46686094 fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3d011a9198
commit
9446d91fae
@@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
||||
@@ -106,14 +108,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
z_vals = z_samples
|
||||
# Resort by depth.
|
||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||
|
||||
kwargs_ray = dict(vars(input_ray_bundle))
|
||||
ray_bundle = copy.copy(input_ray_bundle)
|
||||
if input_ray_bundle.bins is None:
|
||||
kwargs_ray["lengths"] = z_vals
|
||||
return ImplicitronRayBundle(**kwargs_ray)
|
||||
kwargs_ray["bins"] = z_vals
|
||||
del kwargs_ray["lengths"]
|
||||
return ImplicitronRayBundle.from_bins(**kwargs_ray)
|
||||
ray_bundle.lengths = z_vals
|
||||
else:
|
||||
ray_bundle.bins = z_vals
|
||||
|
||||
return ray_bundle
|
||||
|
||||
|
||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user