mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Prepare for "Fix type-safety of torch.nn.Module
instances": wave 2
Summary: See D52890934 Reviewed By: malfet, r-barnes Differential Revision: D66245100 fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
This commit is contained in:
parent
e20cbe9b0e
commit
f6c2ca6bfc
@ -123,6 +123,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
"""
|
"""
|
||||||
# Get the parameters to optimize
|
# Get the parameters to optimize
|
||||||
if hasattr(model, "_get_param_groups"): # use the model function
|
if hasattr(model, "_get_param_groups"): # use the model function
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||||
else:
|
else:
|
||||||
p_groups = [
|
p_groups = [
|
||||||
|
@ -395,6 +395,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
):
|
):
|
||||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||||
if hasattr(model, "visualize"):
|
if hasattr(model, "visualize"):
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
model.visualize(
|
model.visualize(
|
||||||
viz,
|
viz,
|
||||||
visdom_env_imgs,
|
visdom_env_imgs,
|
||||||
|
@ -329,6 +329,7 @@ def adjust_camera_to_bbox_crop_(
|
|||||||
|
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
camera.principal_point[0],
|
camera.principal_point[0],
|
||||||
image_size_wh,
|
image_size_wh,
|
||||||
)
|
)
|
||||||
@ -341,6 +342,7 @@ def adjust_camera_to_bbox_crop_(
|
|||||||
)
|
)
|
||||||
|
|
||||||
camera.focal_length = focal_length[None]
|
camera.focal_length = focal_length[None]
|
||||||
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||||
camera.principal_point = principal_point_cropped[None]
|
camera.principal_point = principal_point_cropped[None]
|
||||||
|
|
||||||
|
|
||||||
@ -352,6 +354,7 @@ def adjust_camera_to_image_scale_(
|
|||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
camera.principal_point[0],
|
camera.principal_point[0],
|
||||||
original_size_wh,
|
original_size_wh,
|
||||||
)
|
)
|
||||||
@ -368,6 +371,7 @@ def adjust_camera_to_image_scale_(
|
|||||||
image_size_wh_output,
|
image_size_wh_output,
|
||||||
)
|
)
|
||||||
camera.focal_length = focal_length_scaled[None]
|
camera.focal_length = focal_length_scaled[None]
|
||||||
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||||
camera.principal_point = principal_point_scaled[None]
|
camera.principal_point = principal_point_scaled[None]
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,9 +142,15 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
return f"res_layer_{stage + 1}"
|
return f"res_layer_{stage + 1}"
|
||||||
|
|
||||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
|
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
|
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
return (img - self._resnet_mean) / self._resnet_std
|
return (img - self._resnet_mean) / self._resnet_std
|
||||||
|
|
||||||
def get_feat_dims(self) -> int:
|
def get_feat_dims(self) -> int:
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||||
|
# not a function.
|
||||||
return sum(self._feat_dim.values())
|
return sum(self._feat_dim.values())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -183,7 +189,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
else:
|
else:
|
||||||
imgs_normed = imgs_resized
|
imgs_normed = imgs_resized
|
||||||
# is not a function.
|
# is not a function.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
feats = self.stem(imgs_normed)
|
feats = self.stem(imgs_normed)
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
||||||
feats = layer(feats)
|
feats = layer(feats)
|
||||||
# just a sanity check below
|
# just a sanity check below
|
||||||
|
@ -478,6 +478,8 @@ class GenericModel(ImplicitronModelBase):
|
|||||||
)
|
)
|
||||||
custom_args["global_code"] = global_code
|
custom_args["global_code"] = global_code
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||||
|
# function.
|
||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.bind_args(**custom_args)
|
func.bind_args(**custom_args)
|
||||||
|
|
||||||
@ -500,6 +502,8 @@ class GenericModel(ImplicitronModelBase):
|
|||||||
# Unbind the custom arguments to prevent pytorch from storing
|
# Unbind the custom arguments to prevent pytorch from storing
|
||||||
# large buffers of intermediate results due to points in the
|
# large buffers of intermediate results due to points in the
|
||||||
# bound arguments.
|
# bound arguments.
|
||||||
|
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||||
|
# function.
|
||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.unbind_args()
|
func.unbind_args()
|
||||||
|
|
||||||
|
@ -71,6 +71,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
|
||||||
return (self._autodecoder_codes.weight**2).mean()
|
return (self._autodecoder_codes.weight**2).mean()
|
||||||
|
|
||||||
def get_encoding_dim(self) -> int:
|
def get_encoding_dim(self) -> int:
|
||||||
@ -95,6 +96,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||||
# `Tensor`.
|
# `Tensor`.
|
||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||||
[self._key_map[elem] for elem in x],
|
[self._key_map[elem] for elem in x],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
@ -102,6 +104,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError("Not enough n_instances in the autodecoder") from None
|
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self._autodecoder_codes(x)
|
return self._autodecoder_codes(x)
|
||||||
|
|
||||||
def _load_key_map_hook(
|
def _load_key_map_hook(
|
||||||
|
@ -122,6 +122,7 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
|||||||
if frame_timestamp.shape[-1] != 1:
|
if frame_timestamp.shape[-1] != 1:
|
||||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||||
time = frame_timestamp / self.time_divisor
|
time = frame_timestamp / self.time_divisor
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self._harmonic_embedding(time)
|
return self._harmonic_embedding(time)
|
||||||
|
|
||||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
@ -232,9 +232,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
# if the skip tensor is None, we use `x` instead.
|
# if the skip tensor is None, we use `x` instead.
|
||||||
z = x
|
z = x
|
||||||
skipi = 0
|
skipi = 0
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
for li, layer in enumerate(self.mlp):
|
for li, layer in enumerate(self.mlp):
|
||||||
|
# pyre-fixme[58]: `in` is not supported for right operand type
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
if li in self._input_skips:
|
if li in self._input_skips:
|
||||||
if self._skip_affine_trans:
|
if self._skip_affine_trans:
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||||
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
||||||
else:
|
else:
|
||||||
y = torch.cat((y, z), dim=-1)
|
y = torch.cat((y, z), dim=-1)
|
||||||
|
@ -141,11 +141,16 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
self.embed_fn is None and fun_viewpool is None and global_code is None
|
self.embed_fn is None and fun_viewpool is None and global_code is None
|
||||||
):
|
):
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
[], device=rays_points_world.device, dtype=rays_points_world.dtype
|
[],
|
||||||
|
device=rays_points_world.device,
|
||||||
|
dtype=rays_points_world.dtype,
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
|
||||||
|
# `Union[Module, Tensor]`.
|
||||||
).view(0, self.out_dim)
|
).view(0, self.out_dim)
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
if self.embed_fn is not None:
|
if self.embed_fn is not None:
|
||||||
|
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
||||||
embeddings.append(self.embed_fn(rays_points_world))
|
embeddings.append(self.embed_fn(rays_points_world))
|
||||||
|
|
||||||
if fun_viewpool is not None:
|
if fun_viewpool is not None:
|
||||||
@ -164,13 +169,19 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
embedding = torch.cat(embeddings, dim=-1)
|
embedding = torch.cat(embeddings, dim=-1)
|
||||||
x = embedding
|
x = embedding
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||||
|
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
||||||
for layer_idx in range(self.num_layers - 1):
|
for layer_idx in range(self.num_layers - 1):
|
||||||
if layer_idx in self.skip_in:
|
if layer_idx in self.skip_in:
|
||||||
x = torch.cat([x, embedding], dim=-1) / 2**0.5
|
x = torch.cat([x, embedding], dim=-1) / 2**0.5
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||||
x = self.linear_layers[layer_idx](x)
|
x = self.linear_layers[layer_idx](x)
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||||
|
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
||||||
if layer_idx < self.num_layers - 2:
|
if layer_idx < self.num_layers - 2:
|
||||||
|
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
||||||
x = self.softplus(x)
|
x = self.softplus(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -123,8 +123,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
# Normalize the ray_directions to unit l2 norm.
|
# Normalize the ray_directions to unit l2 norm.
|
||||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||||
# Obtain the harmonic embedding of the normalized ray directions.
|
# Obtain the harmonic embedding of the normalized ray directions.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
|
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -195,6 +197,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
embeds = create_embeddings_for_implicit_function(
|
embeds = create_embeddings_for_implicit_function(
|
||||||
xyz_world=rays_points_world,
|
xyz_world=rays_points_world,
|
||||||
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
|
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
|
||||||
|
# got `Union[None, Tensor, Module]`.
|
||||||
xyz_embedding_function=(
|
xyz_embedding_function=(
|
||||||
self.harmonic_embedding_xyz if self.input_xyz else None
|
self.harmonic_embedding_xyz if self.input_xyz else None
|
||||||
),
|
),
|
||||||
@ -206,12 +210,14 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
features = self.xyz_encoder(embeds)
|
features = self.xyz_encoder(embeds)
|
||||||
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
|
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
|
||||||
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
||||||
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
|
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
|
||||||
features = features.reshape(*rays_points_world.shape[:-1], -1)
|
features = features.reshape(*rays_points_world.shape[:-1], -1)
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
raw_densities = self.density_layer(features)
|
raw_densities = self.density_layer(features)
|
||||||
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
|
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
|
||||||
|
|
||||||
@ -219,6 +225,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
if camera is None:
|
if camera is None:
|
||||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||||
|
|
||||||
|
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
directions = ray_bundle.directions @ camera.R
|
directions = ray_bundle.directions @ camera.R
|
||||||
else:
|
else:
|
||||||
directions = ray_bundle.directions
|
directions = ray_bundle.directions
|
||||||
|
@ -103,6 +103,8 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
embeds = create_embeddings_for_implicit_function(
|
embeds = create_embeddings_for_implicit_function(
|
||||||
xyz_world=rays_points_world,
|
xyz_world=rays_points_world,
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
|
||||||
|
# got `Union[Tensor, Module]`.
|
||||||
xyz_embedding_function=self._harmonic_embedding,
|
xyz_embedding_function=self._harmonic_embedding,
|
||||||
global_code=global_code,
|
global_code=global_code,
|
||||||
fun_viewpool=fun_viewpool,
|
fun_viewpool=fun_viewpool,
|
||||||
@ -112,6 +114,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
# Before running the network, we have to resize embeds to ndims=3,
|
# Before running the network, we have to resize embeds to ndims=3,
|
||||||
# otherwise the SRN layers consume huge amounts of memory.
|
# otherwise the SRN layers consume huge amounts of memory.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
raymarch_features = self._net(
|
raymarch_features = self._net(
|
||||||
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
|
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
|
||||||
)
|
)
|
||||||
@ -166,7 +169,9 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
# Normalize the ray_directions to unit l2 norm.
|
# Normalize the ray_directions to unit l2 norm.
|
||||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||||
# Obtain the harmonic embedding of the normalized ray directions.
|
# Obtain the harmonic embedding of the normalized ray directions.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
rays_embedding = self._harmonic_embedding(rays_directions_normed)
|
rays_embedding = self._harmonic_embedding(rays_directions_normed)
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self._color_layer((features, rays_embedding))
|
return self._color_layer((features, rays_embedding))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -195,6 +200,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
denoting the color of each ray point.
|
denoting the color of each ray point.
|
||||||
"""
|
"""
|
||||||
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
|
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
features = self._net(raymarch_features)
|
features = self._net(raymarch_features)
|
||||||
# features.shape = [minibatch x ... x self.n_hidden_units]
|
# features.shape = [minibatch x ... x self.n_hidden_units]
|
||||||
|
|
||||||
@ -202,6 +208,8 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
if camera is None:
|
if camera is None:
|
||||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||||
|
|
||||||
|
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
directions = ray_bundle.directions @ camera.R
|
directions = ray_bundle.directions @ camera.R
|
||||||
else:
|
else:
|
||||||
directions = ray_bundle.directions
|
directions = ray_bundle.directions
|
||||||
@ -209,6 +217,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
||||||
features = features.reshape(*raymarch_features.shape[:-1], -1)
|
features = features.reshape(*raymarch_features.shape[:-1], -1)
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
raw_densities = self._density_layer(features)
|
raw_densities = self._density_layer(features)
|
||||||
|
|
||||||
rays_colors = self._get_colors(features, directions)
|
rays_colors = self._get_colors(features, directions)
|
||||||
@ -269,6 +278,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
srn_raymarch_function.
|
srn_raymarch_function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
net = self._hypernet(global_code)
|
net = self._hypernet(global_code)
|
||||||
|
|
||||||
# use the hyper-net generated network to instantiate the raymarch module
|
# use the hyper-net generated network to instantiate the raymarch module
|
||||||
@ -304,6 +314,8 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
# across LSTM iterations for the same global_code.
|
# across LSTM iterations for the same global_code.
|
||||||
if self.cached_srn_raymarch_function is None:
|
if self.cached_srn_raymarch_function is None:
|
||||||
# generate the raymarching network from the hypernet
|
# generate the raymarching network from the hypernet
|
||||||
|
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
|
||||||
|
# `cached_srn_raymarch_function`.
|
||||||
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
|
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
|
||||||
(srn_raymarch_function,) = cast(
|
(srn_raymarch_function,) = cast(
|
||||||
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
|
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
|
||||||
@ -331,6 +343,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
def create_raymarch_function(self) -> None:
|
def create_raymarch_function(self) -> None:
|
||||||
self.raymarch_function = SRNRaymarchFunction(
|
self.raymarch_function = SRNRaymarchFunction(
|
||||||
latent_dim=self.latent_dim,
|
latent_dim=self.latent_dim,
|
||||||
|
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||||
**self.raymarch_function_args,
|
**self.raymarch_function_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -389,6 +402,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
self.hypernet = SRNRaymarchHyperNet(
|
self.hypernet = SRNRaymarchHyperNet(
|
||||||
latent_dim=self.latent_dim,
|
latent_dim=self.latent_dim,
|
||||||
latent_dim_hypernet=self.latent_dim_hypernet,
|
latent_dim_hypernet=self.latent_dim_hypernet,
|
||||||
|
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||||
**self.hypernet_args,
|
**self.hypernet_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -269,6 +269,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
for name, tensor in vars(grid_values_with_wanted_resolution).items()
|
for name, tensor in vars(grid_values_with_wanted_resolution).items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self.values_type(**params), True
|
return self.values_type(**params), True
|
||||||
|
|
||||||
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
|
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
|
||||||
@ -882,6 +883,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
torch.Tensor of shape (..., n_features)
|
torch.Tensor of shape (..., n_features)
|
||||||
"""
|
"""
|
||||||
locator = self._get_volume_locator()
|
locator = self._get_volume_locator()
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
grid_values = self.voxel_grid.values_type(**self.params)
|
grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||||
@ -895,6 +897,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
replace current parameters
|
replace current parameters
|
||||||
"""
|
"""
|
||||||
if self.hold_voxel_grid_as_parameters:
|
if self.hold_voxel_grid_as_parameters:
|
||||||
|
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
|
||||||
self.params = torch.nn.ParameterDict(
|
self.params = torch.nn.ParameterDict(
|
||||||
{
|
{
|
||||||
k: torch.nn.Parameter(val)
|
k: torch.nn.Parameter(val)
|
||||||
@ -945,6 +948,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
True if parameter change has happened else False.
|
True if parameter change has happened else False.
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
grid_values = self.voxel_grid.values_type(**self.params)
|
grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
grid_values, change = self.voxel_grid.change_resolution(
|
grid_values, change = self.voxel_grid.change_resolution(
|
||||||
grid_values, epoch=epoch
|
grid_values, epoch=epoch
|
||||||
@ -992,16 +996,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
'''
|
'''
|
||||||
new_params = {}
|
new_params = {}
|
||||||
|
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||||
|
# function.
|
||||||
for name in self.params:
|
for name in self.params:
|
||||||
key = prefix + "params." + name
|
key = prefix + "params." + name
|
||||||
if key in state_dict:
|
if key in state_dict:
|
||||||
new_params[name] = torch.zeros_like(state_dict[key])
|
new_params[name] = torch.zeros_like(state_dict[key])
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
|
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
|
||||||
|
|
||||||
def get_device(self) -> torch.device:
|
def get_device(self) -> torch.device:
|
||||||
"""
|
"""
|
||||||
Returns torch.device on which module parameters are located
|
Returns torch.device on which module parameters are located
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||||
|
# not a function.
|
||||||
return next(val for val in self.params.values() if val is not None).device
|
return next(val for val in self.params.values() if val is not None).device
|
||||||
|
|
||||||
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
|
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
|
||||||
@ -1018,6 +1027,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
locator = self._get_volume_locator()
|
locator = self._get_volume_locator()
|
||||||
# torch.nn.modules.module.Module]` is not a function.
|
# torch.nn.modules.module.Module]` is not a function.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
old_grid_values = self.voxel_grid.values_type(**self.params)
|
old_grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
new_grid_values = self.voxel_grid.crop_world(
|
new_grid_values = self.voxel_grid.crop_world(
|
||||||
min_point, max_point, old_grid_values, locator
|
min_point, max_point, old_grid_values, locator
|
||||||
@ -1025,6 +1035,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
grid_values, _ = self.voxel_grid.change_resolution(
|
grid_values, _ = self.voxel_grid.change_resolution(
|
||||||
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
|
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
|
||||||
)
|
)
|
||||||
|
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
|
||||||
self.params = torch.nn.ParameterDict(
|
self.params = torch.nn.ParameterDict(
|
||||||
{
|
{
|
||||||
k: torch.nn.Parameter(val)
|
k: torch.nn.Parameter(val)
|
||||||
|
@ -192,16 +192,26 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||||
|
# `voxel_grid_scaffold`.
|
||||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||||
|
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||||
|
# `harmonic_embedder_xyz_density`.
|
||||||
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
|
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
|
||||||
**self.harmonic_embedder_xyz_density_args
|
**self.harmonic_embedder_xyz_density_args
|
||||||
)
|
)
|
||||||
|
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||||
|
# `harmonic_embedder_xyz_color`.
|
||||||
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
|
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
|
||||||
**self.harmonic_embedder_xyz_color_args
|
**self.harmonic_embedder_xyz_color_args
|
||||||
)
|
)
|
||||||
|
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||||
|
# `harmonic_embedder_dir_color`.
|
||||||
self.harmonic_embedder_dir_color = HarmonicEmbedding(
|
self.harmonic_embedder_dir_color = HarmonicEmbedding(
|
||||||
**self.harmonic_embedder_dir_color_args
|
**self.harmonic_embedder_dir_color_args
|
||||||
)
|
)
|
||||||
|
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||||
|
# `_scaffold_ready`.
|
||||||
self._scaffold_ready = False
|
self._scaffold_ready = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -252,6 +262,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
# ########## filter the points using the scaffold ########## #
|
# ########## filter the points using the scaffold ########## #
|
||||||
if self._scaffold_ready and self.scaffold_filter_points:
|
if self._scaffold_ready and self.scaffold_filter_points:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||||
points = points[non_empty_points]
|
points = points[non_empty_points]
|
||||||
if len(points) == 0:
|
if len(points) == 0:
|
||||||
@ -363,6 +374,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
feature dimensionality which `decoder_density` returns
|
feature dimensionality which `decoder_density` returns
|
||||||
"""
|
"""
|
||||||
embeds_density = self.voxel_grid_density(points)
|
embeds_density = self.voxel_grid_density(points)
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
|
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
|
||||||
# shape = [..., density_dim]
|
# shape = [..., density_dim]
|
||||||
return self.decoder_density(harmonic_embedding_density)
|
return self.decoder_density(harmonic_embedding_density)
|
||||||
@ -397,6 +409,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
if self.xyz_ray_dir_in_camera_coords:
|
if self.xyz_ray_dir_in_camera_coords:
|
||||||
if camera is None:
|
if camera is None:
|
||||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||||
|
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
directions = directions @ camera.R
|
directions = directions @ camera.R
|
||||||
|
|
||||||
# ########## get voxel grid output ########## #
|
# ########## get voxel grid output ########## #
|
||||||
@ -405,11 +419,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
# ########## embed with the harmonic function ########## #
|
# ########## embed with the harmonic function ########## #
|
||||||
# Obtain the harmonic embedding of the voxel grid output.
|
# Obtain the harmonic embedding of the voxel grid output.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
|
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
|
||||||
|
|
||||||
# Normalize the ray_directions to unit l2 norm.
|
# Normalize the ray_directions to unit l2 norm.
|
||||||
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
|
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
|
||||||
# Obtain the harmonic embedding of the normalized ray directions.
|
# Obtain the harmonic embedding of the normalized ray directions.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
|
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
|
||||||
rays_directions_normed
|
rays_directions_normed
|
||||||
)
|
)
|
||||||
@ -478,8 +494,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
an object inside, else False.
|
an object inside, else False.
|
||||||
"""
|
"""
|
||||||
# find bounding box
|
# find bounding box
|
||||||
|
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
||||||
|
# `get_grid_points`.
|
||||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||||
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
|
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
|
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||||
non_zero_idxs = torch.nonzero(occupancy)
|
non_zero_idxs = torch.nonzero(occupancy)
|
||||||
if len(non_zero_idxs) == 0:
|
if len(non_zero_idxs) == 0:
|
||||||
@ -511,6 +530,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
planes = []
|
planes = []
|
||||||
|
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
||||||
|
# `get_grid_points`.
|
||||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||||
|
|
||||||
chunk_size = (
|
chunk_size = (
|
||||||
@ -530,7 +551,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
|
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
|
||||||
|
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `params`.
|
||||||
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
|
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
|
||||||
|
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||||
|
# `_scaffold_ready`.
|
||||||
self._scaffold_ready = True
|
self._scaffold_ready = True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -547,6 +571,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
decoding function to this value.
|
decoding function to this value.
|
||||||
"""
|
"""
|
||||||
grid_args = self.voxel_grid_density_args
|
grid_args = self.voxel_grid_density_args
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||||
|
|
||||||
embedder_args = self.harmonic_embedder_xyz_density_args
|
embedder_args = self.harmonic_embedder_xyz_density_args
|
||||||
@ -575,6 +601,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
decoding function to this value.
|
decoding function to this value.
|
||||||
"""
|
"""
|
||||||
grid_args = self.voxel_grid_color_args
|
grid_args = self.voxel_grid_color_args
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||||
|
|
||||||
embedder_args = self.harmonic_embedder_xyz_color_args
|
embedder_args = self.harmonic_embedder_xyz_color_args
|
||||||
@ -608,7 +636,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
`self.voxel_grid_density`
|
`self.voxel_grid_density`
|
||||||
"""
|
"""
|
||||||
return VoxelGridModule(
|
return VoxelGridModule(
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||||
extents=self.voxel_grid_density_args["extents"],
|
extents=self.voxel_grid_density_args["extents"],
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||||
translation=self.voxel_grid_density_args["translation"],
|
translation=self.voxel_grid_density_args["translation"],
|
||||||
voxel_grid_class_type="FullResolutionVoxelGrid",
|
voxel_grid_class_type="FullResolutionVoxelGrid",
|
||||||
hold_voxel_grid_as_parameters=False,
|
hold_voxel_grid_as_parameters=False,
|
||||||
|
@ -135,6 +135,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# run the lstm marcher
|
# run the lstm marcher
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
state_h, state_c = self._lstm(
|
state_h, state_c = self._lstm(
|
||||||
raymarch_features.view(-1, raymarch_features.shape[-1]),
|
raymarch_features.view(-1, raymarch_features.shape[-1]),
|
||||||
states[-1],
|
states[-1],
|
||||||
@ -142,6 +143,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
if state_h.requires_grad:
|
if state_h.requires_grad:
|
||||||
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
|
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
|
||||||
# predict the next step size
|
# predict the next step size
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
|
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
|
||||||
# log the lstm states
|
# log the lstm states
|
||||||
states.append((state_h, state_c))
|
states.append((state_h, state_c))
|
||||||
|
@ -207,6 +207,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
sample_mask = None
|
sample_mask = None
|
||||||
if (
|
if (
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||||
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
|
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
|
||||||
and mask is not None
|
and mask is not None
|
||||||
):
|
):
|
||||||
@ -223,6 +224,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
EvaluationMode.EVALUATION: self._evaluation_raysampler,
|
EvaluationMode.EVALUATION: self._evaluation_raysampler,
|
||||||
}[evaluation_mode]
|
}[evaluation_mode]
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
ray_bundle = raysampler(
|
ray_bundle = raysampler(
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
mask=sample_mask,
|
mask=sample_mask,
|
||||||
@ -240,6 +242,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
|
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
|
||||||
)
|
)
|
||||||
elif self.cast_ray_bundle_as_cone:
|
elif self.cast_ray_bundle_as_cone:
|
||||||
|
# pyre-fixme[9]: pixel_hw has type `Tuple[float, float]`; used as
|
||||||
|
# `Tuple[Union[Tensor, Module], Union[Tensor, Module]]`.
|
||||||
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
||||||
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
||||||
return ImplicitronRayBundle(
|
return ImplicitronRayBundle(
|
||||||
|
@ -179,8 +179,10 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
|||||||
rays_densities = torch.relu(rays_densities)
|
rays_densities = torch.relu(rays_densities)
|
||||||
|
|
||||||
weighted_densities = deltas * rays_densities
|
weighted_densities = deltas * rays_densities
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
capped_densities = self._capping_function(weighted_densities)
|
capped_densities = self._capping_function(weighted_densities)
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
rays_opacities = self._capping_function(
|
rays_opacities = self._capping_function(
|
||||||
torch.cumsum(weighted_densities, dim=-1)
|
torch.cumsum(weighted_densities, dim=-1)
|
||||||
)
|
)
|
||||||
@ -190,6 +192,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
absorption_shifted[..., : self.surface_thickness] = 1.0
|
absorption_shifted[..., : self.surface_thickness] = 1.0
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
weights = self._weight_function(capped_densities, absorption_shifted)
|
weights = self._weight_function(capped_densities, absorption_shifted)
|
||||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||||
@ -197,6 +200,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
|||||||
alpha = opacities if self.blend_output else 1
|
alpha = opacities if self.blend_output else 1
|
||||||
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
|
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
|
||||||
raise ValueError("Wrong number of background color channels.")
|
raise ValueError("Wrong number of background color channels.")
|
||||||
|
# pyre-fixme[58]: `*` is not supported for operand types `int` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
features = alpha * features + (1 - opacities) * self._bg_color
|
features = alpha * features + (1 - opacities) * self._bg_color
|
||||||
|
|
||||||
return RendererOutput(
|
return RendererOutput(
|
||||||
|
@ -61,6 +61,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
|
|
||||||
def create_ray_tracer(self) -> None:
|
def create_ray_tracer(self) -> None:
|
||||||
self.ray_tracer = RayTracing(
|
self.ray_tracer = RayTracing(
|
||||||
|
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||||
**self.ray_tracer_args,
|
**self.ray_tracer_args,
|
||||||
object_bounding_sphere=self.object_bounding_sphere,
|
object_bounding_sphere=self.object_bounding_sphere,
|
||||||
)
|
)
|
||||||
@ -149,6 +150,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
n_eik_points,
|
n_eik_points,
|
||||||
3,
|
3,
|
||||||
# but got `Union[device, Tensor, Module]`.
|
# but got `Union[device, Tensor, Module]`.
|
||||||
|
# pyre-fixme[6]: For 3rd argument expected `Union[None, int, str,
|
||||||
|
# device]` but got `Union[device, Tensor, Module]`.
|
||||||
device=self._bg_color.device,
|
device=self._bg_color.device,
|
||||||
).uniform_(-eik_bounding_box, eik_bounding_box)
|
).uniform_(-eik_bounding_box, eik_bounding_box)
|
||||||
eikonal_pixel_points = points.clone()
|
eikonal_pixel_points = points.clone()
|
||||||
@ -205,6 +208,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
]
|
]
|
||||||
normals_full.view(-1, 3)[surface_mask] = normals
|
normals_full.view(-1, 3)[surface_mask] = normals
|
||||||
render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
|
render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
self._rgb_network(
|
self._rgb_network(
|
||||||
features,
|
features,
|
||||||
differentiable_surface_points[None],
|
differentiable_surface_points[None],
|
||||||
|
@ -532,6 +532,7 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
|
|||||||
|
|
||||||
# does not produce nans randomly unlike get_camera_center() below
|
# does not produce nans randomly unlike get_camera_center() below
|
||||||
cam_centers_rep = -torch.bmm(
|
cam_centers_rep = -torch.bmm(
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
camera_rep.T[:, None],
|
camera_rep.T[:, None],
|
||||||
camera_rep.R.permute(0, 2, 1),
|
camera_rep.R.permute(0, 2, 1),
|
||||||
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
||||||
|
@ -122,12 +122,17 @@ def corresponding_cameras_alignment(
|
|||||||
|
|
||||||
# create a new cameras object and set the R and T accordingly
|
# create a new cameras object and set the R and T accordingly
|
||||||
cameras_src_aligned = cameras_src.clone()
|
cameras_src_aligned = cameras_src.clone()
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||||
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
|
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
|
||||||
cameras_src_aligned.T = (
|
cameras_src_aligned.T = (
|
||||||
torch.bmm(
|
torch.bmm(
|
||||||
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
|
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
cameras_src.R,
|
cameras_src.R,
|
||||||
)[:, 0]
|
)[:, 0]
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||||
|
# float, int, Tensor]) -> Tensor, Tensor, Module]` is not a function.
|
||||||
+ cameras_src.T * align_t_s
|
+ cameras_src.T * align_t_s
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -175,6 +180,7 @@ def _align_camera_extrinsics(
|
|||||||
R_A = (U V^T)^T
|
R_A = (U V^T)^T
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||||
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
|
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
|
||||||
U, _, V = torch.svd(RRcov)
|
U, _, V = torch.svd(RRcov)
|
||||||
align_t_R = V @ U.t()
|
align_t_R = V @ U.t()
|
||||||
@ -204,7 +210,11 @@ def _align_camera_extrinsics(
|
|||||||
T_A = mean(B) - mean(A) * s_A
|
T_A = mean(B) - mean(A) * s_A
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
|
||||||
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
|
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
|
||||||
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
|
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
|
||||||
Amu = A.mean(0, keepdim=True)
|
Amu = A.mean(0, keepdim=True)
|
||||||
Bmu = B.mean(0, keepdim=True)
|
Bmu = B.mean(0, keepdim=True)
|
||||||
|
@ -65,7 +65,11 @@ def _opencv_from_cameras_projection(
|
|||||||
cameras: PerspectiveCameras,
|
cameras: PerspectiveCameras,
|
||||||
image_size: torch.Tensor,
|
image_size: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
|
||||||
|
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a function.
|
||||||
R_pytorch3d = cameras.R.clone()
|
R_pytorch3d = cameras.R.clone()
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
|
||||||
|
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a function.
|
||||||
T_pytorch3d = cameras.T.clone()
|
T_pytorch3d = cameras.T.clone()
|
||||||
focal_pytorch3d = cameras.focal_length
|
focal_pytorch3d = cameras.focal_length
|
||||||
p0_pytorch3d = cameras.principal_point
|
p0_pytorch3d = cameras.principal_point
|
||||||
|
@ -203,7 +203,9 @@ class CamerasBase(TensorProperties):
|
|||||||
"""
|
"""
|
||||||
R: torch.Tensor = kwargs.get("R", self.R)
|
R: torch.Tensor = kwargs.get("R", self.R)
|
||||||
T: torch.Tensor = kwargs.get("T", self.T)
|
T: torch.Tensor = kwargs.get("T", self.T)
|
||||||
|
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
|
||||||
self.R = R
|
self.R = R
|
||||||
|
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
|
||||||
self.T = T
|
self.T = T
|
||||||
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
|
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
|
||||||
return world_to_view_transform
|
return world_to_view_transform
|
||||||
@ -228,7 +230,9 @@ class CamerasBase(TensorProperties):
|
|||||||
a Transform3d object which represents a batch of transforms
|
a Transform3d object which represents a batch of transforms
|
||||||
of shape (N, 3, 3)
|
of shape (N, 3, 3)
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
|
||||||
self.R: torch.Tensor = kwargs.get("R", self.R)
|
self.R: torch.Tensor = kwargs.get("R", self.R)
|
||||||
|
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
|
||||||
self.T: torch.Tensor = kwargs.get("T", self.T)
|
self.T: torch.Tensor = kwargs.get("T", self.T)
|
||||||
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
||||||
view_to_proj_transform = self.get_projection_transform(**kwargs)
|
view_to_proj_transform = self.get_projection_transform(**kwargs)
|
||||||
|
@ -266,7 +266,9 @@ class PointLights(TensorProperties):
|
|||||||
shape (P, 3) or (N, H, W, K, 3).
|
shape (P, 3) or (N, H, W, K, 3).
|
||||||
"""
|
"""
|
||||||
if self.location.ndim == points.ndim:
|
if self.location.ndim == points.ndim:
|
||||||
|
# pyre-fixme[7]: Expected `Tensor` but got `Union[Tensor, Module]`.
|
||||||
return self.location
|
return self.location
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
return self.location[:, None, None, None, :]
|
return self.location[:, None, None, None, :]
|
||||||
|
|
||||||
def diffuse(self, normals, points) -> torch.Tensor:
|
def diffuse(self, normals, points) -> torch.Tensor:
|
||||||
|
@ -588,9 +588,15 @@ def _add_struct_from_batch(
|
|||||||
if isinstance(batched_struct, CamerasBase):
|
if isinstance(batched_struct, CamerasBase):
|
||||||
# we can't index directly into camera batches
|
# we can't index directly into camera batches
|
||||||
R, T = batched_struct.R, batched_struct.T
|
R, T = batched_struct.R, batched_struct.T
|
||||||
|
# pyre-fixme[6]: For 1st argument expected
|
||||||
|
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, Module]`.
|
||||||
r_idx = min(scene_num, len(R) - 1)
|
r_idx = min(scene_num, len(R) - 1)
|
||||||
|
# pyre-fixme[6]: For 1st argument expected
|
||||||
|
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, Module]`.
|
||||||
t_idx = min(scene_num, len(T) - 1)
|
t_idx = min(scene_num, len(T) - 1)
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
R = R[r_idx].unsqueeze(0)
|
R = R[r_idx].unsqueeze(0)
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
T = T[t_idx].unsqueeze(0)
|
T = T[t_idx].unsqueeze(0)
|
||||||
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
||||||
elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(
|
elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user