1 Commits

Author SHA1 Message Date
bottler
9c586b1351 Run tests in github action not circleci (#1896)
Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1896

Differential Revision: D65272512

Pulled By: bottler
2024-10-31 08:41:20 -07:00
147 changed files with 578 additions and 873 deletions

View File

@@ -88,6 +88,7 @@ def workflow_pair(
upload=False,
filter_branch,
):
w = []
py = python_version.replace(".", "")
pyt = pytorch_version.replace(".", "")
@@ -126,6 +127,7 @@ def generate_base_workflow(
btype,
filter_branch=None,
):
d = {
"name": base_workflow_name,
"python_version": python_version,

View File

@@ -3,9 +3,6 @@ on:
pull_request:
branches:
- main
push:
branches:
- main
jobs:
binary_linux_conda_cuda:
runs-on: 4-core-ubuntu-gpu-t4

View File

@@ -36,5 +36,5 @@ then
echo "Running pyre..."
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
fi

View File

@@ -10,7 +10,6 @@ This example demonstrates the most trivial, direct interface of the pulsar
sphere renderer. It renders and saves an image with 10 random spheres.
Output: basic.png.
"""
import logging
import math
from os import path

View File

@@ -11,7 +11,6 @@ interface for sphere renderering. It renders and saves an image with
10 random spheres.
Output: basic-pt3d.png.
"""
import logging
from os import path

View File

@@ -14,7 +14,6 @@ distorted. Gradient-based optimization is used to converge towards the
original camera parameters.
Output: cam.gif.
"""
import logging
import math
from os import path

View File

@@ -14,7 +14,6 @@ distorted. Gradient-based optimization is used to converge towards the
original camera parameters.
Output: cam-pt3d.gif
"""
import logging
from os import path

View File

@@ -18,7 +18,6 @@ This example is not available yet through the 'unified' interface,
because opacity support has not landed in PyTorch3D for general data
structures yet.
"""
import logging
import math
from os import path

View File

@@ -13,7 +13,6 @@ The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful
scene representation.
"""
import logging
import math

View File

@@ -13,7 +13,6 @@ The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful
scene representation.
"""
import logging
import math

View File

@@ -32,6 +32,7 @@ requirements:
build:
string: py{{py}}_{{ environ['CU_VERSION'] }}_pyt{{ environ['PYTORCH_VERSION_NODOT']}}
# script: LD_LIBRARY_PATH=$PREFIX/lib:$BUILD_PREFIX/lib:$LD_LIBRARY_PATH python setup.py install --single-version-externally-managed --record=record.txt # [not win]
script: python setup.py install --single-version-externally-managed --record=record.txt # [not win]
script_env:
- CUDA_HOME
@@ -56,6 +57,7 @@ test:
- pandas
- sqlalchemy
commands:
#pytest .
python -m unittest discover -v -s tests -t .

View File

@@ -7,7 +7,7 @@
# pyre-unsafe
""" "
""""
This file is the entry point for launching experiments with Implicitron.
Launch Training
@@ -44,7 +44,6 @@ The outputs of the experiment are saved and logged in multiple ways:
config file.
"""
import logging
import os
import warnings

View File

@@ -26,6 +26,7 @@ logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase):
resume: bool = True # resume from the last checkpoint
def __call__(self, **kwargs) -> ImplicitronModelBase:
@@ -115,9 +116,7 @@ class ImplicitronModelFactory(ModelFactoryBase):
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
model_state_dict = torch.load(
model_io.get_model_path(model_path),
map_location=map_location,
weights_only=True,
model_io.get_model_path(model_path), map_location=map_location
)
try:

View File

@@ -123,7 +123,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
else:
p_groups = [
@@ -242,7 +241,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
optimizer_state = torch.load(opt_path, map_location)
else:
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
return optimizer_state

View File

@@ -161,6 +161,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
for epoch in range(start_epoch, self.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start
with stats:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
seed_all_random_engines(seed + epoch)
@@ -394,7 +395,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
if hasattr(model, "visualize"):
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
model.visualize(
viz,
visdom_env_imgs,

View File

@@ -53,8 +53,12 @@ class TestExperiment(unittest.TestCase):
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
"JsonIndexDatasetMapProvider"
)
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
dataset_args = (
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataloader_args = (
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
@@ -90,8 +94,12 @@ class TestExperiment(unittest.TestCase):
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
"JsonIndexDatasetMapProvider"
)
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
dataset_args = (
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataloader_args = (
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
@@ -103,7 +111,9 @@ class TestExperiment(unittest.TestCase):
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
2
)
if DEBUG:
experiment.dump_cfg(cfg)

View File

@@ -81,9 +81,8 @@ class TestOptimizerFactory(unittest.TestCase):
def test_param_overrides_self_param_group_assignment(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = (
Node(params=[pa]),
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
na, nb = Node(params=[pa]), Node(
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
)
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
param_groups = self._get_param_groups(root)

View File

@@ -84,9 +84,9 @@ def get_nerf_datasets(
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
# Automatically download the data files if missing.
download_data([dataset_name], data_root=data_root)
download_data((dataset_name,), data_root=data_root)
train_data = torch.load(cameras_path, weights_only=True)
train_data = torch.load(cameras_path)
n_cameras = train_data["cameras"]["R"].shape[0]
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS

View File

@@ -194,6 +194,7 @@ class Stats:
it = self.it[stat_set]
for stat in self.log_vars:
if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter()

View File

@@ -24,6 +24,7 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
def main(cfg: DictConfig):
# Device on which to run.
if torch.cuda.is_available():
device = "cuda"
@@ -62,7 +63,7 @@ def main(cfg: DictConfig):
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
print(f"Loading checkpoint {checkpoint_path}.")
loaded_data = torch.load(checkpoint_path, weights_only=True)
loaded_data = torch.load(checkpoint_path)
# Do not load the cached xy grid.
# - this allows setting an arbitrary evaluation image size.
state_dict = {

View File

@@ -42,6 +42,7 @@ class TestRaysampler(unittest.TestCase):
cameras, rays = [], []
for _ in range(batch_size):
R = random_rotations(1)
T = torch.randn(1, 3)
focal_length = torch.rand(1, 2) + 0.5

View File

@@ -25,6 +25,7 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
def main(cfg: DictConfig):
# Set the relevant seeds for reproducibility.
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
@@ -76,7 +77,7 @@ def main(cfg: DictConfig):
# Resume training if requested.
if cfg.resume and os.path.isfile(checkpoint_path):
print(f"Resuming from checkpoint {checkpoint_path}.")
loaded_data = torch.load(checkpoint_path, weights_only=True)
loaded_data = torch.load(checkpoint_path)
model.load_state_dict(loaded_data["model"])
stats = pickle.loads(loaded_data["stats"])
print(f" => resuming from epoch {stats.epoch}.")
@@ -218,6 +219,7 @@ def main(cfg: DictConfig):
# Validation
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
# Sample a validation camera/image.
val_batch = next(val_dataloader.__iter__())
val_image, val_camera, camera_idx = val_batch[0].values()

View File

@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
def meshgrid_ij(
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
"""
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij

View File

@@ -7,6 +7,7 @@
*/
#include <torch/extension.h>
#include <queue>
#include <tuple>
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(

View File

@@ -28,6 +28,7 @@ __global__ void alphaCompositeCudaForwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
@@ -78,6 +79,7 @@ __global__ void alphaCompositeCudaBackwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);

View File

@@ -28,6 +28,7 @@ __global__ void weightedSumNormCudaForwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
@@ -91,6 +92,7 @@ __global__ void weightedSumNormCudaBackwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);

View File

@@ -26,6 +26,7 @@ __global__ void weightedSumCudaForwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
@@ -73,6 +74,7 @@ __global__ void weightedSumCudaBackwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);

View File

@@ -149,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gamma"),
py::arg("max_depth"),
py::arg("min_depth") /* = 0.f*/,
py::arg("bg_col") /* = std::nullopt not exposed properly in
pytorch 1.1. */
py::arg(
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
,
py::arg("opacity") /* = std::nullopt ... */,
py::arg("opacity") /* = at::nullopt ... */,
py::arg("percent_allowed_difference") = 0.01f,
py::arg("max_n_hits") = MAX_UINT,
py::arg("mode") = 0)

View File

@@ -7,7 +7,10 @@
*/
#include <torch/extension.h>
#include <torch/torch.h>
#include <list>
#include <numeric>
#include <queue>
#include <tuple>
#include "iou_box3d/iou_utils.h"

View File

@@ -461,8 +461,10 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
__device__ inline bool IsCoplanarTriTri(
const FaceVerts& tri1,
const FaceVerts& tri2) {
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
// Check if parallel
@@ -498,6 +500,7 @@ __device__ inline bool IsCoplanarTriPlane(
const FaceVerts& tri,
const FaceVerts& plane,
const float3& normal) {
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
// check if parallel
@@ -725,7 +728,7 @@ __device__ inline int BoxIntersections(
}
}
// Update the face_verts_out tris
num_tris = min(MAX_TRIS, offset);
num_tris = offset;
for (int j = 0; j < num_tris; ++j) {
face_verts_out[j] = tri_verts_updated[j];
}

View File

@@ -8,7 +8,9 @@
#include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/extension.h>
#include <algorithm>
#include <cmath>
#include <thread>
#include <vector>
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to

View File

@@ -8,7 +8,6 @@
#ifdef WITH_CUDA
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
#endif
#include <torch/extension.h>
@@ -34,13 +33,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
.contiguous();
if (forw_info.device().type() == c10::DeviceType::CUDA) {
#ifdef WITH_CUDA
C10_CUDA_CHECK(cudaMemcpyAsync(
cudaMemcpyAsync(
result.data_ptr(),
tmp.data_ptr(),
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
tmp.size(3),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()));
at::cuda::getCurrentCUDAStream());
#else
throw std::runtime_error(
"Copy on CUDA device initiated but built "

View File

@@ -7,7 +7,6 @@
*/
#ifdef WITH_CUDA
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h>
namespace pulsar {
@@ -18,8 +17,7 @@ void cudaDevToDev(
const void* src,
const int& size,
const cudaStream_t& stream) {
C10_CUDA_CHECK(
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
}
void cudaDevToHost(
@@ -27,8 +25,7 @@ void cudaDevToHost(
const void* src,
const int& size,
const cudaStream_t& stream) {
C10_CUDA_CHECK(
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
}
} // namespace pytorch

View File

@@ -9,6 +9,7 @@
#include <torch/extension.h>
#include <algorithm>
#include <list>
#include <queue>
#include <thread>
#include <tuple>
#include "ATen/core/TensorAccessor.h"

View File

@@ -35,6 +35,8 @@ __global__ void FarthestPointSamplingKernel(
__shared__ int64_t selected_store;
// Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t D = points.size(2);
// Get batch index and thread index

View File

@@ -376,6 +376,8 @@ PointLineDistanceBackward(
float tt = t_top / t_bot;
tt = __saturatef(tt);
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
const float2 d = p - p_proj;
const float dist = sqrt(dot(d, d));
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);

View File

@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
):
synset_set.add(synset)
elif (synset in self.synset_inv.keys()) and (
path.isdir(path.join(data_dir, self.synset_inv[synset]))
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
):
synset_set.add(self.synset_inv[synset])
else:

View File

@@ -36,6 +36,7 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
collated_dict["mesh"] = None
if {"verts", "faces"}.issubset(collated_dict.keys()):
textures = None
if "textures" in collated_dict:
textures = TexturesAtlas(atlas=collated_dict["textures"])

View File

@@ -26,7 +26,7 @@ from typing import (
import numpy as np
import torch
from pytorch3d.implicitron.dataset import orm_types, types
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.utils import (
adjust_camera_to_bbox_crop_,
adjust_camera_to_image_scale_,
@@ -48,12 +48,8 @@ from pytorch3d.implicitron.dataset.utils import (
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
@dataclass
class FrameData(Mapping[str, Any]):
@@ -126,9 +122,9 @@ class FrameData(Mapping[str, Any]):
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor] = None
sequence_name: Union[str, List[str]] = ""
sequence_category: Union[str, List[str]] = ""
frame_number: Optional[torch.LongTensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.LongTensor] = None
effective_image_size_hw: Optional[torch.LongTensor] = None
@@ -159,7 +155,7 @@ class FrameData(Mapping[str, Any]):
new_params = {}
for field_name in iter(self):
value = getattr(self, field_name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[field_name] = value.to(*args, **kwargs)
else:
new_params[field_name] = value
@@ -421,6 +417,7 @@ class FrameData(Mapping[str, Any]):
for f in fields(elem):
if not f.init:
continue
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
@@ -429,7 +426,7 @@ class FrameData(Mapping[str, Any]):
if all(list_value is not None for list_value in list_values)
else None
)
return type(elem)(**collated)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
@@ -437,8 +434,6 @@ class FrameData(Mapping[str, Any]):
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
elif isinstance(elem, Meshes):
return join_meshes_as_batch(batch)
else:
return torch.utils.data.dataloader.default_collate(batch)
@@ -459,8 +454,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
@abstractmethod
def build(
self,
frame_annotation: FrameAnnotationT,
sequence_annotation: SequenceAnnotationT,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
*,
load_blobs: bool = True,
**kwargs,
@@ -546,8 +541,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def build(
self,
frame_annotation: FrameAnnotationT,
sequence_annotation: SequenceAnnotationT,
frame_annotation: types.FrameAnnotation,
sequence_annotation: types.SequenceAnnotation,
*,
load_blobs: bool = True,
**kwargs,
@@ -591,81 +586,58 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
),
)
dataset_root = self.dataset_root
fg_mask_np: Optional[np.ndarray] = None
mask_annotation = frame_annotation.mask
depth_annotation = frame_annotation.depth
image_path: str | None = None
mask_path: str | None = None
depth_path: str | None = None
pcl_path: str | None = None
if dataset_root is not None: # set all paths even if we wont load blobs
if frame_annotation.image.path is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if mask_annotation is not None and mask_annotation.path:
mask_path = os.path.join(dataset_root, mask_annotation.path)
frame_data.mask_path = mask_path
if depth_annotation is not None and depth_annotation.path is not None:
depth_path = os.path.join(dataset_root, depth_annotation.path)
frame_data.depth_path = depth_path
if point_cloud is not None:
pcl_path = os.path.join(dataset_root, point_cloud.path)
frame_data.sequence_point_cloud_path = pcl_path
fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None
if mask_annotation is not None:
if load_blobs and self.load_masks and mask_path:
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
if load_blobs and self.load_masks:
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
frame_data.mask_path = mask_path
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
frame_data.image_size_hw = image_size_hw # original image size
# image size after crop/resize
frame_data.effective_image_size_hw = image_size_hw
image_path = None
dataset_root = self.dataset_root
if frame_annotation.image.path is not None and dataset_root is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if load_blobs and self.load_images:
if image_path is None:
raise ValueError("Image path is required to load images.")
no_mask = fg_mask_np is None # didnt read the mask file
image_np = load_image(
self._local_path(image_path), try_read_alpha=no_mask
)
if image_np.shape[0] == 4: # RGBA image
if no_mask:
fg_mask_np = image_np[3:]
frame_data.fg_probability = safe_as_tensor(
fg_mask_np, torch.float
)
image_np = image_np[:3]
image_np = load_image(self._local_path(image_path))
frame_data.image_rgb = self._postprocess_image(
image_np, frame_annotation.image.size, frame_data.fg_probability
)
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if load_blobs and self.load_depths and depth_path is not None:
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
frame_annotation, depth_path, fg_mask_np
)
if (
load_blobs
and self.load_depths
and frame_annotation.depth is not None
and frame_annotation.depth.path is not None
):
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(frame_annotation, fg_mask_np)
if load_blobs and self.load_point_clouds and point_cloud is not None:
assert pcl_path is not None
pcl_path = self._fix_point_cloud_path(point_cloud.path)
frame_data.sequence_point_cloud = load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
frame_data.sequence_point_cloud_path = pcl_path
if frame_annotation.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
@@ -681,14 +653,18 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return frame_data
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
fg_probability = load_mask(self._local_path(path))
def _load_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[np.ndarray, str]:
assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
)
return fg_probability
return fg_probability, full_path
def _postprocess_image(
self,
@@ -709,14 +685,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_mask_depth(
self,
entry: FrameAnnotationT,
path: str,
entry: types.FrameAnnotation,
fg_mask: Optional[np.ndarray],
) -> tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
dataset_root = self.dataset_root
assert dataset_root is not None
assert entry_depth is not None
assert entry_depth is not None and entry_depth.path is not None
path = os.path.join(dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.mask_depths:
@@ -730,11 +706,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
else:
depth_mask = (depth_map > 0.0).astype(np.float32)
return torch.tensor(depth_map), torch.tensor(depth_mask)
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
def _get_pytorch3d_camera(
self,
entry: FrameAnnotationT,
entry: types.FrameAnnotation,
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
@@ -763,6 +739,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _fix_point_cloud_path(self, path: str) -> str:
"""
Fix up a point cloud path from the dataset.
Some files in Co3Dv2 have an accidental absolute path stored.
"""
unwanted_prefix = (
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path

View File

@@ -222,6 +222,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
self.dataset_map = dataset_map
def _load_category(self, category: str) -> DatasetMap:
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(
self.dataset_root, category, "sequence_annotations.jgz"

View File

@@ -75,6 +75,7 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
def _load_data(
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
):
poses_arr = np.load(
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
)
@@ -163,6 +164,7 @@ def ptstocam(pts, c2w):
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
@@ -190,6 +192,7 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
def recenter_poses(poses):
poses_ = poses + 0
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
c2w = poses_avg(poses)
@@ -253,6 +256,7 @@ def spherify_poses(poses, bds):
new_poses = []
for th in np.linspace(0.0, 2.0 * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.0])
@@ -307,6 +311,7 @@ def load_llff_data(
path_zflat=False,
path_manager=None,
):
poses, bds, imgs = _load_data(
basedir, factor=factor, path_manager=path_manager
) # factor=8 downsamples original imgs by 8x

View File

@@ -4,8 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This functionality requires SQLAlchemy 2.0 or later.
import math

View File

@@ -4,15 +4,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import hashlib
import json
import logging
import os
import urllib
from dataclasses import dataclass, Field, field
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
@@ -33,18 +29,17 @@ import sqlalchemy as sa
import torch
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import (
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
FrameData,
FrameDataBuilder, # noqa
FrameDataBuilder,
FrameDataBuilderBase,
)
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from sqlalchemy.orm import scoped_session, Session, sessionmaker
from sqlalchemy.orm import Session
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
@@ -56,7 +51,7 @@ _SET_LISTS_TABLE: str = "set_lists"
@registry.register
class SqlIndexDataset(DatasetBase, ReplaceableBase):
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
"""
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
The length is returned after all sequence and frame filters are applied (see param
@@ -93,7 +88,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
engine verbatim. Dont expose it to end users of your application!
pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to.
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_per_category_to: Limit the dataset to the first up to N
sequences within each category (applies after all other sequence filters
@@ -108,16 +102,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
more frames than that; applied after other frame-level filters.
seed: The seed of the random generator sampling `n_frames_per_sequence`
random frames per sequence.
preload_metadata: If True, the metadata is preloaded into memory.
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
scoped_session: If True, allows different parts of the code to share
a global session to access the database.
"""
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
SqlSequenceAnnotation
)
sqlite_metadata_file: str = ""
dataset_root: Optional[str] = None
@@ -130,7 +117,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
pick_categories: Tuple[str, ...] = ()
pick_sequences: Tuple[str, ...] = ()
pick_sequences_sql_clause: Optional[str] = None
exclude_sequences: Tuple[str, ...] = ()
limit_sequences_per_category_to: int = 0
limit_sequences_to: int = 0
@@ -138,22 +124,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
n_frames_per_sequence: int = -1
seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000
preload_metadata: bool = False
precompute_seq_to_idx: bool = False
# we set it manually in the constructor
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
_sql_engine: sa.engine.Engine = field(
init=False, metadata={"omegaconf_ignore": True}
)
eval_batches: Optional[List[Any]] = field(
init=False, metadata={"omegaconf_ignore": True}
)
# _index: pd.DataFrame = field(init=False)
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
frame_data_builder: FrameDataBuilderBase
frame_data_builder_class_type: str = "FrameDataBuilder"
scoped_session: bool = False
def __post_init__(self) -> None:
if sa.__version__ < "2.0":
raise ImportError("This class requires SQL Alchemy 2.0 or later")
@@ -162,28 +138,19 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
raise ValueError("sqlite_metadata_file must be set")
if self.dataset_root:
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
getattr(self, frame_args)["dataset_root"] = self.dataset_root
getattr(self, frame_args)["path_manager"] = self.path_manager
frame_builder_type = self.frame_data_builder_class_type
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
"dataset_root"
] = self.dataset_root
run_auto_creation(self)
self.frame_data_builder.path_manager = self.path_manager
if self.path_manager is not None:
self.sqlite_metadata_file = self.path_manager.get_local_path(
self.sqlite_metadata_file
)
self.subset_lists_file = self.path_manager.get_local_path(
self.subset_lists_file
)
# NOTE: sqlite-specific args (read-only mode).
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
self._sql_engine = sa.create_engine(
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
)
if self.preload_metadata:
self._sql_engine = self._preload_database(self._sql_engine)
sequences = self._get_filtered_sequences_if_any()
if self.subsets:
@@ -199,29 +166,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
if len(index) == 0:
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
self._index = index.set_index(["sequence_name", "frame_number"])
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
self.eval_batches = None
self.eval_batches = None # pyre-ignore
if self.eval_batches_file:
self.eval_batches = self._load_filter_eval_batches()
logger.info(str(self))
if self.scoped_session:
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
if self.precompute_seq_to_idx:
# This is deprecated and will be removed in the future.
# After we backport https://github.com/facebookresearch/uco3d/pull/3
logger.warning(
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
)
self._index["rowid"] = np.arange(len(self._index))
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
del self._index["rowid"]
def __len__(self) -> int:
# pyre-ignore[16]
return len(self._index)
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
@@ -278,18 +232,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
self.frame_annotations_type.frame_number
== int(frame), # cast from np.int64
)
seq_stmt = sa.select(self.sequence_annotations_type).where(
self.sequence_annotations_type.sequence_name == seq
seq_stmt = sa.select(SqlSequenceAnnotation).where(
SqlSequenceAnnotation.sequence_name == seq
)
if self.scoped_session:
# pyre-ignore
with scoped_session(self._session_factory)() as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
else:
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
@@ -302,6 +250,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
return frame_data
def __str__(self) -> str:
# pyre-ignore[16]
return f"SqlIndexDataset #frames={len(self._index)}"
def sequence_names(self) -> Iterable[str]:
@@ -311,10 +260,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
# override
def category_to_sequence_names(self) -> Dict[str, List[str]]:
stmt = sa.select(
self.sequence_annotations_type.category,
self.sequence_annotations_type.sequence_name,
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
).where( # we limit results to sequences that have frames after all filters
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
)
with self._sql_engine.connect() as connection:
cat_to_seqs = pd.read_sql(stmt, connection)
@@ -387,31 +335,17 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
rows = self._index.index.get_loc(seq_name)
if isinstance(rows, slice):
assert rows.stop is not None, "Unexpected result from pandas"
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
rows = range(rows.start or 0, rows.stop, rows.step or 1)
else:
rows_seq = list(np.where(rows)[0])
rows = np.where(rows)[0]
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
rows_seq, seq_name, subset_filter
rows, seq_name, subset_filter
)
index_slice["idx"] = idx
yield from index_slice.itertuples(index=False)
# override
def sequence_indices_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
if self.precompute_seq_to_idx and subset_filter is None:
# pyre-ignore
yield from self._seq_to_indices[seq_name]
else:
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
yield idx
# override
def get_eval_batches(self) -> Optional[List[Any]]:
"""
@@ -445,35 +379,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
or self.limit_sequences_to > 0
or self.limit_sequences_per_category_to > 0
or len(self.pick_sequences) > 0
or self.pick_sequences_sql_clause is not None
or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0
)
def _preload_database(
self, source_engine: sa.engine.base.Engine
) -> sa.engine.base.Engine:
destination_engine = sa.create_engine("sqlite:///:memory:")
metadata = sa.MetaData()
metadata.reflect(bind=source_engine)
metadata.create_all(bind=destination_engine)
with source_engine.connect() as source_conn:
with destination_engine.connect() as destination_conn:
for table_obj in metadata.tables.values():
# Select all rows from the source table
source_rows = source_conn.execute(table_obj.select())
# Insert rows into the destination table
for row in source_rows:
destination_conn.execute(table_obj.insert().values(row))
# Commit the changes for each table
destination_conn.commit()
return destination_engine
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible filter (if limit_sequences_per_category_to == 0):
# WHERE category IN 'self.pick_categories'
@@ -486,22 +396,19 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
*self._get_pick_filters(),
*self._get_exclude_filters(),
]
if self.pick_sequences_sql_clause:
print("Applying the custom SQL clause.")
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
def add_where(stmt):
return stmt.where(*where_conditions) if where_conditions else stmt
if self.limit_sequences_per_category_to <= 0:
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
else:
subquery = sa.select(
self.sequence_annotations_type.sequence_name,
SqlSequenceAnnotation.sequence_name,
sa.func.row_number()
.over(
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
partition_by=self.sequence_annotations_type.category,
partition_by=SqlSequenceAnnotation.category,
)
.label("row_number"),
)
@@ -537,34 +444,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
return []
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
def _get_pick_filters(self) -> List[sa.ColumnElement]:
if not self.pick_sequences:
return []
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
if not self.exclude_sequences:
return []
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
return [
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
]
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
subsets = self.subsets
assert subsets is not None
assert self.subsets is not None
with open(subset_lists_path, "r") as f:
subset_to_seq_frame = json.load(f)
seq_frame_list = sum(
(
[(*row, subset) for row in subset_to_seq_frame[subset]]
for subset in subsets
for subset in self.subsets
),
[],
)
@@ -618,7 +522,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
).where(self.frame_annotations_type._mask_mass == 0)
with Session(self._sql_engine) as session:
to_remove = session.execute(stmt).all()
@@ -682,7 +586,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
stmt = sa.select(
self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number,
self.frame_annotations_type._image_path, # pyre-ignore[16]
self.frame_annotations_type._image_path,
sa.null().label("subset"),
)
where_conditions = []
@@ -696,7 +600,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
logger.info(" excluding samples with empty masks")
where_conditions.append(
sa.or_(
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
self.frame_annotations_type._mask_mass.is_(None),
self.frame_annotations_type._mask_mass != 0,
)
)
@@ -730,9 +634,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
assert self.eval_batches_file
logger.info(f"Loading eval batches from {self.eval_batches_file}")
if (
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
if not os.path.isfile(self.eval_batches_file):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError(
@@ -740,8 +642,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
+ "Please specify a correct dataset_root folder."
)
eval_batches_file = self._local_path(self.eval_batches_file)
with open(eval_batches_file, "r") as f:
with open(self.eval_batches_file, "r") as f:
eval_batches = json.load(f)
# limit the dataset to sequences to allow multiple evaluations in one file
@@ -825,15 +726,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
self.frame_annotations_type.sequence_name == seq_name,
self.frame_annotations_type.frame_number.in_(frames),
)
frame_no_ts = None
if self.scoped_session:
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
with scoped_session(self._session_factory)() as session: # pyre-ignore
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
else:
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if len(frame_no_ts) != len(index_slice):
raise ValueError(
@@ -863,18 +758,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
prefixes=["TEMP"], # NOTE SQLite specific!
)
@classmethod
def pre_expand(cls) -> None:
# remove dataclass annotations that are not meant to be init params
# because they cause troubles for OmegaConf
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
if isinstance(attr_value, Field) and attr_value.metadata.get(
"omegaconf_ignore", False
):
delattr(cls, attr)
del cls.__annotations__[attr]
def _seq_name_to_seed(seq_name) -> int:
"""Generates numbers in [0, 2 ** 28)"""
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
def _safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None

View File

@@ -4,8 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging
import os
@@ -45,7 +43,7 @@ logger = logging.getLogger(__name__)
@registry.register
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"""
Generates the training, validation, and testing dataset objects for
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
@@ -195,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
# this is a mould that is never constructed, used to build self._dataset_map values
dataset_class_type: str = "SqlIndexDataset"
dataset: SqlIndexDataset # pyre-ignore [13]
dataset: SqlIndexDataset
path_manager_factory: PathManagerFactory # pyre-ignore [13]
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self):
@@ -284,14 +282,8 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.")
if self.eval_batches_path is None:
eval_batches_file = None
else:
eval_batches_file = self._get_lists_file("eval_batches")
if "eval_batches_file" in common_dataset_kwargs:
common_dataset_kwargs.pop("eval_batches_file", None)
eval_batches_file = self._get_lists_file("eval_batches")
del common_dataset_kwargs["eval_batches_file"]
test_dataset = dataset_type(
**common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True),

View File

@@ -87,15 +87,6 @@ def is_train_frame(
def get_bbox_from_mask(
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# these corner cases need to be handled in order to avoid an infinite loop
if mask.size == 0:
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
return 0, 0, 1, 1
if not mask.min() >= 0.0:
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
mask = mask.clip(min=0.0)
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
@@ -143,15 +134,7 @@ T = TypeVar("T", bound=torch.Tensor)
def bbox_xyxy_to_xywh(xyxy: T) -> T:
wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh])
return xywh # pyre-ignore[7]
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
wh = xywh[2:]
if clamp_size is not None:
wh = wh.clamp(min=clamp_size)
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
return xyxy # pyre-ignore[7]
return xywh # pyre-ignore
def get_clamp_bbox(
@@ -197,6 +180,16 @@ def rescale_bbox(
return bbox * rel_size
def bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1] + 1
@@ -208,24 +201,18 @@ def resize_image(
image_width: Optional[int],
mode: str = "bilinear",
) -> Tuple[torch.Tensor, float, torch.Tensor]:
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
if (
image_height is None
or image_width is None
or image.shape[-2] == 0
or image.shape[-1] == 0
):
if image_height is None or image_width is None:
# skip the resizing
return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
image[None],
scale_factor=minscale,
@@ -233,7 +220,6 @@ def resize_image(
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width)
@@ -246,21 +232,9 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
return im.astype(np.float32) / 255.0
def load_image(
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
) -> np.ndarray:
"""
Load an image from a path and return it as a numpy array.
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
returned as the fourth channel.
Otherwise, the image is read as RGB and a three-channel image is returned.
"""
def load_image(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
# Check if the image has an alpha channel
if try_read_alpha and pil_im.mode == "RGBA":
im = np.array(pil_im)
else:
im = np.array(pil_im.convert(pil_format))
im = np.array(pil_im.convert("RGB"))
return transpose_normalize_image(im)
@@ -355,7 +329,6 @@ def adjust_camera_to_bbox_crop_(
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera.principal_point[0],
image_size_wh,
)
@@ -368,7 +341,6 @@ def adjust_camera_to_bbox_crop_(
)
camera.focal_length = focal_length[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_cropped[None]
@@ -380,7 +352,6 @@ def adjust_camera_to_image_scale_(
) -> PerspectiveCameras:
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera.principal_point[0],
original_size_wh,
)
@@ -397,8 +368,7 @@ def adjust_camera_to_image_scale_(
image_size_wh_output,
)
camera.focal_length = focal_length_scaled[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
camera.principal_point = principal_point_scaled[None]
# NOTE this cache is per-worker; they are implemented as processes.

View File

@@ -299,6 +299,7 @@ def eval_batch(
)
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
loss_mask_now = mask_crop * loss_fg_mask
for rgb_metric_name, rgb_metric_fun in zip(

View File

@@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
self.layers = torch.nn.ModuleList()
self.proj_layers = torch.nn.ModuleList()
for stage in range(self.max_stage):
stage_name = f"layer{stage + 1}"
stage_name = f"layer{stage+1}"
feature_name = self._get_resnet_stage_feature_name(stage)
if (stage + 1) in self.stages:
if (
@@ -139,18 +139,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
self.stages = set(self.stages) # convert to set for faster "in"
def _get_resnet_stage_feature_name(self, stage) -> str:
return f"res_layer_{stage + 1}"
return f"res_layer_{stage+1}"
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self) -> int:
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
# not a function.
return sum(self._feat_dim.values())
def forward(
@@ -189,12 +183,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
else:
imgs_normed = imgs_resized
# is not a function.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
feats = self.stem(imgs_normed)
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
# `Union[Tensor, Module]`.
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
# `Union[Tensor, Module]`.
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
feats = layer(feats)
# just a sanity check below

View File

@@ -478,8 +478,6 @@ class GenericModel(ImplicitronModelBase):
)
custom_args["global_code"] = global_code
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for func in self._implicit_functions:
func.bind_args(**custom_args)
@@ -502,8 +500,6 @@ class GenericModel(ImplicitronModelBase):
# Unbind the custom arguments to prevent pytorch from storing
# large buffers of intermediate results due to points in the
# bound arguments.
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for func in self._implicit_functions:
func.unbind_args()

View File

@@ -71,7 +71,6 @@ class Autodecoder(Configurable, torch.nn.Module):
return key_map
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
return (self._autodecoder_codes.weight**2).mean()
def get_encoding_dim(self) -> int:
@@ -96,7 +95,6 @@ class Autodecoder(Configurable, torch.nn.Module):
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
# `Tensor`.
x = torch.tensor(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
[self._key_map[elem] for elem in x],
dtype=torch.long,
device=next(self.parameters()).device,
@@ -104,7 +102,6 @@ class Autodecoder(Configurable, torch.nn.Module):
except StopIteration:
raise ValueError("Not enough n_instances in the autodecoder") from None
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._autodecoder_codes(x)
def _load_key_map_hook(

View File

@@ -122,7 +122,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._harmonic_embedding(time)
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:

View File

@@ -232,14 +232,9 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
# if the skip tensor is None, we use `x` instead.
z = x
skipi = 0
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
# `Union[Tensor, Module]`.
for li, layer in enumerate(self.mlp):
# pyre-fixme[58]: `in` is not supported for right operand type
# `Union[Tensor, Module]`.
if li in self._input_skips:
if self._skip_affine_trans:
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
else:
y = torch.cat((y, z), dim=-1)

View File

@@ -141,16 +141,11 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
self.embed_fn is None and fun_viewpool is None and global_code is None
):
return torch.tensor(
[],
device=rays_points_world.device,
dtype=rays_points_world.dtype,
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
# `Union[Module, Tensor]`.
[], device=rays_points_world.device, dtype=rays_points_world.dtype
).view(0, self.out_dim)
embeddings = []
if self.embed_fn is not None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
embeddings.append(self.embed_fn(rays_points_world))
if fun_viewpool is not None:
@@ -169,19 +164,13 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
embedding = torch.cat(embeddings, dim=-1)
x = embedding
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
for layer_idx in range(self.num_layers - 1):
if layer_idx in self.skip_in:
x = torch.cat([x, embedding], dim=-1) / 2**0.5
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
x = self.linear_layers[layer_idx](x)
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
if layer_idx < self.num_layers - 2:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
x = self.softplus(x)
return x

View File

@@ -123,10 +123,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self.color_layer((self.intermediate_linear(features), rays_embedding))
@staticmethod
@@ -197,8 +195,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world,
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
# got `Union[None, Tensor, Module]`.
xyz_embedding_function=(
self.harmonic_embedding_xyz if self.input_xyz else None
),
@@ -210,14 +206,12 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
)
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
features = self.xyz_encoder(embeds)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
features = features.reshape(*rays_points_world.shape[:-1], -1)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raw_densities = self.density_layer(features)
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
@@ -225,8 +219,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions

View File

@@ -103,8 +103,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world,
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
# got `Union[Tensor, Module]`.
xyz_embedding_function=self._harmonic_embedding,
global_code=global_code,
fun_viewpool=fun_viewpool,
@@ -114,7 +112,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
# Before running the network, we have to resize embeds to ndims=3,
# otherwise the SRN layers consume huge amounts of memory.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raymarch_features = self._net(
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
)
@@ -169,9 +166,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_embedding = self._harmonic_embedding(rays_directions_normed)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._color_layer((features, rays_embedding))
def forward(
@@ -200,7 +195,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
denoting the color of each ray point.
"""
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
features = self._net(raymarch_features)
# features.shape = [minibatch x ... x self.n_hidden_units]
@@ -208,8 +202,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions
@@ -217,7 +209,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
# NNs operate on the flattenned rays; reshaping to the correct spatial size
features = features.reshape(*raymarch_features.shape[:-1], -1)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raw_densities = self._density_layer(features)
rays_colors = self._get_colors(features, directions)
@@ -278,7 +269,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
srn_raymarch_function.
"""
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
net = self._hypernet(global_code)
# use the hyper-net generated network to instantiate the raymarch module
@@ -306,6 +296,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
global_code=None,
**kwargs,
):
if global_code is None:
raise ValueError("SRN Hypernetwork requires a non-trivial global code.")
@@ -313,8 +304,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
# across LSTM iterations for the same global_code.
if self.cached_srn_raymarch_function is None:
# generate the raymarching network from the hypernet
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
# `cached_srn_raymarch_function`.
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
(srn_raymarch_function,) = cast(
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
@@ -342,7 +331,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def create_raymarch_function(self) -> None:
self.raymarch_function = SRNRaymarchFunction(
latent_dim=self.latent_dim,
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.raymarch_function_args,
)
@@ -401,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
self.hypernet = SRNRaymarchHyperNet(
latent_dim=self.latent_dim,
latent_dim_hypernet=self.latent_dim_hypernet,
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.hypernet_args,
)

View File

@@ -40,6 +40,7 @@ def create_embeddings_for_implicit_function(
xyz_embedding_function: Optional[Callable],
diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
if xyz_in_camera_coords:
@@ -63,6 +64,7 @@ def create_embeddings_for_implicit_function(
0,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
embeds = embeds.reshape(
bs,

View File

@@ -269,7 +269,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
for name, tensor in vars(grid_values_with_wanted_resolution).items()
}
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self.values_type(**params), True
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
@@ -883,7 +882,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
torch.Tensor of shape (..., n_features)
"""
locator = self._get_volume_locator()
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
@@ -897,7 +895,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
replace current parameters
"""
if self.hold_voxel_grid_as_parameters:
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
self.params = torch.nn.ParameterDict(
{
k: torch.nn.Parameter(val)
@@ -948,7 +945,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
Returns:
True if parameter change has happened else False.
"""
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params)
grid_values, change = self.voxel_grid.change_resolution(
grid_values, epoch=epoch
@@ -996,21 +992,16 @@ class VoxelGridModule(Configurable, torch.nn.Module):
"""
'''
new_params = {}
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for name in self.params:
key = prefix + "params." + name
if key in state_dict:
new_params[name] = torch.zeros_like(state_dict[key])
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
def get_device(self) -> torch.device:
"""
Returns torch.device on which module parameters are located
"""
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
# not a function.
return next(val for val in self.params.values() if val is not None).device
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
@@ -1027,7 +1018,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
"""
locator = self._get_volume_locator()
# torch.nn.modules.module.Module]` is not a function.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
old_grid_values = self.voxel_grid.values_type(**self.params)
new_grid_values = self.voxel_grid.crop_world(
min_point, max_point, old_grid_values, locator
@@ -1035,7 +1025,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
grid_values, _ = self.voxel_grid.change_resolution(
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
)
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
self.params = torch.nn.ParameterDict(
{
k: torch.nn.Parameter(val)

View File

@@ -192,26 +192,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def __post_init__(self) -> None:
run_auto_creation(self)
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `voxel_grid_scaffold`.
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `harmonic_embedder_xyz_density`.
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
**self.harmonic_embedder_xyz_density_args
)
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `harmonic_embedder_xyz_color`.
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
**self.harmonic_embedder_xyz_color_args
)
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `harmonic_embedder_dir_color`.
self.harmonic_embedder_dir_color = HarmonicEmbedding(
**self.harmonic_embedder_dir_color_args
)
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `_scaffold_ready`.
self._scaffold_ready = False
def forward(
@@ -262,7 +252,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
# ########## filter the points using the scaffold ########## #
if self._scaffold_ready and self.scaffold_filter_points:
with torch.no_grad():
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
points = points[non_empty_points]
if len(points) == 0:
@@ -374,7 +363,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
feature dimensionality which `decoder_density` returns
"""
embeds_density = self.voxel_grid_density(points)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
# shape = [..., density_dim]
return self.decoder_density(harmonic_embedding_density)
@@ -409,8 +397,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
if self.xyz_ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = directions @ camera.R
# ########## get voxel grid output ########## #
@@ -419,13 +405,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
# ########## embed with the harmonic function ########## #
# Obtain the harmonic embedding of the voxel grid output.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
rays_directions_normed
)
@@ -494,11 +478,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
an object inside, else False.
"""
# find bounding box
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_grid_points`.
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
non_zero_idxs = torch.nonzero(occupancy)
if len(non_zero_idxs) == 0:
@@ -530,8 +511,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
"""
planes = []
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_grid_points`.
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
chunk_size = (
@@ -551,10 +530,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
stride=1,
)
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `params`.
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `_scaffold_ready`.
self._scaffold_ready = True
return False
@@ -571,8 +547,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
decoding function to this value.
"""
grid_args = self.voxel_grid_density_args
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
# `Union[Tensor, Module]`.
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
embedder_args = self.harmonic_embedder_xyz_density_args
@@ -601,8 +575,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
decoding function to this value.
"""
grid_args = self.voxel_grid_color_args
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
# `Union[Tensor, Module]`.
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
embedder_args = self.harmonic_embedder_xyz_color_args
@@ -636,9 +608,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
`self.voxel_grid_density`
"""
return VoxelGridModule(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
extents=self.voxel_grid_density_args["extents"],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
translation=self.voxel_grid_density_args["translation"],
voxel_grid_class_type="FullResolutionVoxelGrid",
hold_voxel_grid_as_parameters=False,

View File

@@ -6,6 +6,7 @@
# pyre-unsafe
import warnings
from typing import Any, Dict, Optional
@@ -297,8 +298,9 @@ class ViewMetrics(ViewMetricsBase):
_rgb_metrics(
image_rgb,
image_rgb_pred,
masks=fg_probability,
masks_crop=mask_crop,
fg_probability,
fg_probability_pred,
mask_crop,
)
)
@@ -308,21 +310,9 @@ class ViewMetrics(ViewMetricsBase):
metrics["mask_neg_iou"] = utils.neg_iou_loss(
fg_probability_pred, fg_probability, mask=mask_crop
)
if torch.is_autocast_enabled():
# To avoid issues with mixed precision
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred.logit(),
fg_probability,
mask=mask_crop,
pred_logits=True,
)
else:
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred,
fg_probability,
mask=mask_crop,
pred_logits=False,
)
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred, fg_probability, mask=mask_crop
)
if depth_map is not None and depth_map_pred is not None:
assert mask_crop is not None
@@ -334,11 +324,7 @@ class ViewMetrics(ViewMetricsBase):
if fg_probability is not None:
mask = fg_probability * mask_crop
_, abs_ = utils.eval_depth(
depth_map_pred,
depth_map,
get_best_scale=True,
mask=mask,
crop=0,
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
)
metrics["depth_abs_fg"] = abs_.mean()
@@ -360,26 +346,18 @@ class ViewMetrics(ViewMetricsBase):
return metrics
def _rgb_metrics(
images,
images_pred,
masks=None,
masks_crop=None,
huber_scaling: float = 0.03,
):
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
assert masks_crop is not None
if images.shape[1] != images_pred.shape[1]:
raise ValueError(
f"Network output's RGB images had {images_pred.shape[1]} "
f"channels. {images.shape[1]} expected."
)
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
crop_mass = masks_crop.sum().clamp(1.0)
results = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
}

View File

@@ -135,7 +135,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
break
# run the lstm marcher
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
state_h, state_c = self._lstm(
raymarch_features.view(-1, raymarch_features.shape[-1]),
states[-1],
@@ -143,7 +142,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
if state_h.requires_grad:
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
# predict the next step size
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
# log the lstm states
states.append((state_h, state_c))

View File

@@ -207,7 +207,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
"""
sample_mask = None
if (
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
and mask is not None
):
@@ -224,7 +223,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
EvaluationMode.EVALUATION: self._evaluation_raysampler,
}[evaluation_mode]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
ray_bundle = raysampler(
cameras=cameras,
mask=sample_mask,
@@ -242,8 +240,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
)
elif self.cast_ray_bundle_as_cone:
# pyre-fixme[9]: pixel_hw has type `Tuple[float, float]`; used as
# `Tuple[Union[Tensor, Module], Union[Tensor, Module]]`.
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
return ImplicitronRayBundle(

View File

@@ -179,10 +179,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
rays_densities = torch.relu(rays_densities)
weighted_densities = deltas * rays_densities
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
capped_densities = self._capping_function(weighted_densities)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_opacities = self._capping_function(
torch.cumsum(weighted_densities, dim=-1)
)
@@ -192,7 +190,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
)
absorption_shifted[..., : self.surface_thickness] = 1.0
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
weights = self._weight_function(capped_densities, absorption_shifted)
features = (weights[..., None] * rays_features).sum(dim=-2)
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
@@ -200,8 +197,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
alpha = opacities if self.blend_output else 1
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
raise ValueError("Wrong number of background color channels.")
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[Tensor, Module]`.
features = alpha * features + (1 - opacities) * self._bg_color
return RendererOutput(

View File

@@ -61,7 +61,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
def create_ray_tracer(self) -> None:
self.ray_tracer = RayTracing(
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.ray_tracer_args,
object_bounding_sphere=self.object_bounding_sphere,
)
@@ -150,8 +149,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
n_eik_points,
3,
# but got `Union[device, Tensor, Module]`.
# pyre-fixme[6]: For 3rd argument expected `Union[None, int, str,
# device]` but got `Union[device, Tensor, Module]`.
device=self._bg_color.device,
).uniform_(-eik_bounding_box, eik_bounding_box)
eikonal_pixel_points = points.clone()
@@ -208,7 +205,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
]
normals_full.view(-1, 3)[surface_mask] = normals
render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
self._rgb_network(
features,
differentiable_surface_points[None],
@@ -220,7 +216,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
)
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
-self.soft_mask_alpha * sdf_output[~surface_mask]
-self.soft_mask_alpha
* sdf_output[~surface_mask]
)
# scatter points with surface_mask

View File

@@ -532,7 +532,6 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
# does not produce nans randomly unlike get_camera_center() below
cam_centers_rep = -torch.bmm(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera_rep.T[:, None],
camera_rep.R.permute(0, 2, 1),
).reshape(-1, *([1] * (pts.ndim - 2)), 3)

View File

@@ -209,7 +209,6 @@ def handle_seq_id(
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
# pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has
# no attribute `to`.
# pyre-fixme[7]: Expected `LongTensor` but got `Tensor`.
return seq_id.to(device)

View File

@@ -21,6 +21,7 @@ def cleanup_eval_depth(
sigma: float = 0.01,
image=None,
):
ba, _, H, W = depth.shape
pcl = point_cloud.points_padded()

View File

@@ -6,15 +6,12 @@
# pyre-unsafe
import logging
import math
from typing import Optional, Tuple
import torch
from torch.nn import functional as F
logger = logging.getLogger(__name__)
def eval_depth(
pred: torch.Tensor,
@@ -24,8 +21,6 @@ def eval_depth(
get_best_scale: bool = True,
mask_thr: float = 0.5,
best_scale_clamp_thr: float = 1e-4,
use_disparity: bool = False,
disparity_eps: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluate the depth error between the prediction `pred` and the ground
@@ -69,13 +64,6 @@ def eval_depth(
# s.t. we get best possible mse error
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
pred = pred * scale_best[:, None, None, None]
if use_disparity:
gt = torch.div(1.0, (gt + disparity_eps))
pred = torch.div(1.0, (pred + disparity_eps))
scale_best = estimate_depth_scale_factor(
pred, gt, dmask, best_scale_clamp_thr
).detach()
pred = pred * scale_best[:, None, None, None]
df = gt - pred
@@ -129,7 +117,6 @@ def calc_bce(
pred_eps: float = 0.01,
mask: Optional[torch.Tensor] = None,
lerp_bound: Optional[float] = None,
pred_logits: bool = False,
) -> torch.Tensor:
"""
Calculates the binary cross entropy.
@@ -152,23 +139,9 @@ def calc_bce(
weight = torch.ones_like(gt) * mask
if lerp_bound is not None:
# binary_cross_entropy_lerp requires pred to be in [0, 1]
if pred_logits:
pred = F.sigmoid(pred)
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
else:
if pred_logits:
loss = F.binary_cross_entropy_with_logits(
pred,
gt,
reduction="none",
weight=weight,
)
else:
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
return loss.mean()
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
def binary_cross_entropy_lerp(

View File

@@ -111,10 +111,10 @@ def load_model(fl, map_location: Optional[dict]):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
flopt = get_optimizer_path(fl)
model_state_dict = torch.load(flmodel, map_location=map_location, weights_only=True)
model_state_dict = torch.load(flmodel, map_location=map_location)
stats = load_stats(flstats)
if os.path.isfile(flopt):
optimizer = torch.load(flopt, map_location=map_location, weights_only=True)
optimizer = torch.load(flopt, map_location=map_location)
else:
optimizer = None

View File

@@ -100,6 +100,7 @@ def render_point_cloud_pytorch3d(
bin_size: Optional[int] = None,
**kwargs,
):
# feature dimension
featdim = point_cloud.features_packed().shape[-1]

View File

@@ -37,6 +37,7 @@ class AverageMeter:
self.count = 0
def update(self, val, n=1, epoch=0):
# make sure the history is of the same len as epoch
while len(self.history) <= epoch:
self.history.append([])
@@ -114,6 +115,7 @@ class Stats:
visdom_server="http://localhost",
visdom_port=8097,
):
self.log_vars = log_vars
self.visdom_env = visdom_env
self.visdom_server = visdom_server
@@ -200,6 +202,7 @@ class Stats:
self.log_vars.append(add_log_var)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
if self.epoch == -1: # uninitialized
logger.warning(
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
@@ -216,6 +219,7 @@ class Stats:
epoch = self.epoch
for stat in self.log_vars:
if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter()
@@ -244,6 +248,7 @@ class Stats:
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
def get_epoch_averages(self, epoch=None):
stat_sets = list(self.stats.keys())
if epoch is None:
@@ -340,6 +345,7 @@ class Stats:
def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
):
# use the cached visdom env if none supplied
if visdom_env is None:
visdom_env = self.visdom_env
@@ -443,6 +449,7 @@ class Stats:
warnings.warn("Cant dump stats due to insufficient permissions!")
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
stat_sets = list(self.stats.keys())
# remove the additional log_vars
@@ -483,12 +490,11 @@ class Stats:
for ep in range(lastep):
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
epoch_generated = self.stats[stat_set][stat].get_epoch()
assert epoch_generated == self.epoch + 1, (
"bad epoch of synchronized log_var! %d vs %d"
% (
self.epoch + 1,
epoch_generated,
)
assert (
epoch_generated == self.epoch + 1
), "bad epoch of synchronized log_var! %d vs %d" % (
self.epoch + 1,
epoch_generated,
)

View File

@@ -16,17 +16,8 @@ from typing import Optional, Tuple, Union
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
_NO_TORCHVISION = False
try:
import torchvision
except ImportError:
_NO_TORCHVISION = True
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
matplotlib.use("Agg")
@@ -45,7 +36,6 @@ class VideoWriter:
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
use_torchvision_video_writer: bool = False,
**kwargs,
) -> None:
"""
@@ -59,8 +49,6 @@ class VideoWriter:
is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty.
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
to write the video
"""
self.rmdir_allowed = rmdir_allowed
self.output_format = output_format
@@ -68,14 +56,10 @@ class VideoWriter:
self.out_path = out_path
self.cache_dir = cache_dir
self.ffmpeg_bin = ffmpeg_bin
self.use_torchvision_video_writer = use_torchvision_video_writer
self.frames = []
self.regexp = "frame_%08d.png"
self.frame_num = 0
if self.use_torchvision_video_writer:
assert not _NO_TORCHVISION, "torchvision not available"
if self.cache_dir is not None:
self.tmp_dir = None
if os.path.isdir(self.cache_dir):
@@ -130,7 +114,7 @@ class VideoWriter:
resize = im.size
# make sure size is divisible by 2
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
im = im.resize(resize, Image.ANTIALIAS)
im.save(outfile)
@@ -155,56 +139,38 @@ class VideoWriter:
# got `Optional[str]`.
regexp = os.path.join(self.cache_dir, self.regexp)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
if self.output_format == "visdom": # works for ppt too
# Video codec parameters
video_codec = "h264"
crf = "18"
b = "2000k"
pix_fmt = "yuv420p"
if self.use_torchvision_video_writer:
torchvision.io.write_video(
self.out_path,
torch.stack(
[torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
),
fps=self.fps,
video_codec=video_codec,
options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
"h264",
"-f",
"mp4",
"-y",
"-crf",
"18",
"-b",
"2000k",
"-pix_fmt",
"yuv420p",
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
video_codec,
"-f",
"mp4",
"-y",
"-crf",
crf,
"-b",
b,
"-pix_fmt",
pix_fmt,
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
subprocess.check_call(args)
subprocess.check_call(args)
else:
raise ValueError("no such output type %s" % str(self.output_format))

View File

@@ -163,8 +163,9 @@ def _read_chunks(
if binary_data is not None:
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
assert binary_data is not None
# pyre-fixme[7]: Expected `Optional[Tuple[Dict[str, typing.Any],
# ndarray[typing.Any, typing.Any]]]` but got `Tuple[typing.Any,
# Optional[ndarray[typing.Any, dtype[typing.Any]]]]`.
return json_data, binary_data

View File

@@ -7,7 +7,6 @@
# pyre-unsafe
"""This module implements utility functions for loading .mtl files and textures."""
import os
import warnings
from typing import Dict, List, Optional, Tuple

View File

@@ -8,7 +8,6 @@
"""This module implements utility functions for loading and saving meshes."""
import os
import warnings
from collections import namedtuple
@@ -814,6 +813,7 @@ def _save(
save_texture: bool = False,
save_normals: bool = False,
) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)

View File

@@ -14,7 +14,6 @@ meshes as .off files.
This format is introduced, for example, at
http://www.geomview.org/docs/html/OFF.html .
"""
import warnings
from typing import cast, Optional, Tuple, Union
@@ -85,7 +84,7 @@ def _read_faces_lump(
)
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces)
except ValueError as e:
if n_faces > 1 and "number of columns" in e.args[0]:
if n_faces > 1 and "Wrong number of columns" in e.args[0]:
file.seek(old_offset)
return None
raise ValueError("Not enough face data.") from None

View File

@@ -11,7 +11,6 @@
This module implements utility functions for loading and saving
meshes and point clouds as PLY files.
"""
import itertools
import os
import struct
@@ -1247,10 +1246,13 @@ def _save_ply(
return
color_np_type = np.ubyte if colors_as_uint8 else np.float32
verts_dtype: list = [("verts", np.float32, 3)]
verts_dtype = [("verts", np.float32, 3)]
if verts_normals is not None:
verts_dtype.append(("normals", np.float32, 3))
if verts_colors is not None:
# pyre-fixme[6]: For 1st argument expected `Tuple[str,
# Type[floating[_32Bit]], int]` but got `Tuple[str,
# Type[Union[floating[_32Bit], unsignedinteger[typing.Any]]], int]`.
verts_dtype.append(("colors", color_np_type, 3))
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)

View File

@@ -122,17 +122,12 @@ def corresponding_cameras_alignment(
# create a new cameras object and set the R and T accordingly
cameras_src_aligned = cameras_src.clone()
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `Union[Tensor, Module]`.
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
cameras_src_aligned.T = (
torch.bmm(
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Tensor, Module]`.
cameras_src.R,
)[:, 0]
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Tensor, Module]` is not a function.
+ cameras_src.T * align_t_s
)
@@ -180,7 +175,6 @@ def _align_camera_extrinsics(
R_A = (U V^T)^T
```
"""
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
U, _, V = torch.svd(RRcov)
align_t_R = V @ U.t()
@@ -210,11 +204,7 @@ def _align_camera_extrinsics(
T_A = mean(B) - mean(A) * s_A
```
"""
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
Amu = A.mean(0, keepdim=True)
Bmu = B.mean(0, keepdim=True)

View File

@@ -62,7 +62,7 @@ def cubify(
*,
feats: Optional[torch.Tensor] = None,
device=None,
align: str = "topleft",
align: str = "topleft"
) -> Meshes:
r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube

View File

@@ -85,6 +85,7 @@ class _points_to_volumes_function(Function):
align_corners: bool,
splat: bool,
):
ctx.mark_dirty(volume_densities, volume_features)
N, P, D = points_3d.shape
@@ -496,6 +497,7 @@ def _check_points_to_volumes_inputs(
grid_sizes: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
) -> None:
max_grid_size = grid_sizes.max(dim=0).values
if torch.prod(max_grid_size) > volume_densities.shape[1]:
raise ValueError(

View File

@@ -11,7 +11,6 @@
This module implements utility functions for sampling points from
batches of meshes.
"""
import sys
from typing import Tuple, Union

View File

@@ -353,16 +353,45 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
# e.g. verts_per_mesh = (4, 5, 6)
# e.g. edges_per_mesh = (5, 7, 9)
rng = torch.arange(verts_per_mesh.shape[0], device=device) # (0,1,2)
verts_nums = rng.repeat_interleave(
verts_per_mesh
) # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
edges_nums = rng.repeat_interleave(
edges_per_mesh
) # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
nums = torch.cat([verts_nums, edges_nums])
V = verts_per_mesh.sum() # e.g. 15
E = edges_per_mesh.sum() # e.g. 21
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0) # (N,) e.g. (5, 12, 21)
v_to_e_idx = verts_per_mesh_cumsum.clone()
# vertex to edge index.
v_to_e_idx[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
# vertex to edge offset.
v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
v_to_e_offset[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
e_to_v_idx = (
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
) # (4, 9) + (5, 12) = (9, 21)
e_to_v_offset = (
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
) # (4, 9) - (5, 12) - 15 = (-16, -18)
# Add one new vertex per edge.
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
idx_diffs[v_to_e_idx] += v_to_e_offset
idx_diffs[e_to_v_idx] += e_to_v_offset
# e.g.
# [
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
# ]
verts_idx = idx_diffs.cumsum(dim=0) - 1
verts_idx = torch.argsort(nums, stable=True)
# e.g.
# [
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
@@ -371,6 +400,7 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
# ]
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
return verts_idx
@@ -391,9 +421,44 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
"""
# e.g. faces_per_mesh = [2, 5, 3]
rng = torch.arange(faces_per_mesh.shape[0], device=device) # (0,1,2)
nums = rng.repeat_interleave(faces_per_mesh).repeat(4)
faces_idx = torch.argsort(nums, stable=True)
F = faces_per_mesh.sum() # e.g. 10
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone()
switch1_idx[1:] += (
3 * faces_per_mesh_cumsum[:-1]
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
switch2_idx[1:] += (
2 * faces_per_mesh_cumsum[:-1]
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
switch3_idx[1:] += faces_per_mesh_cumsum[
:-1
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
# typing.Tuple[int, ...]]` but got `Tensor`.
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
idx_diffs[switch1_idx] += switch123_offset
idx_diffs[switch2_idx] += switch123_offset
idx_diffs[switch3_idx] += switch123_offset
idx_diffs[switch4_idx] -= 3 * F
# e.g
# [
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
# ]
faces_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [

View File

@@ -65,11 +65,7 @@ def _opencv_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a function.
R_pytorch3d = cameras.R.clone()
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a function.
T_pytorch3d = cameras.T.clone()
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
@@ -110,32 +106,30 @@ def _pulsar_from_opencv_projection(
# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(image_size_wh > 0), (
"height and width must be positive but min is: %s"
% (str(image_size_wh.min().item()))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3, (
"Incorrect camera matrix shape: expected 3x3 but got %dx%d"
% (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert R.size(1) == 3 and R.size(2) == 3, (
"Incorrect R shape: expected 3x3 but got %dx%d"
% (
R.size(1),
R.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert tvec.size(1) == 3 and tvec.size(2) == 1, (
"Incorrect tvec shape: expected 3x1 but got %dx%d"
% (
tvec.size(1),
tvec.size(2),
)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
@@ -143,12 +137,11 @@ def _pulsar_from_opencv_projection(
batch_size,
R.size(0),
)
assert tvec.size(0) == batch_size, (
"Expected tvec to have batch size %d. Has size %d."
% (
batch_size,
tvec.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]

View File

@@ -203,9 +203,7 @@ class CamerasBase(TensorProperties):
"""
R: torch.Tensor = kwargs.get("R", self.R)
T: torch.Tensor = kwargs.get("T", self.T)
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
self.R = R
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
self.T = T
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
return world_to_view_transform
@@ -230,9 +228,7 @@ class CamerasBase(TensorProperties):
a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
self.R: torch.Tensor = kwargs.get("R", self.R)
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
self.T: torch.Tensor = kwargs.get("T", self.T)
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
view_to_proj_transform = self.get_projection_transform(**kwargs)
@@ -1176,12 +1172,7 @@ class PerspectiveCameras(CamerasBase):
unprojection_transform = to_camera_transform.inverse()
xy_inv_depth = torch.cat(
# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
# tuple[Tensor, ...]]` but got `Tuple[Tensor, float]`.
# pyre-fixme[58]: `/` is not supported for operand types `float` and
# `Tensor`.
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]),
dim=-1, # type: ignore
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore
)
return unprojection_transform.transform_points(xy_inv_depth)

View File

@@ -281,10 +281,8 @@ class FishEyeCameras(CamerasBase):
# project from camera space to image space
N = len(self.radial_params)
if not self.check_input(points, N):
msg = (
"Expected points of (P, 3) with batch_size 1 or N, or shape (M, P, 3) \
msg = "Expected points of (P, 3) with batch_size 1 or N, or shape (M, P, 3) \
with batch_size 1; got points of shape %r and batch_size %r"
)
raise ValueError(msg % (points.shape, N))
if N == 1:

View File

@@ -67,7 +67,7 @@ class HeterogeneousRayBundle:
def ray_bundle_to_ray_points(
ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
ray_bundle: Union[RayBundle, HeterogeneousRayBundle]
) -> torch.Tensor:
"""
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`

View File

@@ -266,9 +266,7 @@ class PointLights(TensorProperties):
shape (P, 3) or (N, H, W, K, 3).
"""
if self.location.ndim == points.ndim:
# pyre-fixme[7]: Expected `Tensor` but got `Union[Tensor, Module]`.
return self.location
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
return self.location[:, None, None, None, :]
def diffuse(self, normals, points) -> torch.Tensor:

View File

@@ -168,7 +168,7 @@ def _get_culled_faces(face_verts: torch.Tensor, frustum: ClipFrustum) -> torch.T
position of the clipping planes.
Returns:
faces_culled: boolean tensor of size F specifying whether or not each face should be
faces_culled: An boolean tensor of size F specifying whether or not each face should be
culled.
"""
clipping_planes = (

View File

@@ -726,17 +726,15 @@ class TexturesUV(TexturesBase):
for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
(a FloatTensor with values between 0 and 1).
maps_ids: Used if there are to be multiple maps per face.
This can be either a list of map_ids [(F,)]
maps_ids: Used if there are to be multiple maps per face. This can be either a list of map_ids [(F,)]
or a long tensor of shape (N, F) giving the id of the texture map
for each face. If maps_ids is present, the maps has an extra dimension M
(so maps_padded is (N, M, H, W, C) and maps_list has elements of
shape (M, H, W, C)).
Specifically, the color
of a vertex V is given by an average of
maps_padded[i, maps_ids[i, f], u, v, :]
of a vertex V is given by an average of maps_padded[i, maps_ids[i, f], u, v, :]
over u and v integers adjacent to
_verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
_verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
align_corners: If true, the extreme values 0 and 1 for verts_uvs
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
@@ -1239,8 +1237,7 @@ class TexturesUV(TexturesBase):
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
else:
# We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),
# fragments.pix_to_face: (N, Ho, Wo, K)
# We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),fragmenmts.pix_to_face: (N, Ho, Wo, K)
# Get pixel_to_map_ids: (N, K, Ho, Wo) by indexing pix_to_face into maps_ids
N, M, H_in, W_in, C = texture_maps.shape # 3 for RGB
@@ -1251,9 +1248,8 @@ class TexturesUV(TexturesBase):
pixel_to_map_ids = (
maps_ids_padded.flatten()
.gather(0, pix_to_face.flatten())
.view(N, H_out, W_out, K, 1)
.permute(0, 3, 1, 2, 4)
) # N x H_out x W_out x K x 1
.view(N, K, H_out, W_out)
)
# Normalize between -1 and 1 with M (number of maps)
pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1
@@ -1262,10 +1258,10 @@ class TexturesUV(TexturesBase):
pixel_uvs.new_tensor([-1.0, 1.0]),
pixel_uvs.new_tensor([1.0, -1.0]),
pixel_uvs,
) # N x H_out x W_out x K x 2
)
# N x H_out x W_out x K x 3
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids), dim=4)
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4)
# (N, M, H, W, C) -> (N, C, M, H, W)
texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
if texture_maps.device != pixel_uvs.device:
@@ -1830,7 +1826,7 @@ class TexturesVertex(TexturesBase):
representation) which overlap the pixel.
Returns:
texels: A texture per pixel of shape (N, H, W, K, C).
texels: An texture per pixel of shape (N, H, W, K, C).
There will be one C dimensional value for each element in
fragments.pix_to_face.
"""

View File

@@ -184,7 +184,7 @@ class EGLContext:
"""
# Lock used to prevent multiple threads from rendering on the same device
# at the same time, creating/destroying contexts at the same time, etc.
self.lock = threading.RLock()
self.lock = threading.Lock()
self.cuda_device_id = cuda_device_id
self.device = _get_cuda_device(self.cuda_device_id)
self.width = width
@@ -224,14 +224,15 @@ class EGLContext:
Throws:
EGLError when the context cannot be made current or make non-current.
"""
with self.lock:
egl.eglMakeCurrent(self.dpy, self.surface, self.surface, self.context)
try:
yield
finally:
egl.eglMakeCurrent(
self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT
)
self.lock.acquire()
egl.eglMakeCurrent(self.dpy, self.surface, self.surface, self.context)
try:
yield
finally:
egl.eglMakeCurrent(
self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT
)
self.lock.release()
def get_context_info(self) -> Dict[str, Any]:
"""

View File

@@ -12,7 +12,6 @@ Proper Python support for pytorch requires creating a torch.autograd.function
(independent of whether this is being done within the C++ module). This is done
here and a torch.nn.Module is exposed for the use in more complex models.
"""
import logging
import warnings
from typing import Optional, Tuple, Union

View File

@@ -133,7 +133,8 @@ def _get_splat_kernel_normalization(
epsilon = 0.05
normalization_constant = torch.exp(
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
-(offsets**2).sum(dim=1) / (2 * sigma**2)
-(offsets**2).sum(dim=1)
/ (2 * sigma**2)
).sum()
# We add an epsilon to the normalization constant to ensure the gradient will travel

View File

@@ -114,6 +114,7 @@ class TensorProperties(nn.Module):
self.device = make_device(device)
self._N = 0
if kwargs is not None:
# broadcast all inputs which are float/int/list/tuple/tensor/array
# set as attributes anything else e.g. strings, bools
args_to_broadcast = {}
@@ -438,7 +439,7 @@ def ndc_to_grid_sample_coords(
def parse_image_size(
image_size: Union[List[int], Tuple[int, int], int],
image_size: Union[List[int], Tuple[int, int], int]
) -> Tuple[int, int]:
"""
Args:

View File

@@ -1531,6 +1531,7 @@ class Meshes:
def sample_textures(self, fragments):
if self.textures is not None:
# Check dimensions of textures match that of meshes
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
if not shape_ok:

View File

@@ -1274,7 +1274,7 @@ def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]) -> Pointclouds
def join_pointclouds_as_scene(
pointclouds: Union[Pointclouds, List[Pointclouds]],
pointclouds: Union[Pointclouds, List[Pointclouds]]
) -> Pointclouds:
"""
Joins a batch of point cloud in the form of a Pointclouds object or a list of Pointclouds

View File

@@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
return out[..., 1:]
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to rotation matrices.
@@ -472,91 +472,27 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if not fast:
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
shape = axis_angle.shape
device, dtype = axis_angle.device, axis_angle.dtype
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
cross_product_matrix = torch.stack(
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
).view(shape + (3,))
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
identity = torch.eye(3, dtype=dtype, device=device)
angles_sqrd = angles * angles
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
return (
identity.expand(cross_product_matrix.shape)
+ torch.sinc(angles / torch.pi) * cross_product_matrix
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
)
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
fast: Whether to use the new faster implementation (based on the
Rodrigues formula) instead of the original implementation (which
first converted to a quaternion and then back to a rotation matrix).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
if not fast:
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
omegas = torch.stack(
[
matrix[..., 2, 1] - matrix[..., 1, 2],
matrix[..., 0, 2] - matrix[..., 2, 0],
matrix[..., 1, 0] - matrix[..., 0, 1],
],
dim=-1,
)
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
angles = torch.atan2(norms, traces - 1)
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
near_pi = angles.isclose(angles.new_full((1,), torch.pi)).squeeze(-1)
axis_angles = torch.empty_like(omegas)
axis_angles[~near_pi] = (
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
)
# this derives from: nnT = (R + 1) / 2
n = 0.5 * (
matrix[near_pi][..., 0, :]
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
)
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
return axis_angles
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
@@ -573,10 +509,22 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
return torch.cat(
[torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
half_angles = angles * 0.5
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
@@ -595,9 +543,18 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
# can't be zero
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles

View File

@@ -311,7 +311,9 @@ def plot_scene(
)
else:
msg = "Invalid number {} of viewpoint cameras were provided. Either 1 \
or {} cameras are required".format(len(viewpoint_cameras), len(subplots))
or {} cameras are required".format(
len(viewpoint_cameras), len(subplots)
)
warnings.warn(msg)
for subplot_idx in range(len(subplots)):
@@ -586,15 +588,9 @@ def _add_struct_from_batch(
if isinstance(batched_struct, CamerasBase):
# we can't index directly into camera batches
R, T = batched_struct.R, batched_struct.T
# pyre-fixme[6]: For 1st argument expected
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, Module]`.
r_idx = min(scene_num, len(R) - 1)
# pyre-fixme[6]: For 1st argument expected
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, Module]`.
t_idx = min(scene_num, len(T) - 1)
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
R = R[r_idx].unsqueeze(0)
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
T = T[t_idx].unsqueeze(0)
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(

View File

@@ -11,6 +11,7 @@ from tests.test_ball_query import TestBallQuery
def bm_ball_query() -> None:
backends = ["cpu", "cuda:0"]
kwargs_list = []

View File

@@ -11,6 +11,7 @@ from tests.test_cameras_alignment import TestCamerasAlignment
def bm_cameras_alignment() -> None:
case_grid = {
"batch_size": [10, 100, 1000],
"mode": ["centers", "extrinsics"],

View File

@@ -11,6 +11,7 @@ from tests.test_knn import TestKNN
def bm_knn() -> None:
backends = ["cpu", "cuda:0"]
kwargs_list = []

View File

@@ -12,6 +12,7 @@ from tests.test_point_mesh_distance import TestPointMeshDistance
def bm_point_mesh_distance() -> None:
backend = ["cuda:0"]
kwargs_list = []

View File

@@ -12,6 +12,7 @@ from tests.test_points_alignment import TestCorrespondingPointsAlignment, TestIC
def bm_iterative_closest_point() -> None:
case_grid = {
"batch_size": [1, 10],
"dim": [3, 20],
@@ -42,6 +43,7 @@ def bm_iterative_closest_point() -> None:
def bm_corresponding_points_alignment() -> None:
case_grid = {
"allow_reflection": [True, False],
"batch_size": [1, 10, 100],

Some files were not shown because too many files have changed in this diff Show More