mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-27 00:36:02 +08:00
Compare commits
1 Commits
bottler/un
...
bottler/ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c586b1351 |
@@ -88,6 +88,7 @@ def workflow_pair(
|
||||
upload=False,
|
||||
filter_branch,
|
||||
):
|
||||
|
||||
w = []
|
||||
py = python_version.replace(".", "")
|
||||
pyt = pytorch_version.replace(".", "")
|
||||
@@ -126,6 +127,7 @@ def generate_base_workflow(
|
||||
btype,
|
||||
filter_branch=None,
|
||||
):
|
||||
|
||||
d = {
|
||||
"name": base_workflow_name,
|
||||
"python_version": python_version,
|
||||
|
||||
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
@@ -3,9 +3,6 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
binary_linux_conda_cuda:
|
||||
runs-on: 4-core-ubuntu-gpu-t4
|
||||
|
||||
@@ -36,5 +36,5 @@ then
|
||||
|
||||
echo "Running pyre..."
|
||||
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
||||
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
|
||||
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
|
||||
fi
|
||||
|
||||
@@ -10,7 +10,6 @@ This example demonstrates the most trivial, direct interface of the pulsar
|
||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||
Output: basic.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -11,7 +11,6 @@ interface for sphere renderering. It renders and saves an image with
|
||||
10 random spheres.
|
||||
Output: basic-pt3d.png.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam.gif.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -14,7 +14,6 @@ distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
Output: cam-pt3d.gif
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import path
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ This example is not available yet through the 'unified' interface,
|
||||
because opacity support has not landed in PyTorch3D for general data
|
||||
structures yet.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from os import path
|
||||
|
||||
@@ -13,7 +13,6 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ requirements:
|
||||
|
||||
build:
|
||||
string: py{{py}}_{{ environ['CU_VERSION'] }}_pyt{{ environ['PYTORCH_VERSION_NODOT']}}
|
||||
# script: LD_LIBRARY_PATH=$PREFIX/lib:$BUILD_PREFIX/lib:$LD_LIBRARY_PATH python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
||||
script: python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
||||
script_env:
|
||||
- CUDA_HOME
|
||||
@@ -56,6 +57,7 @@ test:
|
||||
- pandas
|
||||
- sqlalchemy
|
||||
commands:
|
||||
#pytest .
|
||||
python -m unittest discover -v -s tests -t .
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
""" "
|
||||
""""
|
||||
This file is the entry point for launching experiments with Implicitron.
|
||||
|
||||
Launch Training
|
||||
@@ -44,7 +44,6 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
config file.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
@@ -26,6 +26,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
@@ -115,9 +116,7 @@ class ImplicitronModelFactory(ModelFactoryBase):
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path),
|
||||
map_location=map_location,
|
||||
weights_only=True,
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -123,7 +123,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
"""
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||
else:
|
||||
p_groups = [
|
||||
@@ -242,7 +241,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
else:
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
||||
@@ -161,6 +161,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
for epoch in range(start_epoch, self.max_epochs):
|
||||
# automatic new_epoch and plotting of stats at every epoch start
|
||||
with stats:
|
||||
|
||||
# Make sure to re-seed random generators to ensure reproducibility
|
||||
# even after restart.
|
||||
seed_all_random_engines(seed + epoch)
|
||||
@@ -394,7 +395,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
if hasattr(model, "visualize"):
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
model.visualize(
|
||||
viz,
|
||||
visdom_env_imgs,
|
||||
|
||||
@@ -53,8 +53,12 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@@ -90,8 +94,12 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
dataset_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
@@ -103,7 +111,9 @@ class TestExperiment(unittest.TestCase):
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
|
||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
||||
2
|
||||
)
|
||||
|
||||
if DEBUG:
|
||||
experiment.dump_cfg(cfg)
|
||||
|
||||
@@ -81,9 +81,8 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
|
||||
def test_param_overrides_self_param_group_assignment(self):
|
||||
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
||||
na, nb = (
|
||||
Node(params=[pa]),
|
||||
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
|
||||
na, nb = Node(params=[pa]), Node(
|
||||
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
|
||||
)
|
||||
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
||||
param_groups = self._get_param_groups(root)
|
||||
|
||||
@@ -84,9 +84,9 @@ def get_nerf_datasets(
|
||||
|
||||
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
||||
# Automatically download the data files if missing.
|
||||
download_data([dataset_name], data_root=data_root)
|
||||
download_data((dataset_name,), data_root=data_root)
|
||||
|
||||
train_data = torch.load(cameras_path, weights_only=True)
|
||||
train_data = torch.load(cameras_path)
|
||||
n_cameras = train_data["cameras"]["R"].shape[0]
|
||||
|
||||
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
||||
|
||||
@@ -194,6 +194,7 @@ class Stats:
|
||||
it = self.it[stat_set]
|
||||
|
||||
for stat in self.log_vars:
|
||||
|
||||
if stat not in self.stats[stat_set]:
|
||||
self.stats[stat_set][stat] = AverageMeter()
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Device on which to run.
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
@@ -62,7 +63,7 @@ def main(cfg: DictConfig):
|
||||
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
||||
|
||||
print(f"Loading checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
# Do not load the cached xy grid.
|
||||
# - this allows setting an arbitrary evaluation image size.
|
||||
state_dict = {
|
||||
|
||||
@@ -42,6 +42,7 @@ class TestRaysampler(unittest.TestCase):
|
||||
cameras, rays = [], []
|
||||
|
||||
for _ in range(batch_size):
|
||||
|
||||
R = random_rotations(1)
|
||||
T = torch.randn(1, 3)
|
||||
focal_length = torch.rand(1, 2) + 0.5
|
||||
|
||||
@@ -25,6 +25,7 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
||||
|
||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
# Set the relevant seeds for reproducibility.
|
||||
np.random.seed(cfg.seed)
|
||||
torch.manual_seed(cfg.seed)
|
||||
@@ -76,7 +77,7 @@ def main(cfg: DictConfig):
|
||||
# Resume training if requested.
|
||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||
loaded_data = torch.load(checkpoint_path)
|
||||
model.load_state_dict(loaded_data["model"])
|
||||
stats = pickle.loads(loaded_data["stats"])
|
||||
print(f" => resuming from epoch {stats.epoch}.")
|
||||
@@ -218,6 +219,7 @@ def main(cfg: DictConfig):
|
||||
|
||||
# Validation
|
||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||
|
||||
# Sample a validation camera/image.
|
||||
val_batch = next(val_dataloader.__iter__())
|
||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||
|
||||
@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
|
||||
|
||||
|
||||
def meshgrid_ij(
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||
"""
|
||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
|
||||
@@ -28,6 +28,7 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
@@ -78,6 +79,7 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
@@ -28,6 +28,7 @@ __global__ void weightedSumNormCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
@@ -91,6 +92,7 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
@@ -26,6 +26,7 @@ __global__ void weightedSumCudaForwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
@@ -73,6 +74,7 @@ __global__ void weightedSumCudaBackwardKernel(
|
||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
@@ -149,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::arg("gamma"),
|
||||
py::arg("max_depth"),
|
||||
py::arg("min_depth") /* = 0.f*/,
|
||||
py::arg("bg_col") /* = std::nullopt not exposed properly in
|
||||
pytorch 1.1. */
|
||||
py::arg(
|
||||
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
||||
,
|
||||
py::arg("opacity") /* = std::nullopt ... */,
|
||||
py::arg("opacity") /* = at::nullopt ... */,
|
||||
py::arg("percent_allowed_difference") = 0.01f,
|
||||
py::arg("max_n_hits") = MAX_UINT,
|
||||
py::arg("mode") = 0)
|
||||
|
||||
@@ -7,7 +7,10 @@
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <list>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include "iou_box3d/iou_utils.h"
|
||||
|
||||
|
||||
@@ -461,8 +461,10 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
||||
__device__ inline bool IsCoplanarTriTri(
|
||||
const FaceVerts& tri1,
|
||||
const FaceVerts& tri2) {
|
||||
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||
|
||||
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||
|
||||
// Check if parallel
|
||||
@@ -498,6 +500,7 @@ __device__ inline bool IsCoplanarTriPlane(
|
||||
const FaceVerts& tri,
|
||||
const FaceVerts& plane,
|
||||
const float3& normal) {
|
||||
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||
|
||||
// check if parallel
|
||||
@@ -725,7 +728,7 @@ __device__ inline int BoxIntersections(
|
||||
}
|
||||
}
|
||||
// Update the face_verts_out tris
|
||||
num_tris = min(MAX_TRIS, offset);
|
||||
num_tris = offset;
|
||||
for (int j = 0; j < num_tris; ++j) {
|
||||
face_verts_out[j] = tri_verts_updated[j];
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@
|
||||
|
||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
@@ -34,13 +33,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
.contiguous();
|
||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||
#ifdef WITH_CUDA
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
cudaMemcpyAsync(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
#else
|
||||
throw std::runtime_error(
|
||||
"Copy on CUDA device initiated but built "
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
*/
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace pulsar {
|
||||
@@ -18,8 +17,7 @@ void cudaDevToDev(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
void cudaDevToHost(
|
||||
@@ -27,8 +25,7 @@ void cudaDevToHost(
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
C10_CUDA_CHECK(
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include "ATen/core/TensorAccessor.h"
|
||||
|
||||
@@ -35,6 +35,8 @@ __global__ void FarthestPointSamplingKernel(
|
||||
__shared__ int64_t selected_store;
|
||||
|
||||
// Get constants
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t D = points.size(2);
|
||||
|
||||
// Get batch index and thread index
|
||||
|
||||
@@ -376,6 +376,8 @@ PointLineDistanceBackward(
|
||||
float tt = t_top / t_bot;
|
||||
tt = __saturatef(tt);
|
||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||
const float2 d = p - p_proj;
|
||||
const float dist = sqrt(dot(d, d));
|
||||
|
||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||
|
||||
@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
|
||||
):
|
||||
synset_set.add(synset)
|
||||
elif (synset in self.synset_inv.keys()) and (
|
||||
path.isdir(path.join(data_dir, self.synset_inv[synset]))
|
||||
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||
):
|
||||
synset_set.add(self.synset_inv[synset])
|
||||
else:
|
||||
|
||||
@@ -36,6 +36,7 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
||||
|
||||
collated_dict["mesh"] = None
|
||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||
|
||||
textures = None
|
||||
if "textures" in collated_dict:
|
||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||
|
||||
@@ -26,7 +26,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset import orm_types, types
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.utils import (
|
||||
adjust_camera_to_bbox_crop_,
|
||||
adjust_camera_to_image_scale_,
|
||||
@@ -48,12 +48,8 @@ from pytorch3d.implicitron.dataset.utils import (
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
@@ -126,9 +122,9 @@ class FrameData(Mapping[str, Any]):
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor] = None
|
||||
sequence_name: Union[str, List[str]] = ""
|
||||
sequence_category: Union[str, List[str]] = ""
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.LongTensor] = None
|
||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||
@@ -159,7 +155,7 @@ class FrameData(Mapping[str, Any]):
|
||||
new_params = {}
|
||||
for field_name in iter(self):
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[field_name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[field_name] = value
|
||||
@@ -421,6 +417,7 @@ class FrameData(Mapping[str, Any]):
|
||||
for f in fields(elem):
|
||||
if not f.init:
|
||||
continue
|
||||
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
@@ -429,7 +426,7 @@ class FrameData(Mapping[str, Any]):
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return type(elem)(**collated)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
@@ -437,8 +434,6 @@ class FrameData(Mapping[str, Any]):
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
elif isinstance(elem, Meshes):
|
||||
return join_meshes_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data.dataloader.default_collate(batch)
|
||||
|
||||
@@ -459,8 +454,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@@ -546,8 +541,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def build(
|
||||
self,
|
||||
frame_annotation: FrameAnnotationT,
|
||||
sequence_annotation: SequenceAnnotationT,
|
||||
frame_annotation: types.FrameAnnotation,
|
||||
sequence_annotation: types.SequenceAnnotation,
|
||||
*,
|
||||
load_blobs: bool = True,
|
||||
**kwargs,
|
||||
@@ -591,81 +586,58 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
),
|
||||
)
|
||||
|
||||
dataset_root = self.dataset_root
|
||||
fg_mask_np: Optional[np.ndarray] = None
|
||||
mask_annotation = frame_annotation.mask
|
||||
depth_annotation = frame_annotation.depth
|
||||
image_path: str | None = None
|
||||
mask_path: str | None = None
|
||||
depth_path: str | None = None
|
||||
pcl_path: str | None = None
|
||||
if dataset_root is not None: # set all paths even if we won’t load blobs
|
||||
if frame_annotation.image.path is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if mask_annotation is not None and mask_annotation.path:
|
||||
mask_path = os.path.join(dataset_root, mask_annotation.path)
|
||||
frame_data.mask_path = mask_path
|
||||
|
||||
if depth_annotation is not None and depth_annotation.path is not None:
|
||||
depth_path = os.path.join(dataset_root, depth_annotation.path)
|
||||
frame_data.depth_path = depth_path
|
||||
|
||||
if point_cloud is not None:
|
||||
pcl_path = os.path.join(dataset_root, point_cloud.path)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
fg_mask_np: np.ndarray | None = None
|
||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks and mask_path:
|
||||
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
frame_data.mask_path = mask_path
|
||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||
|
||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if frame_annotation.image is not None:
|
||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
image_path = None
|
||||
dataset_root = self.dataset_root
|
||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
no_mask = fg_mask_np is None # didn’t read the mask file
|
||||
image_np = load_image(
|
||||
self._local_path(image_path), try_read_alpha=no_mask
|
||||
)
|
||||
if image_np.shape[0] == 4: # RGBA image
|
||||
if no_mask:
|
||||
fg_mask_np = image_np[3:]
|
||||
frame_data.fg_probability = safe_as_tensor(
|
||||
fg_mask_np, torch.float
|
||||
)
|
||||
|
||||
image_np = image_np[:3]
|
||||
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if bbox_xywh is None and fg_mask_np is not None:
|
||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||
|
||||
if load_blobs and self.load_depths and depth_path is not None:
|
||||
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
|
||||
frame_annotation, depth_path, fg_mask_np
|
||||
)
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
frame_data.depth_mask,
|
||||
) = self._load_mask_depth(frame_annotation, fg_mask_np)
|
||||
|
||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||
assert pcl_path is not None
|
||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||
frame_data.sequence_point_cloud = load_pointcloud(
|
||||
self._local_path(pcl_path), max_points=self.max_points
|
||||
)
|
||||
frame_data.sequence_point_cloud_path = pcl_path
|
||||
|
||||
if frame_annotation.viewpoint is not None:
|
||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||
@@ -681,14 +653,18 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return frame_data
|
||||
|
||||
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
|
||||
fg_probability = load_mask(self._local_path(path))
|
||||
def _load_fg_probability(
|
||||
self, entry: types.FrameAnnotation
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
assert self.dataset_root is not None and entry.mask is not None
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
fg_probability = load_mask(self._local_path(full_path))
|
||||
if fg_probability.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
|
||||
return fg_probability
|
||||
return fg_probability, full_path
|
||||
|
||||
def _postprocess_image(
|
||||
self,
|
||||
@@ -709,14 +685,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
entry: FrameAnnotationT,
|
||||
path: str,
|
||||
entry: types.FrameAnnotation,
|
||||
fg_mask: Optional[np.ndarray],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
dataset_root = self.dataset_root
|
||||
assert dataset_root is not None
|
||||
assert entry_depth is not None
|
||||
assert entry_depth is not None and entry_depth.path is not None
|
||||
path = os.path.join(dataset_root, entry_depth.path)
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
@@ -730,11 +706,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
else:
|
||||
depth_mask = (depth_map > 0.0).astype(np.float32)
|
||||
|
||||
return torch.tensor(depth_map), torch.tensor(depth_mask)
|
||||
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
||||
|
||||
def _get_pytorch3d_camera(
|
||||
self,
|
||||
entry: FrameAnnotationT,
|
||||
entry: types.FrameAnnotation,
|
||||
) -> PerspectiveCameras:
|
||||
entry_viewpoint = entry.viewpoint
|
||||
assert entry_viewpoint is not None
|
||||
@@ -763,6 +739,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||
)
|
||||
|
||||
def _fix_point_cloud_path(self, path: str) -> str:
|
||||
"""
|
||||
Fix up a point cloud path from the dataset.
|
||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||
"""
|
||||
unwanted_prefix = (
|
||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||
)
|
||||
if path.startswith(unwanted_prefix):
|
||||
path = path[len(unwanted_prefix) :]
|
||||
assert self.dataset_root is not None
|
||||
return os.path.join(self.dataset_root, path)
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
if self.path_manager is None:
|
||||
return path
|
||||
|
||||
@@ -222,6 +222,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def _load_category(self, category: str) -> DatasetMap:
|
||||
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
|
||||
@@ -75,6 +75,7 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
|
||||
def _load_data(
|
||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||
):
|
||||
|
||||
poses_arr = np.load(
|
||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||
)
|
||||
@@ -163,6 +164,7 @@ def ptstocam(pts, c2w):
|
||||
|
||||
|
||||
def poses_avg(poses):
|
||||
|
||||
hwf = poses[0, :3, -1:]
|
||||
|
||||
center = poses[:, :3, 3].mean(0)
|
||||
@@ -190,6 +192,7 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||
|
||||
|
||||
def recenter_poses(poses):
|
||||
|
||||
poses_ = poses + 0
|
||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||
c2w = poses_avg(poses)
|
||||
@@ -253,6 +256,7 @@ def spherify_poses(poses, bds):
|
||||
new_poses = []
|
||||
|
||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||
|
||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||
up = np.array([0, 0, -1.0])
|
||||
|
||||
@@ -307,6 +311,7 @@ def load_llff_data(
|
||||
path_zflat=False,
|
||||
path_manager=None,
|
||||
):
|
||||
|
||||
poses, bds, imgs = _load_data(
|
||||
basedir, factor=factor, path_manager=path_manager
|
||||
) # factor=8 downsamples original imgs by 8x
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
# This functionality requires SQLAlchemy 2.0 or later.
|
||||
|
||||
import math
|
||||
|
||||
@@ -4,15 +4,11 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import urllib
|
||||
from dataclasses import dataclass, Field, field
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
@@ -33,18 +29,17 @@ import sqlalchemy as sa
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
|
||||
from pytorch3d.implicitron.dataset.frame_data import (
|
||||
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
||||
FrameData,
|
||||
FrameDataBuilder, # noqa
|
||||
FrameDataBuilder,
|
||||
FrameDataBuilderBase,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||
|
||||
@@ -56,7 +51,7 @@ _SET_LISTS_TABLE: str = "set_lists"
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
"""
|
||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||
The length is returned after all sequence and frame filters are applied (see param
|
||||
@@ -93,7 +88,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
engine verbatim. Don’t expose it to end users of your application!
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||
sequences within each category (applies after all other sequence filters
|
||||
@@ -108,16 +102,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
more frames than that; applied after other frame-level filters.
|
||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||
random frames per sequence.
|
||||
preload_metadata: If True, the metadata is preloaded into memory.
|
||||
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||
scoped_session: If True, allows different parts of the code to share
|
||||
a global session to access the database.
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
|
||||
SqlSequenceAnnotation
|
||||
)
|
||||
|
||||
sqlite_metadata_file: str = ""
|
||||
dataset_root: Optional[str] = None
|
||||
@@ -130,7 +117,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
pick_categories: Tuple[str, ...] = ()
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
pick_sequences_sql_clause: Optional[str] = None
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_per_category_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
@@ -138,22 +124,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||
preload_metadata: bool = False
|
||||
precompute_seq_to_idx: bool = False
|
||||
# we set it manually in the constructor
|
||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||
_sql_engine: sa.engine.Engine = field(
|
||||
init=False, metadata={"omegaconf_ignore": True}
|
||||
)
|
||||
eval_batches: Optional[List[Any]] = field(
|
||||
init=False, metadata={"omegaconf_ignore": True}
|
||||
)
|
||||
# _index: pd.DataFrame = field(init=False)
|
||||
|
||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||
frame_data_builder: FrameDataBuilderBase
|
||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||
|
||||
scoped_session: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if sa.__version__ < "2.0":
|
||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||
@@ -162,28 +138,19 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
raise ValueError("sqlite_metadata_file must be set")
|
||||
|
||||
if self.dataset_root:
|
||||
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
|
||||
getattr(self, frame_args)["dataset_root"] = self.dataset_root
|
||||
getattr(self, frame_args)["path_manager"] = self.path_manager
|
||||
frame_builder_type = self.frame_data_builder_class_type
|
||||
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
||||
"dataset_root"
|
||||
] = self.dataset_root
|
||||
|
||||
run_auto_creation(self)
|
||||
self.frame_data_builder.path_manager = self.path_manager
|
||||
|
||||
if self.path_manager is not None:
|
||||
self.sqlite_metadata_file = self.path_manager.get_local_path(
|
||||
self.sqlite_metadata_file
|
||||
)
|
||||
self.subset_lists_file = self.path_manager.get_local_path(
|
||||
self.subset_lists_file
|
||||
)
|
||||
|
||||
# NOTE: sqlite-specific args (read-only mode).
|
||||
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
||||
self._sql_engine = sa.create_engine(
|
||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
||||
)
|
||||
|
||||
if self.preload_metadata:
|
||||
self._sql_engine = self._preload_database(self._sql_engine)
|
||||
|
||||
sequences = self._get_filtered_sequences_if_any()
|
||||
|
||||
if self.subsets:
|
||||
@@ -199,29 +166,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
if len(index) == 0:
|
||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||
|
||||
self._index = index.set_index(["sequence_name", "frame_number"])
|
||||
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
||||
|
||||
self.eval_batches = None
|
||||
self.eval_batches = None # pyre-ignore
|
||||
if self.eval_batches_file:
|
||||
self.eval_batches = self._load_filter_eval_batches()
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
if self.scoped_session:
|
||||
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||
|
||||
if self.precompute_seq_to_idx:
|
||||
# This is deprecated and will be removed in the future.
|
||||
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||
logger.warning(
|
||||
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||
)
|
||||
self._index["rowid"] = np.arange(len(self._index))
|
||||
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||
del self._index["rowid"]
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-ignore[16]
|
||||
return len(self._index)
|
||||
|
||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||
@@ -278,18 +232,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.frame_annotations_type.frame_number
|
||||
== int(frame), # cast from np.int64
|
||||
)
|
||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||
self.sequence_annotations_type.sequence_name == seq
|
||||
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
||||
SqlSequenceAnnotation.sequence_name == seq
|
||||
)
|
||||
if self.scoped_session:
|
||||
# pyre-ignore
|
||||
with scoped_session(self._session_factory)() as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
else:
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
|
||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||
|
||||
@@ -302,6 +250,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
return frame_data
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pyre-ignore[16]
|
||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
@@ -311,10 +260,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
# override
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
stmt = sa.select(
|
||||
self.sequence_annotations_type.category,
|
||||
self.sequence_annotations_type.sequence_name,
|
||||
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
||||
).where( # we limit results to sequences that have frames after all filters
|
||||
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
|
||||
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
||||
)
|
||||
with self._sql_engine.connect() as connection:
|
||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||
@@ -387,31 +335,17 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
rows = self._index.index.get_loc(seq_name)
|
||||
if isinstance(rows, slice):
|
||||
assert rows.stop is not None, "Unexpected result from pandas"
|
||||
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||
else:
|
||||
rows_seq = list(np.where(rows)[0])
|
||||
rows = np.where(rows)[0]
|
||||
|
||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||
rows_seq, seq_name, subset_filter
|
||||
rows, seq_name, subset_filter
|
||||
)
|
||||
index_slice["idx"] = idx
|
||||
|
||||
yield from index_slice.itertuples(index=False)
|
||||
|
||||
# override
|
||||
def sequence_indices_in_order(
|
||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||
) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
if self.precompute_seq_to_idx and subset_filter is None:
|
||||
# pyre-ignore
|
||||
yield from self._seq_to_indices[seq_name]
|
||||
else:
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||
yield idx
|
||||
|
||||
# override
|
||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||
"""
|
||||
@@ -445,35 +379,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
or self.limit_sequences_to > 0
|
||||
or self.limit_sequences_per_category_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or self.pick_sequences_sql_clause is not None
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def _preload_database(
|
||||
self, source_engine: sa.engine.base.Engine
|
||||
) -> sa.engine.base.Engine:
|
||||
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||
metadata = sa.MetaData()
|
||||
metadata.reflect(bind=source_engine)
|
||||
metadata.create_all(bind=destination_engine)
|
||||
|
||||
with source_engine.connect() as source_conn:
|
||||
with destination_engine.connect() as destination_conn:
|
||||
for table_obj in metadata.tables.values():
|
||||
# Select all rows from the source table
|
||||
source_rows = source_conn.execute(table_obj.select())
|
||||
|
||||
# Insert rows into the destination table
|
||||
for row in source_rows:
|
||||
destination_conn.execute(table_obj.insert().values(row))
|
||||
|
||||
# Commit the changes for each table
|
||||
destination_conn.commit()
|
||||
|
||||
return destination_engine
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||
# WHERE category IN 'self.pick_categories'
|
||||
@@ -486,22 +396,19 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if self.pick_sequences_sql_clause:
|
||||
print("Applying the custom SQL clause.")
|
||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
|
||||
if self.limit_sequences_per_category_to <= 0:
|
||||
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
|
||||
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
||||
else:
|
||||
subquery = sa.select(
|
||||
self.sequence_annotations_type.sequence_name,
|
||||
SqlSequenceAnnotation.sequence_name,
|
||||
sa.func.row_number()
|
||||
.over(
|
||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||
partition_by=self.sequence_annotations_type.category,
|
||||
partition_by=SqlSequenceAnnotation.category,
|
||||
)
|
||||
.label("row_number"),
|
||||
)
|
||||
@@ -537,34 +444,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
|
||||
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
||||
|
||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||
if not self.pick_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
|
||||
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
||||
|
||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||
if not self.exclude_sequences:
|
||||
return []
|
||||
|
||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||
return [
|
||||
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
|
||||
]
|
||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||
|
||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||
subsets = self.subsets
|
||||
assert subsets is not None
|
||||
assert self.subsets is not None
|
||||
with open(subset_lists_path, "r") as f:
|
||||
subset_to_seq_frame = json.load(f)
|
||||
|
||||
seq_frame_list = sum(
|
||||
(
|
||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||
for subset in subsets
|
||||
for subset in self.subsets
|
||||
),
|
||||
[],
|
||||
)
|
||||
@@ -618,7 +522,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
|
||||
).where(self.frame_annotations_type._mask_mass == 0)
|
||||
with Session(self._sql_engine) as session:
|
||||
to_remove = session.execute(stmt).all()
|
||||
|
||||
@@ -682,7 +586,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
stmt = sa.select(
|
||||
self.frame_annotations_type.sequence_name,
|
||||
self.frame_annotations_type.frame_number,
|
||||
self.frame_annotations_type._image_path, # pyre-ignore[16]
|
||||
self.frame_annotations_type._image_path,
|
||||
sa.null().label("subset"),
|
||||
)
|
||||
where_conditions = []
|
||||
@@ -696,7 +600,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
logger.info(" excluding samples with empty masks")
|
||||
where_conditions.append(
|
||||
sa.or_(
|
||||
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
||||
self.frame_annotations_type._mask_mass.is_(None),
|
||||
self.frame_annotations_type._mask_mass != 0,
|
||||
)
|
||||
)
|
||||
@@ -730,9 +634,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
assert self.eval_batches_file
|
||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||
|
||||
if (
|
||||
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
|
||||
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
|
||||
if not os.path.isfile(self.eval_batches_file):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
@@ -740,8 +642,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
|
||||
eval_batches_file = self._local_path(self.eval_batches_file)
|
||||
with open(eval_batches_file, "r") as f:
|
||||
with open(self.eval_batches_file, "r") as f:
|
||||
eval_batches = json.load(f)
|
||||
|
||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||
@@ -825,15 +726,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.frame_annotations_type.sequence_name == seq_name,
|
||||
self.frame_annotations_type.frame_number.in_(frames),
|
||||
)
|
||||
frame_no_ts = None
|
||||
|
||||
if self.scoped_session:
|
||||
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||
else:
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
|
||||
if len(frame_no_ts) != len(index_slice):
|
||||
raise ValueError(
|
||||
@@ -863,18 +758,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pre_expand(cls) -> None:
|
||||
# remove dataclass annotations that are not meant to be init params
|
||||
# because they cause troubles for OmegaConf
|
||||
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
|
||||
if isinstance(attr_value, Field) and attr_value.metadata.get(
|
||||
"omegaconf_ignore", False
|
||||
):
|
||||
delattr(cls, attr)
|
||||
del cls.__annotations__[attr]
|
||||
|
||||
|
||||
def _seq_name_to_seed(seq_name) -> int:
|
||||
"""Generates numbers in [0, 2 ** 28)"""
|
||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -45,7 +43,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
"""
|
||||
Generates the training, validation, and testing dataset objects for
|
||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||
@@ -195,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||
|
||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||
dataset_class_type: str = "SqlIndexDataset"
|
||||
dataset: SqlIndexDataset # pyre-ignore [13]
|
||||
dataset: SqlIndexDataset
|
||||
|
||||
path_manager_factory: PathManagerFactory # pyre-ignore [13]
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -284,14 +282,8 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
|
||||
logger.debug("Extracting test dataset.")
|
||||
if self.eval_batches_path is None:
|
||||
eval_batches_file = None
|
||||
else:
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
|
||||
if "eval_batches_file" in common_dataset_kwargs:
|
||||
common_dataset_kwargs.pop("eval_batches_file", None)
|
||||
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
del common_dataset_kwargs["eval_batches_file"]
|
||||
test_dataset = dataset_type(
|
||||
**common_dataset_kwargs,
|
||||
subsets=self._get_subsets(self.test_subsets, True),
|
||||
|
||||
@@ -87,15 +87,6 @@ def is_train_frame(
|
||||
def get_bbox_from_mask(
|
||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# these corner cases need to be handled in order to avoid an infinite loop
|
||||
if mask.size == 0:
|
||||
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
|
||||
return 0, 0, 1, 1
|
||||
|
||||
if not mask.min() >= 0.0:
|
||||
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
|
||||
mask = mask.clip(min=0.0)
|
||||
|
||||
# bbox in xywh
|
||||
masks_for_box = np.zeros_like(mask)
|
||||
while masks_for_box.sum() <= 1.0:
|
||||
@@ -143,15 +134,7 @@ T = TypeVar("T", bound=torch.Tensor)
|
||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh # pyre-ignore[7]
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
|
||||
wh = xywh[2:]
|
||||
if clamp_size is not None:
|
||||
wh = wh.clamp(min=clamp_size)
|
||||
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
|
||||
return xyxy # pyre-ignore[7]
|
||||
return xywh # pyre-ignore
|
||||
|
||||
|
||||
def get_clamp_bbox(
|
||||
@@ -197,6 +180,16 @@ def rescale_bbox(
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1] + 1
|
||||
@@ -208,24 +201,18 @@ def resize_image(
|
||||
image_width: Optional[int],
|
||||
mode: str = "bilinear",
|
||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if (
|
||||
image_height is None
|
||||
or image_width is None
|
||||
or image.shape[-2] == 0
|
||||
or image.shape[-1] == 0
|
||||
):
|
||||
if image_height is None or image_width is None:
|
||||
# skip the resizing
|
||||
return image, 1.0, torch.ones_like(image[:1])
|
||||
|
||||
# takes numpy array or tensor, returns pytorch tensor
|
||||
minscale = min(
|
||||
image_height / image.shape[-2],
|
||||
image_width / image.shape[-1],
|
||||
)
|
||||
|
||||
imre = torch.nn.functional.interpolate(
|
||||
image[None],
|
||||
scale_factor=minscale,
|
||||
@@ -233,7 +220,6 @@ def resize_image(
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
|
||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
mask = torch.zeros(1, image_height, image_width)
|
||||
@@ -246,21 +232,9 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(
|
||||
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load an image from a path and return it as a numpy array.
|
||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||
returned as the fourth channel.
|
||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||
"""
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
# Check if the image has an alpha channel
|
||||
if try_read_alpha and pil_im.mode == "RGBA":
|
||||
im = np.array(pil_im)
|
||||
else:
|
||||
im = np.array(pil_im.convert(pil_format))
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
@@ -355,7 +329,6 @@ def adjust_camera_to_bbox_crop_(
|
||||
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera.principal_point[0],
|
||||
image_size_wh,
|
||||
)
|
||||
@@ -368,7 +341,6 @@ def adjust_camera_to_bbox_crop_(
|
||||
)
|
||||
|
||||
camera.focal_length = focal_length[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_cropped[None]
|
||||
|
||||
|
||||
@@ -380,7 +352,6 @@ def adjust_camera_to_image_scale_(
|
||||
) -> PerspectiveCameras:
|
||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||
camera.focal_length[0],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera.principal_point[0],
|
||||
original_size_wh,
|
||||
)
|
||||
@@ -397,8 +368,7 @@ def adjust_camera_to_image_scale_(
|
||||
image_size_wh_output,
|
||||
)
|
||||
camera.focal_length = focal_length_scaled[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
||||
camera.principal_point = principal_point_scaled[None]
|
||||
|
||||
|
||||
# NOTE this cache is per-worker; they are implemented as processes.
|
||||
|
||||
@@ -299,6 +299,7 @@ def eval_batch(
|
||||
)
|
||||
|
||||
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
||||
|
||||
loss_mask_now = mask_crop * loss_fg_mask
|
||||
|
||||
for rgb_metric_name, rgb_metric_fun in zip(
|
||||
|
||||
@@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
self.layers = torch.nn.ModuleList()
|
||||
self.proj_layers = torch.nn.ModuleList()
|
||||
for stage in range(self.max_stage):
|
||||
stage_name = f"layer{stage + 1}"
|
||||
stage_name = f"layer{stage+1}"
|
||||
feature_name = self._get_resnet_stage_feature_name(stage)
|
||||
if (stage + 1) in self.stages:
|
||||
if (
|
||||
@@ -139,18 +139,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
self.stages = set(self.stages) # convert to set for faster "in"
|
||||
|
||||
def _get_resnet_stage_feature_name(self, stage) -> str:
|
||||
return f"res_layer_{stage + 1}"
|
||||
return f"res_layer_{stage+1}"
|
||||
|
||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
return (img - self._resnet_mean) / self._resnet_std
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||
# not a function.
|
||||
return sum(self._feat_dim.values())
|
||||
|
||||
def forward(
|
||||
@@ -189,12 +183,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
# is not a function.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
feats = self.stem(imgs_normed)
|
||||
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
||||
feats = layer(feats)
|
||||
# just a sanity check below
|
||||
|
||||
@@ -478,8 +478,6 @@ class GenericModel(ImplicitronModelBase):
|
||||
)
|
||||
custom_args["global_code"] = global_code
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for func in self._implicit_functions:
|
||||
func.bind_args(**custom_args)
|
||||
|
||||
@@ -502,8 +500,6 @@ class GenericModel(ImplicitronModelBase):
|
||||
# Unbind the custom arguments to prevent pytorch from storing
|
||||
# large buffers of intermediate results due to points in the
|
||||
# bound arguments.
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for func in self._implicit_functions:
|
||||
func.unbind_args()
|
||||
|
||||
|
||||
@@ -71,7 +71,6 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
return key_map
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
|
||||
return (self._autodecoder_codes.weight**2).mean()
|
||||
|
||||
def get_encoding_dim(self) -> int:
|
||||
@@ -96,7 +95,6 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||
# `Tensor`.
|
||||
x = torch.tensor(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||
[self._key_map[elem] for elem in x],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
@@ -104,7 +102,6 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
except StopIteration:
|
||||
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._autodecoder_codes(x)
|
||||
|
||||
def _load_key_map_hook(
|
||||
|
||||
@@ -122,7 +122,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
if frame_timestamp.shape[-1] != 1:
|
||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||
time = frame_timestamp / self.time_divisor
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._harmonic_embedding(time)
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -232,14 +232,9 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
# if the skip tensor is None, we use `x` instead.
|
||||
z = x
|
||||
skipi = 0
|
||||
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
for li, layer in enumerate(self.mlp):
|
||||
# pyre-fixme[58]: `in` is not supported for right operand type
|
||||
# `Union[Tensor, Module]`.
|
||||
if li in self._input_skips:
|
||||
if self._skip_affine_trans:
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
||||
else:
|
||||
y = torch.cat((y, z), dim=-1)
|
||||
|
||||
@@ -141,16 +141,11 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
self.embed_fn is None and fun_viewpool is None and global_code is None
|
||||
):
|
||||
return torch.tensor(
|
||||
[],
|
||||
device=rays_points_world.device,
|
||||
dtype=rays_points_world.dtype,
|
||||
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
|
||||
# `Union[Module, Tensor]`.
|
||||
[], device=rays_points_world.device, dtype=rays_points_world.dtype
|
||||
).view(0, self.out_dim)
|
||||
|
||||
embeddings = []
|
||||
if self.embed_fn is not None:
|
||||
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
||||
embeddings.append(self.embed_fn(rays_points_world))
|
||||
|
||||
if fun_viewpool is not None:
|
||||
@@ -169,19 +164,13 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
embedding = torch.cat(embeddings, dim=-1)
|
||||
x = embedding
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
||||
for layer_idx in range(self.num_layers - 1):
|
||||
if layer_idx in self.skip_in:
|
||||
x = torch.cat([x, embedding], dim=-1) / 2**0.5
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
x = self.linear_layers[layer_idx](x)
|
||||
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
|
||||
if layer_idx < self.num_layers - 2:
|
||||
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
|
||||
x = self.softplus(x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -123,10 +123,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||
|
||||
@staticmethod
|
||||
@@ -197,8 +195,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
embeds = create_embeddings_for_implicit_function(
|
||||
xyz_world=rays_points_world,
|
||||
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
|
||||
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
|
||||
# got `Union[None, Tensor, Module]`.
|
||||
xyz_embedding_function=(
|
||||
self.harmonic_embedding_xyz if self.input_xyz else None
|
||||
),
|
||||
@@ -210,14 +206,12 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
)
|
||||
|
||||
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
features = self.xyz_encoder(embeds)
|
||||
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
|
||||
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
||||
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
|
||||
features = features.reshape(*rays_points_world.shape[:-1], -1)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
raw_densities = self.density_layer(features)
|
||||
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
|
||||
|
||||
@@ -225,8 +219,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
|
||||
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
directions = ray_bundle.directions @ camera.R
|
||||
else:
|
||||
directions = ray_bundle.directions
|
||||
|
||||
@@ -103,8 +103,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
||||
|
||||
embeds = create_embeddings_for_implicit_function(
|
||||
xyz_world=rays_points_world,
|
||||
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
|
||||
# got `Union[Tensor, Module]`.
|
||||
xyz_embedding_function=self._harmonic_embedding,
|
||||
global_code=global_code,
|
||||
fun_viewpool=fun_viewpool,
|
||||
@@ -114,7 +112,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
||||
|
||||
# Before running the network, we have to resize embeds to ndims=3,
|
||||
# otherwise the SRN layers consume huge amounts of memory.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
raymarch_features = self._net(
|
||||
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
|
||||
)
|
||||
@@ -169,9 +166,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
rays_embedding = self._harmonic_embedding(rays_directions_normed)
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self._color_layer((features, rays_embedding))
|
||||
|
||||
def forward(
|
||||
@@ -200,7 +195,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
denoting the color of each ray point.
|
||||
"""
|
||||
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
features = self._net(raymarch_features)
|
||||
# features.shape = [minibatch x ... x self.n_hidden_units]
|
||||
|
||||
@@ -208,8 +202,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
|
||||
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
directions = ray_bundle.directions @ camera.R
|
||||
else:
|
||||
directions = ray_bundle.directions
|
||||
@@ -217,7 +209,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
||||
# NNs operate on the flattenned rays; reshaping to the correct spatial size
|
||||
features = features.reshape(*raymarch_features.shape[:-1], -1)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
raw_densities = self._density_layer(features)
|
||||
|
||||
rays_colors = self._get_colors(features, directions)
|
||||
@@ -278,7 +269,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
srn_raymarch_function.
|
||||
"""
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
net = self._hypernet(global_code)
|
||||
|
||||
# use the hyper-net generated network to instantiate the raymarch module
|
||||
@@ -306,6 +296,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
global_code=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if global_code is None:
|
||||
raise ValueError("SRN Hypernetwork requires a non-trivial global code.")
|
||||
|
||||
@@ -313,8 +304,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
# across LSTM iterations for the same global_code.
|
||||
if self.cached_srn_raymarch_function is None:
|
||||
# generate the raymarching network from the hypernet
|
||||
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
|
||||
# `cached_srn_raymarch_function`.
|
||||
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
|
||||
(srn_raymarch_function,) = cast(
|
||||
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
|
||||
@@ -342,7 +331,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
def create_raymarch_function(self) -> None:
|
||||
self.raymarch_function = SRNRaymarchFunction(
|
||||
latent_dim=self.latent_dim,
|
||||
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||
**self.raymarch_function_args,
|
||||
)
|
||||
|
||||
@@ -401,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
self.hypernet = SRNRaymarchHyperNet(
|
||||
latent_dim=self.latent_dim,
|
||||
latent_dim_hypernet=self.latent_dim_hypernet,
|
||||
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||
**self.hypernet_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ def create_embeddings_for_implicit_function(
|
||||
xyz_embedding_function: Optional[Callable],
|
||||
diag_cov: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
|
||||
|
||||
if xyz_in_camera_coords:
|
||||
@@ -63,6 +64,7 @@ def create_embeddings_for_implicit_function(
|
||||
0,
|
||||
)
|
||||
else:
|
||||
|
||||
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
|
||||
embeds = embeds.reshape(
|
||||
bs,
|
||||
|
||||
@@ -269,7 +269,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
for name, tensor in vars(grid_values_with_wanted_resolution).items()
|
||||
}
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
return self.values_type(**params), True
|
||||
|
||||
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
|
||||
@@ -883,7 +882,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
torch.Tensor of shape (..., n_features)
|
||||
"""
|
||||
locator = self._get_volume_locator()
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||
@@ -897,7 +895,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
replace current parameters
|
||||
"""
|
||||
if self.hold_voxel_grid_as_parameters:
|
||||
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{
|
||||
k: torch.nn.Parameter(val)
|
||||
@@ -948,7 +945,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
Returns:
|
||||
True if parameter change has happened else False.
|
||||
"""
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
grid_values, change = self.voxel_grid.change_resolution(
|
||||
grid_values, epoch=epoch
|
||||
@@ -996,21 +992,16 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
'''
|
||||
new_params = {}
|
||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||
# function.
|
||||
for name in self.params:
|
||||
key = prefix + "params." + name
|
||||
if key in state_dict:
|
||||
new_params[name] = torch.zeros_like(state_dict[key])
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
|
||||
|
||||
def get_device(self) -> torch.device:
|
||||
"""
|
||||
Returns torch.device on which module parameters are located
|
||||
"""
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||
# not a function.
|
||||
return next(val for val in self.params.values() if val is not None).device
|
||||
|
||||
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
|
||||
@@ -1027,7 +1018,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
locator = self._get_volume_locator()
|
||||
# torch.nn.modules.module.Module]` is not a function.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
old_grid_values = self.voxel_grid.values_type(**self.params)
|
||||
new_grid_values = self.voxel_grid.crop_world(
|
||||
min_point, max_point, old_grid_values, locator
|
||||
@@ -1035,7 +1025,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
grid_values, _ = self.voxel_grid.change_resolution(
|
||||
new_grid_values, grid_values_with_wanted_resolution=old_grid_values
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{
|
||||
k: torch.nn.Parameter(val)
|
||||
|
||||
@@ -192,26 +192,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
run_auto_creation(self)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `voxel_grid_scaffold`.
|
||||
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `harmonic_embedder_xyz_density`.
|
||||
self.harmonic_embedder_xyz_density = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_density_args
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `harmonic_embedder_xyz_color`.
|
||||
self.harmonic_embedder_xyz_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_xyz_color_args
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `harmonic_embedder_dir_color`.
|
||||
self.harmonic_embedder_dir_color = HarmonicEmbedding(
|
||||
**self.harmonic_embedder_dir_color_args
|
||||
)
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `_scaffold_ready`.
|
||||
self._scaffold_ready = False
|
||||
|
||||
def forward(
|
||||
@@ -262,7 +252,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
# ########## filter the points using the scaffold ########## #
|
||||
if self._scaffold_ready and self.scaffold_filter_points:
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
points = points[non_empty_points]
|
||||
if len(points) == 0:
|
||||
@@ -374,7 +363,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
feature dimensionality which `decoder_density` returns
|
||||
"""
|
||||
embeds_density = self.voxel_grid_density(points)
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
|
||||
# shape = [..., density_dim]
|
||||
return self.decoder_density(harmonic_embedding_density)
|
||||
@@ -409,8 +397,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
if self.xyz_ray_dir_in_camera_coords:
|
||||
if camera is None:
|
||||
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
|
||||
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
|
||||
# `Union[Tensor, Module]`.
|
||||
directions = directions @ camera.R
|
||||
|
||||
# ########## get voxel grid output ########## #
|
||||
@@ -419,13 +405,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
# ########## embed with the harmonic function ########## #
|
||||
# Obtain the harmonic embedding of the voxel grid output.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
|
||||
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
harmonic_embedding_dir = self.harmonic_embedder_dir_color(
|
||||
rays_directions_normed
|
||||
)
|
||||
@@ -494,11 +478,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
an object inside, else False.
|
||||
"""
|
||||
# find bounding box
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
||||
# `get_grid_points`.
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
|
||||
non_zero_idxs = torch.nonzero(occupancy)
|
||||
if len(non_zero_idxs) == 0:
|
||||
@@ -530,8 +511,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
"""
|
||||
|
||||
planes = []
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
||||
# `get_grid_points`.
|
||||
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
|
||||
|
||||
chunk_size = (
|
||||
@@ -551,10 +530,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
stride=1,
|
||||
)
|
||||
occupancy_cube = density_cube > self.scaffold_empty_space_threshold
|
||||
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `params`.
|
||||
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
|
||||
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
|
||||
# `_scaffold_ready`.
|
||||
self._scaffold_ready = True
|
||||
|
||||
return False
|
||||
@@ -571,8 +547,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_density_args
|
||||
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_density_args
|
||||
@@ -601,8 +575,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
decoding function to this value.
|
||||
"""
|
||||
grid_args = self.voxel_grid_color_args
|
||||
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
|
||||
|
||||
embedder_args = self.harmonic_embedder_xyz_color_args
|
||||
@@ -636,9 +608,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
`self.voxel_grid_density`
|
||||
"""
|
||||
return VoxelGridModule(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
extents=self.voxel_grid_density_args["extents"],
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
translation=self.voxel_grid_density_args["translation"],
|
||||
voxel_grid_class_type="FullResolutionVoxelGrid",
|
||||
hold_voxel_grid_as_parameters=False,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -297,8 +298,9 @@ class ViewMetrics(ViewMetricsBase):
|
||||
_rgb_metrics(
|
||||
image_rgb,
|
||||
image_rgb_pred,
|
||||
masks=fg_probability,
|
||||
masks_crop=mask_crop,
|
||||
fg_probability,
|
||||
fg_probability_pred,
|
||||
mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -308,21 +310,9 @@ class ViewMetrics(ViewMetricsBase):
|
||||
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
if torch.is_autocast_enabled():
|
||||
# To avoid issues with mixed precision
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred.logit(),
|
||||
fg_probability,
|
||||
mask=mask_crop,
|
||||
pred_logits=True,
|
||||
)
|
||||
else:
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred,
|
||||
fg_probability,
|
||||
mask=mask_crop,
|
||||
pred_logits=False,
|
||||
)
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
|
||||
if depth_map is not None and depth_map_pred is not None:
|
||||
assert mask_crop is not None
|
||||
@@ -334,11 +324,7 @@ class ViewMetrics(ViewMetricsBase):
|
||||
if fg_probability is not None:
|
||||
mask = fg_probability * mask_crop
|
||||
_, abs_ = utils.eval_depth(
|
||||
depth_map_pred,
|
||||
depth_map,
|
||||
get_best_scale=True,
|
||||
mask=mask,
|
||||
crop=0,
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||
)
|
||||
metrics["depth_abs_fg"] = abs_.mean()
|
||||
|
||||
@@ -360,26 +346,18 @@ class ViewMetrics(ViewMetricsBase):
|
||||
return metrics
|
||||
|
||||
|
||||
def _rgb_metrics(
|
||||
images,
|
||||
images_pred,
|
||||
masks=None,
|
||||
masks_crop=None,
|
||||
huber_scaling: float = 0.03,
|
||||
):
|
||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||
assert masks_crop is not None
|
||||
if images.shape[1] != images_pred.shape[1]:
|
||||
raise ValueError(
|
||||
f"Network output's RGB images had {images_pred.shape[1]} "
|
||||
f"channels. {images.shape[1]} expected."
|
||||
)
|
||||
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
|
||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||
crop_mass = masks_crop.sum().clamp(1.0)
|
||||
results = {
|
||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
|
||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||
}
|
||||
|
||||
@@ -135,7 +135,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
break
|
||||
|
||||
# run the lstm marcher
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
state_h, state_c = self._lstm(
|
||||
raymarch_features.view(-1, raymarch_features.shape[-1]),
|
||||
states[-1],
|
||||
@@ -143,7 +142,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
if state_h.requires_grad:
|
||||
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
|
||||
# predict the next step size
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
|
||||
# log the lstm states
|
||||
states.append((state_h, state_c))
|
||||
|
||||
@@ -207,7 +207,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
"""
|
||||
sample_mask = None
|
||||
if (
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
|
||||
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
|
||||
and mask is not None
|
||||
):
|
||||
@@ -224,7 +223,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
EvaluationMode.EVALUATION: self._evaluation_raysampler,
|
||||
}[evaluation_mode]
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
ray_bundle = raysampler(
|
||||
cameras=cameras,
|
||||
mask=sample_mask,
|
||||
@@ -242,8 +240,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
"Heterogeneous ray bundle is not supported for conical frustum computation yet"
|
||||
)
|
||||
elif self.cast_ray_bundle_as_cone:
|
||||
# pyre-fixme[9]: pixel_hw has type `Tuple[float, float]`; used as
|
||||
# `Tuple[Union[Tensor, Module], Union[Tensor, Module]]`.
|
||||
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
||||
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
||||
return ImplicitronRayBundle(
|
||||
|
||||
@@ -179,10 +179,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
rays_densities = torch.relu(rays_densities)
|
||||
|
||||
weighted_densities = deltas * rays_densities
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
capped_densities = self._capping_function(weighted_densities)
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
rays_opacities = self._capping_function(
|
||||
torch.cumsum(weighted_densities, dim=-1)
|
||||
)
|
||||
@@ -192,7 +190,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
)
|
||||
absorption_shifted[..., : self.surface_thickness] = 1.0
|
||||
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
weights = self._weight_function(capped_densities, absorption_shifted)
|
||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||
@@ -200,8 +197,6 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
alpha = opacities if self.blend_output else 1
|
||||
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
|
||||
raise ValueError("Wrong number of background color channels.")
|
||||
# pyre-fixme[58]: `*` is not supported for operand types `int` and
|
||||
# `Union[Tensor, Module]`.
|
||||
features = alpha * features + (1 - opacities) * self._bg_color
|
||||
|
||||
return RendererOutput(
|
||||
|
||||
@@ -61,7 +61,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
|
||||
def create_ray_tracer(self) -> None:
|
||||
self.ray_tracer = RayTracing(
|
||||
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
|
||||
**self.ray_tracer_args,
|
||||
object_bounding_sphere=self.object_bounding_sphere,
|
||||
)
|
||||
@@ -150,8 +149,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
n_eik_points,
|
||||
3,
|
||||
# but got `Union[device, Tensor, Module]`.
|
||||
# pyre-fixme[6]: For 3rd argument expected `Union[None, int, str,
|
||||
# device]` but got `Union[device, Tensor, Module]`.
|
||||
device=self._bg_color.device,
|
||||
).uniform_(-eik_bounding_box, eik_bounding_box)
|
||||
eikonal_pixel_points = points.clone()
|
||||
@@ -208,7 +205,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
]
|
||||
normals_full.view(-1, 3)[surface_mask] = normals
|
||||
render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
|
||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||
self._rgb_network(
|
||||
features,
|
||||
differentiable_surface_points[None],
|
||||
@@ -220,7 +216,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
)
|
||||
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
|
||||
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
||||
-self.soft_mask_alpha * sdf_output[~surface_mask]
|
||||
-self.soft_mask_alpha
|
||||
* sdf_output[~surface_mask]
|
||||
)
|
||||
|
||||
# scatter points with surface_mask
|
||||
|
||||
@@ -532,7 +532,6 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
|
||||
|
||||
# does not produce nans randomly unlike get_camera_center() below
|
||||
cam_centers_rep = -torch.bmm(
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
camera_rep.T[:, None],
|
||||
camera_rep.R.permute(0, 2, 1),
|
||||
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
||||
|
||||
@@ -209,7 +209,6 @@ def handle_seq_id(
|
||||
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
|
||||
# pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has
|
||||
# no attribute `to`.
|
||||
# pyre-fixme[7]: Expected `LongTensor` but got `Tensor`.
|
||||
return seq_id.to(device)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ def cleanup_eval_depth(
|
||||
sigma: float = 0.01,
|
||||
image=None,
|
||||
):
|
||||
|
||||
ba, _, H, W = depth.shape
|
||||
|
||||
pcl = point_cloud.points_padded()
|
||||
|
||||
@@ -6,15 +6,12 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_depth(
|
||||
pred: torch.Tensor,
|
||||
@@ -24,8 +21,6 @@ def eval_depth(
|
||||
get_best_scale: bool = True,
|
||||
mask_thr: float = 0.5,
|
||||
best_scale_clamp_thr: float = 1e-4,
|
||||
use_disparity: bool = False,
|
||||
disparity_eps: float = 1e-4,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Evaluate the depth error between the prediction `pred` and the ground
|
||||
@@ -69,13 +64,6 @@ def eval_depth(
|
||||
# s.t. we get best possible mse error
|
||||
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
|
||||
pred = pred * scale_best[:, None, None, None]
|
||||
if use_disparity:
|
||||
gt = torch.div(1.0, (gt + disparity_eps))
|
||||
pred = torch.div(1.0, (pred + disparity_eps))
|
||||
scale_best = estimate_depth_scale_factor(
|
||||
pred, gt, dmask, best_scale_clamp_thr
|
||||
).detach()
|
||||
pred = pred * scale_best[:, None, None, None]
|
||||
|
||||
df = gt - pred
|
||||
|
||||
@@ -129,7 +117,6 @@ def calc_bce(
|
||||
pred_eps: float = 0.01,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
lerp_bound: Optional[float] = None,
|
||||
pred_logits: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the binary cross entropy.
|
||||
@@ -152,23 +139,9 @@ def calc_bce(
|
||||
weight = torch.ones_like(gt) * mask
|
||||
|
||||
if lerp_bound is not None:
|
||||
# binary_cross_entropy_lerp requires pred to be in [0, 1]
|
||||
if pred_logits:
|
||||
pred = F.sigmoid(pred)
|
||||
|
||||
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
|
||||
else:
|
||||
if pred_logits:
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred,
|
||||
gt,
|
||||
reduction="none",
|
||||
weight=weight,
|
||||
)
|
||||
else:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
|
||||
|
||||
return loss.mean()
|
||||
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
|
||||
|
||||
|
||||
def binary_cross_entropy_lerp(
|
||||
|
||||
@@ -111,10 +111,10 @@ def load_model(fl, map_location: Optional[dict]):
|
||||
flstats = get_stats_path(fl)
|
||||
flmodel = get_model_path(fl)
|
||||
flopt = get_optimizer_path(fl)
|
||||
model_state_dict = torch.load(flmodel, map_location=map_location, weights_only=True)
|
||||
model_state_dict = torch.load(flmodel, map_location=map_location)
|
||||
stats = load_stats(flstats)
|
||||
if os.path.isfile(flopt):
|
||||
optimizer = torch.load(flopt, map_location=map_location, weights_only=True)
|
||||
optimizer = torch.load(flopt, map_location=map_location)
|
||||
else:
|
||||
optimizer = None
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ def render_point_cloud_pytorch3d(
|
||||
bin_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# feature dimension
|
||||
featdim = point_cloud.features_packed().shape[-1]
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ class AverageMeter:
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1, epoch=0):
|
||||
|
||||
# make sure the history is of the same len as epoch
|
||||
while len(self.history) <= epoch:
|
||||
self.history.append([])
|
||||
@@ -114,6 +115,7 @@ class Stats:
|
||||
visdom_server="http://localhost",
|
||||
visdom_port=8097,
|
||||
):
|
||||
|
||||
self.log_vars = log_vars
|
||||
self.visdom_env = visdom_env
|
||||
self.visdom_server = visdom_server
|
||||
@@ -200,6 +202,7 @@ class Stats:
|
||||
self.log_vars.append(add_log_var)
|
||||
|
||||
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
||||
|
||||
if self.epoch == -1: # uninitialized
|
||||
logger.warning(
|
||||
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
|
||||
@@ -216,6 +219,7 @@ class Stats:
|
||||
epoch = self.epoch
|
||||
|
||||
for stat in self.log_vars:
|
||||
|
||||
if stat not in self.stats[stat_set]:
|
||||
self.stats[stat_set][stat] = AverageMeter()
|
||||
|
||||
@@ -244,6 +248,7 @@ class Stats:
|
||||
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
|
||||
|
||||
def get_epoch_averages(self, epoch=None):
|
||||
|
||||
stat_sets = list(self.stats.keys())
|
||||
|
||||
if epoch is None:
|
||||
@@ -340,6 +345,7 @@ class Stats:
|
||||
def plot_stats(
|
||||
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
|
||||
):
|
||||
|
||||
# use the cached visdom env if none supplied
|
||||
if visdom_env is None:
|
||||
visdom_env = self.visdom_env
|
||||
@@ -443,6 +449,7 @@ class Stats:
|
||||
warnings.warn("Cant dump stats due to insufficient permissions!")
|
||||
|
||||
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
|
||||
|
||||
stat_sets = list(self.stats.keys())
|
||||
|
||||
# remove the additional log_vars
|
||||
@@ -483,12 +490,11 @@ class Stats:
|
||||
for ep in range(lastep):
|
||||
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
|
||||
epoch_generated = self.stats[stat_set][stat].get_epoch()
|
||||
assert epoch_generated == self.epoch + 1, (
|
||||
"bad epoch of synchronized log_var! %d vs %d"
|
||||
% (
|
||||
self.epoch + 1,
|
||||
epoch_generated,
|
||||
)
|
||||
assert (
|
||||
epoch_generated == self.epoch + 1
|
||||
), "bad epoch of synchronized log_var! %d vs %d" % (
|
||||
self.epoch + 1,
|
||||
epoch_generated,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,17 +16,8 @@ from typing import Optional, Tuple, Union
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
|
||||
_NO_TORCHVISION = False
|
||||
try:
|
||||
import torchvision
|
||||
except ImportError:
|
||||
_NO_TORCHVISION = True
|
||||
|
||||
|
||||
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -45,7 +36,6 @@ class VideoWriter:
|
||||
fps: int = 20,
|
||||
output_format: str = "visdom",
|
||||
rmdir_allowed: bool = False,
|
||||
use_torchvision_video_writer: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -59,8 +49,6 @@ class VideoWriter:
|
||||
is supported.
|
||||
rmdir_allowed: If `True` delete and create `cache_dir` in case
|
||||
it is not empty.
|
||||
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
|
||||
to write the video
|
||||
"""
|
||||
self.rmdir_allowed = rmdir_allowed
|
||||
self.output_format = output_format
|
||||
@@ -68,14 +56,10 @@ class VideoWriter:
|
||||
self.out_path = out_path
|
||||
self.cache_dir = cache_dir
|
||||
self.ffmpeg_bin = ffmpeg_bin
|
||||
self.use_torchvision_video_writer = use_torchvision_video_writer
|
||||
self.frames = []
|
||||
self.regexp = "frame_%08d.png"
|
||||
self.frame_num = 0
|
||||
|
||||
if self.use_torchvision_video_writer:
|
||||
assert not _NO_TORCHVISION, "torchvision not available"
|
||||
|
||||
if self.cache_dir is not None:
|
||||
self.tmp_dir = None
|
||||
if os.path.isdir(self.cache_dir):
|
||||
@@ -130,7 +114,7 @@ class VideoWriter:
|
||||
resize = im.size
|
||||
# make sure size is divisible by 2
|
||||
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
||||
|
||||
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
|
||||
im = im.resize(resize, Image.ANTIALIAS)
|
||||
im.save(outfile)
|
||||
|
||||
@@ -155,56 +139,38 @@ class VideoWriter:
|
||||
# got `Optional[str]`.
|
||||
regexp = os.path.join(self.cache_dir, self.regexp)
|
||||
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
if self.output_format == "visdom": # works for ppt too
|
||||
# Video codec parameters
|
||||
video_codec = "h264"
|
||||
crf = "18"
|
||||
b = "2000k"
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
if self.use_torchvision_video_writer:
|
||||
torchvision.io.write_video(
|
||||
self.out_path,
|
||||
torch.stack(
|
||||
[torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
|
||||
),
|
||||
fps=self.fps,
|
||||
video_codec=video_codec,
|
||||
options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
"h264",
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
"18",
|
||||
"-b",
|
||||
"2000k",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
self.out_path,
|
||||
]
|
||||
if quiet:
|
||||
subprocess.check_call(
|
||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
)
|
||||
|
||||
else:
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
video_codec,
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
crf,
|
||||
"-b",
|
||||
b,
|
||||
"-pix_fmt",
|
||||
pix_fmt,
|
||||
self.out_path,
|
||||
]
|
||||
if quiet:
|
||||
subprocess.check_call(
|
||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
)
|
||||
else:
|
||||
subprocess.check_call(args)
|
||||
subprocess.check_call(args)
|
||||
else:
|
||||
raise ValueError("no such output type %s" % str(self.output_format))
|
||||
|
||||
|
||||
@@ -163,8 +163,9 @@ def _read_chunks(
|
||||
if binary_data is not None:
|
||||
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
|
||||
|
||||
assert binary_data is not None
|
||||
|
||||
# pyre-fixme[7]: Expected `Optional[Tuple[Dict[str, typing.Any],
|
||||
# ndarray[typing.Any, typing.Any]]]` but got `Tuple[typing.Any,
|
||||
# Optional[ndarray[typing.Any, dtype[typing.Any]]]]`.
|
||||
return json_data, binary_data
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
# pyre-unsafe
|
||||
|
||||
"""This module implements utility functions for loading .mtl files and textures."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
|
||||
"""This module implements utility functions for loading and saving meshes."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
@@ -814,6 +813,7 @@ def _save(
|
||||
save_texture: bool = False,
|
||||
save_normals: bool = False,
|
||||
) -> None:
|
||||
|
||||
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
|
||||
message = "'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
@@ -14,7 +14,6 @@ meshes as .off files.
|
||||
This format is introduced, for example, at
|
||||
http://www.geomview.org/docs/html/OFF.html .
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import cast, Optional, Tuple, Union
|
||||
|
||||
@@ -85,7 +84,7 @@ def _read_faces_lump(
|
||||
)
|
||||
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces)
|
||||
except ValueError as e:
|
||||
if n_faces > 1 and "number of columns" in e.args[0]:
|
||||
if n_faces > 1 and "Wrong number of columns" in e.args[0]:
|
||||
file.seek(old_offset)
|
||||
return None
|
||||
raise ValueError("Not enough face data.") from None
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
This module implements utility functions for loading and saving
|
||||
meshes and point clouds as PLY files.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import struct
|
||||
@@ -1247,10 +1246,13 @@ def _save_ply(
|
||||
return
|
||||
|
||||
color_np_type = np.ubyte if colors_as_uint8 else np.float32
|
||||
verts_dtype: list = [("verts", np.float32, 3)]
|
||||
verts_dtype = [("verts", np.float32, 3)]
|
||||
if verts_normals is not None:
|
||||
verts_dtype.append(("normals", np.float32, 3))
|
||||
if verts_colors is not None:
|
||||
# pyre-fixme[6]: For 1st argument expected `Tuple[str,
|
||||
# Type[floating[_32Bit]], int]` but got `Tuple[str,
|
||||
# Type[Union[floating[_32Bit], unsignedinteger[typing.Any]]], int]`.
|
||||
verts_dtype.append(("colors", color_np_type, 3))
|
||||
|
||||
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)
|
||||
|
||||
@@ -122,17 +122,12 @@ def corresponding_cameras_alignment(
|
||||
|
||||
# create a new cameras object and set the R and T accordingly
|
||||
cameras_src_aligned = cameras_src.clone()
|
||||
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
|
||||
cameras_src_aligned.T = (
|
||||
torch.bmm(
|
||||
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
|
||||
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
|
||||
# `Union[Tensor, Module]`.
|
||||
cameras_src.R,
|
||||
)[:, 0]
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
|
||||
# float, int, Tensor]) -> Tensor, Tensor, Module]` is not a function.
|
||||
+ cameras_src.T * align_t_s
|
||||
)
|
||||
|
||||
@@ -180,7 +175,6 @@ def _align_camera_extrinsics(
|
||||
R_A = (U V^T)^T
|
||||
```
|
||||
"""
|
||||
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
|
||||
U, _, V = torch.svd(RRcov)
|
||||
align_t_R = V @ U.t()
|
||||
@@ -210,11 +204,7 @@ def _align_camera_extrinsics(
|
||||
T_A = mean(B) - mean(A) * s_A
|
||||
```
|
||||
"""
|
||||
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
|
||||
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
|
||||
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
|
||||
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
|
||||
Amu = A.mean(0, keepdim=True)
|
||||
Bmu = B.mean(0, keepdim=True)
|
||||
|
||||
@@ -62,7 +62,7 @@ def cubify(
|
||||
*,
|
||||
feats: Optional[torch.Tensor] = None,
|
||||
device=None,
|
||||
align: str = "topleft",
|
||||
align: str = "topleft"
|
||||
) -> Meshes:
|
||||
r"""
|
||||
Converts a voxel to a mesh by replacing each occupied voxel with a cube
|
||||
|
||||
@@ -85,6 +85,7 @@ class _points_to_volumes_function(Function):
|
||||
align_corners: bool,
|
||||
splat: bool,
|
||||
):
|
||||
|
||||
ctx.mark_dirty(volume_densities, volume_features)
|
||||
|
||||
N, P, D = points_3d.shape
|
||||
@@ -496,6 +497,7 @@ def _check_points_to_volumes_inputs(
|
||||
grid_sizes: torch.LongTensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
max_grid_size = grid_sizes.max(dim=0).values
|
||||
if torch.prod(max_grid_size) > volume_densities.shape[1]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
This module implements utility functions for sampling points from
|
||||
batches of meshes.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
@@ -353,16 +353,45 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
|
||||
# e.g. verts_per_mesh = (4, 5, 6)
|
||||
# e.g. edges_per_mesh = (5, 7, 9)
|
||||
|
||||
rng = torch.arange(verts_per_mesh.shape[0], device=device) # (0,1,2)
|
||||
verts_nums = rng.repeat_interleave(
|
||||
verts_per_mesh
|
||||
) # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
|
||||
edges_nums = rng.repeat_interleave(
|
||||
edges_per_mesh
|
||||
) # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
|
||||
nums = torch.cat([verts_nums, edges_nums])
|
||||
V = verts_per_mesh.sum() # e.g. 15
|
||||
E = edges_per_mesh.sum() # e.g. 21
|
||||
|
||||
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
|
||||
edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0) # (N,) e.g. (5, 12, 21)
|
||||
|
||||
v_to_e_idx = verts_per_mesh_cumsum.clone()
|
||||
|
||||
# vertex to edge index.
|
||||
v_to_e_idx[1:] += edges_per_mesh_cumsum[
|
||||
:-1
|
||||
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
|
||||
|
||||
# vertex to edge offset.
|
||||
v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
|
||||
v_to_e_offset[1:] += edges_per_mesh_cumsum[
|
||||
:-1
|
||||
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
|
||||
e_to_v_idx = (
|
||||
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
|
||||
) # (4, 9) + (5, 12) = (9, 21)
|
||||
e_to_v_offset = (
|
||||
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
|
||||
) # (4, 9) - (5, 12) - 15 = (-16, -18)
|
||||
|
||||
# Add one new vertex per edge.
|
||||
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
|
||||
idx_diffs[v_to_e_idx] += v_to_e_offset
|
||||
idx_diffs[e_to_v_idx] += e_to_v_offset
|
||||
|
||||
# e.g.
|
||||
# [
|
||||
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
|
||||
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
|
||||
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
|
||||
# ]
|
||||
|
||||
verts_idx = idx_diffs.cumsum(dim=0) - 1
|
||||
|
||||
verts_idx = torch.argsort(nums, stable=True)
|
||||
# e.g.
|
||||
# [
|
||||
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
|
||||
@@ -371,6 +400,7 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
|
||||
# ]
|
||||
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
|
||||
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
|
||||
|
||||
return verts_idx
|
||||
|
||||
|
||||
@@ -391,9 +421,44 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
|
||||
"""
|
||||
# e.g. faces_per_mesh = [2, 5, 3]
|
||||
|
||||
rng = torch.arange(faces_per_mesh.shape[0], device=device) # (0,1,2)
|
||||
nums = rng.repeat_interleave(faces_per_mesh).repeat(4)
|
||||
faces_idx = torch.argsort(nums, stable=True)
|
||||
F = faces_per_mesh.sum() # e.g. 10
|
||||
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
|
||||
|
||||
switch1_idx = faces_per_mesh_cumsum.clone()
|
||||
switch1_idx[1:] += (
|
||||
3 * faces_per_mesh_cumsum[:-1]
|
||||
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
|
||||
|
||||
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
|
||||
switch2_idx[1:] += (
|
||||
2 * faces_per_mesh_cumsum[:-1]
|
||||
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
|
||||
|
||||
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
|
||||
switch3_idx[1:] += faces_per_mesh_cumsum[
|
||||
:-1
|
||||
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
|
||||
|
||||
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
|
||||
|
||||
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
|
||||
|
||||
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
|
||||
# typing.Tuple[int, ...]]` but got `Tensor`.
|
||||
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
|
||||
idx_diffs[switch1_idx] += switch123_offset
|
||||
idx_diffs[switch2_idx] += switch123_offset
|
||||
idx_diffs[switch3_idx] += switch123_offset
|
||||
idx_diffs[switch4_idx] -= 3 * F
|
||||
|
||||
# e.g
|
||||
# [
|
||||
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
|
||||
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
|
||||
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
|
||||
# ]
|
||||
|
||||
faces_idx = idx_diffs.cumsum(dim=0) - 1
|
||||
|
||||
# e.g.
|
||||
# [
|
||||
|
||||
@@ -65,11 +65,7 @@ def _opencv_from_cameras_projection(
|
||||
cameras: PerspectiveCameras,
|
||||
image_size: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
|
||||
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a function.
|
||||
R_pytorch3d = cameras.R.clone()
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
|
||||
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a function.
|
||||
T_pytorch3d = cameras.T.clone()
|
||||
focal_pytorch3d = cameras.focal_length
|
||||
p0_pytorch3d = cameras.principal_point
|
||||
@@ -110,32 +106,30 @@ def _pulsar_from_opencv_projection(
|
||||
|
||||
# Validate parameters.
|
||||
image_size_wh = image_size.to(R).flip(dims=(1,))
|
||||
assert torch.all(image_size_wh > 0), (
|
||||
"height and width must be positive but min is: %s"
|
||||
% (str(image_size_wh.min().item()))
|
||||
assert torch.all(
|
||||
image_size_wh > 0
|
||||
), "height and width must be positive but min is: %s" % (
|
||||
str(image_size_wh.min().item())
|
||||
)
|
||||
assert camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3, (
|
||||
"Incorrect camera matrix shape: expected 3x3 but got %dx%d"
|
||||
% (
|
||||
camera_matrix.size(1),
|
||||
camera_matrix.size(2),
|
||||
)
|
||||
assert (
|
||||
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
|
||||
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
|
||||
camera_matrix.size(1),
|
||||
camera_matrix.size(2),
|
||||
)
|
||||
assert R.size(1) == 3 and R.size(2) == 3, (
|
||||
"Incorrect R shape: expected 3x3 but got %dx%d"
|
||||
% (
|
||||
R.size(1),
|
||||
R.size(2),
|
||||
)
|
||||
assert (
|
||||
R.size(1) == 3 and R.size(2) == 3
|
||||
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
|
||||
R.size(1),
|
||||
R.size(2),
|
||||
)
|
||||
if len(tvec.size()) == 2:
|
||||
tvec = tvec.unsqueeze(2)
|
||||
assert tvec.size(1) == 3 and tvec.size(2) == 1, (
|
||||
"Incorrect tvec shape: expected 3x1 but got %dx%d"
|
||||
% (
|
||||
tvec.size(1),
|
||||
tvec.size(2),
|
||||
)
|
||||
assert (
|
||||
tvec.size(1) == 3 and tvec.size(2) == 1
|
||||
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
|
||||
tvec.size(1),
|
||||
tvec.size(2),
|
||||
)
|
||||
# Check batch size.
|
||||
batch_size = camera_matrix.size(0)
|
||||
@@ -143,12 +137,11 @@ def _pulsar_from_opencv_projection(
|
||||
batch_size,
|
||||
R.size(0),
|
||||
)
|
||||
assert tvec.size(0) == batch_size, (
|
||||
"Expected tvec to have batch size %d. Has size %d."
|
||||
% (
|
||||
batch_size,
|
||||
tvec.size(0),
|
||||
)
|
||||
assert (
|
||||
tvec.size(0) == batch_size
|
||||
), "Expected tvec to have batch size %d. Has size %d." % (
|
||||
batch_size,
|
||||
tvec.size(0),
|
||||
)
|
||||
# Check image sizes.
|
||||
image_w = image_size_wh[0, 0]
|
||||
|
||||
@@ -203,9 +203,7 @@ class CamerasBase(TensorProperties):
|
||||
"""
|
||||
R: torch.Tensor = kwargs.get("R", self.R)
|
||||
T: torch.Tensor = kwargs.get("T", self.T)
|
||||
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
|
||||
self.R = R
|
||||
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
|
||||
self.T = T
|
||||
world_to_view_transform = get_world_to_view_transform(R=R, T=T)
|
||||
return world_to_view_transform
|
||||
@@ -230,9 +228,7 @@ class CamerasBase(TensorProperties):
|
||||
a Transform3d object which represents a batch of transforms
|
||||
of shape (N, 3, 3)
|
||||
"""
|
||||
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
|
||||
self.R: torch.Tensor = kwargs.get("R", self.R)
|
||||
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
|
||||
self.T: torch.Tensor = kwargs.get("T", self.T)
|
||||
world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
|
||||
view_to_proj_transform = self.get_projection_transform(**kwargs)
|
||||
@@ -1176,12 +1172,7 @@ class PerspectiveCameras(CamerasBase):
|
||||
|
||||
unprojection_transform = to_camera_transform.inverse()
|
||||
xy_inv_depth = torch.cat(
|
||||
# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
|
||||
# tuple[Tensor, ...]]` but got `Tuple[Tensor, float]`.
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and
|
||||
# `Tensor`.
|
||||
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]),
|
||||
dim=-1, # type: ignore
|
||||
(xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore
|
||||
)
|
||||
return unprojection_transform.transform_points(xy_inv_depth)
|
||||
|
||||
|
||||
@@ -281,10 +281,8 @@ class FishEyeCameras(CamerasBase):
|
||||
# project from camera space to image space
|
||||
N = len(self.radial_params)
|
||||
if not self.check_input(points, N):
|
||||
msg = (
|
||||
"Expected points of (P, 3) with batch_size 1 or N, or shape (M, P, 3) \
|
||||
msg = "Expected points of (P, 3) with batch_size 1 or N, or shape (M, P, 3) \
|
||||
with batch_size 1; got points of shape %r and batch_size %r"
|
||||
)
|
||||
raise ValueError(msg % (points.shape, N))
|
||||
|
||||
if N == 1:
|
||||
|
||||
@@ -67,7 +67,7 @@ class HeterogeneousRayBundle:
|
||||
|
||||
|
||||
def ray_bundle_to_ray_points(
|
||||
ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
|
||||
ray_bundle: Union[RayBundle, HeterogeneousRayBundle]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
|
||||
|
||||
@@ -266,9 +266,7 @@ class PointLights(TensorProperties):
|
||||
shape (P, 3) or (N, H, W, K, 3).
|
||||
"""
|
||||
if self.location.ndim == points.ndim:
|
||||
# pyre-fixme[7]: Expected `Tensor` but got `Union[Tensor, Module]`.
|
||||
return self.location
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
return self.location[:, None, None, None, :]
|
||||
|
||||
def diffuse(self, normals, points) -> torch.Tensor:
|
||||
|
||||
@@ -168,7 +168,7 @@ def _get_culled_faces(face_verts: torch.Tensor, frustum: ClipFrustum) -> torch.T
|
||||
position of the clipping planes.
|
||||
|
||||
Returns:
|
||||
faces_culled: boolean tensor of size F specifying whether or not each face should be
|
||||
faces_culled: An boolean tensor of size F specifying whether or not each face should be
|
||||
culled.
|
||||
"""
|
||||
clipping_planes = (
|
||||
|
||||
@@ -726,17 +726,15 @@ class TexturesUV(TexturesBase):
|
||||
for each face
|
||||
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
|
||||
(a FloatTensor with values between 0 and 1).
|
||||
maps_ids: Used if there are to be multiple maps per face.
|
||||
This can be either a list of map_ids [(F,)]
|
||||
maps_ids: Used if there are to be multiple maps per face. This can be either a list of map_ids [(F,)]
|
||||
or a long tensor of shape (N, F) giving the id of the texture map
|
||||
for each face. If maps_ids is present, the maps has an extra dimension M
|
||||
(so maps_padded is (N, M, H, W, C) and maps_list has elements of
|
||||
shape (M, H, W, C)).
|
||||
Specifically, the color
|
||||
of a vertex V is given by an average of
|
||||
maps_padded[i, maps_ids[i, f], u, v, :]
|
||||
of a vertex V is given by an average of maps_padded[i, maps_ids[i, f], u, v, :]
|
||||
over u and v integers adjacent to
|
||||
_verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
|
||||
_verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
|
||||
align_corners: If true, the extreme values 0 and 1 for verts_uvs
|
||||
indicate the centers of the edge pixels in the maps.
|
||||
padding_mode: padding mode for outside grid values
|
||||
@@ -1239,8 +1237,7 @@ class TexturesUV(TexturesBase):
|
||||
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
|
||||
return texels
|
||||
else:
|
||||
# We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),
|
||||
# fragments.pix_to_face: (N, Ho, Wo, K)
|
||||
# We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),fragmenmts.pix_to_face: (N, Ho, Wo, K)
|
||||
# Get pixel_to_map_ids: (N, K, Ho, Wo) by indexing pix_to_face into maps_ids
|
||||
N, M, H_in, W_in, C = texture_maps.shape # 3 for RGB
|
||||
|
||||
@@ -1251,9 +1248,8 @@ class TexturesUV(TexturesBase):
|
||||
pixel_to_map_ids = (
|
||||
maps_ids_padded.flatten()
|
||||
.gather(0, pix_to_face.flatten())
|
||||
.view(N, H_out, W_out, K, 1)
|
||||
.permute(0, 3, 1, 2, 4)
|
||||
) # N x H_out x W_out x K x 1
|
||||
.view(N, K, H_out, W_out)
|
||||
)
|
||||
|
||||
# Normalize between -1 and 1 with M (number of maps)
|
||||
pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1
|
||||
@@ -1262,10 +1258,10 @@ class TexturesUV(TexturesBase):
|
||||
pixel_uvs.new_tensor([-1.0, 1.0]),
|
||||
pixel_uvs.new_tensor([1.0, -1.0]),
|
||||
pixel_uvs,
|
||||
) # N x H_out x W_out x K x 2
|
||||
)
|
||||
|
||||
# N x H_out x W_out x K x 3
|
||||
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids), dim=4)
|
||||
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4)
|
||||
# (N, M, H, W, C) -> (N, C, M, H, W)
|
||||
texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
|
||||
if texture_maps.device != pixel_uvs.device:
|
||||
@@ -1830,7 +1826,7 @@ class TexturesVertex(TexturesBase):
|
||||
representation) which overlap the pixel.
|
||||
|
||||
Returns:
|
||||
texels: A texture per pixel of shape (N, H, W, K, C).
|
||||
texels: An texture per pixel of shape (N, H, W, K, C).
|
||||
There will be one C dimensional value for each element in
|
||||
fragments.pix_to_face.
|
||||
"""
|
||||
|
||||
@@ -184,7 +184,7 @@ class EGLContext:
|
||||
"""
|
||||
# Lock used to prevent multiple threads from rendering on the same device
|
||||
# at the same time, creating/destroying contexts at the same time, etc.
|
||||
self.lock = threading.RLock()
|
||||
self.lock = threading.Lock()
|
||||
self.cuda_device_id = cuda_device_id
|
||||
self.device = _get_cuda_device(self.cuda_device_id)
|
||||
self.width = width
|
||||
@@ -224,14 +224,15 @@ class EGLContext:
|
||||
Throws:
|
||||
EGLError when the context cannot be made current or make non-current.
|
||||
"""
|
||||
with self.lock:
|
||||
egl.eglMakeCurrent(self.dpy, self.surface, self.surface, self.context)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
egl.eglMakeCurrent(
|
||||
self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT
|
||||
)
|
||||
self.lock.acquire()
|
||||
egl.eglMakeCurrent(self.dpy, self.surface, self.surface, self.context)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
egl.eglMakeCurrent(
|
||||
self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT
|
||||
)
|
||||
self.lock.release()
|
||||
|
||||
def get_context_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,6 @@ Proper Python support for pytorch requires creating a torch.autograd.function
|
||||
(independent of whether this is being done within the C++ module). This is done
|
||||
here and a torch.nn.Module is exposed for the use in more complex models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -133,7 +133,8 @@ def _get_splat_kernel_normalization(
|
||||
epsilon = 0.05
|
||||
normalization_constant = torch.exp(
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
-(offsets**2).sum(dim=1) / (2 * sigma**2)
|
||||
-(offsets**2).sum(dim=1)
|
||||
/ (2 * sigma**2)
|
||||
).sum()
|
||||
|
||||
# We add an epsilon to the normalization constant to ensure the gradient will travel
|
||||
|
||||
@@ -114,6 +114,7 @@ class TensorProperties(nn.Module):
|
||||
self.device = make_device(device)
|
||||
self._N = 0
|
||||
if kwargs is not None:
|
||||
|
||||
# broadcast all inputs which are float/int/list/tuple/tensor/array
|
||||
# set as attributes anything else e.g. strings, bools
|
||||
args_to_broadcast = {}
|
||||
@@ -438,7 +439,7 @@ def ndc_to_grid_sample_coords(
|
||||
|
||||
|
||||
def parse_image_size(
|
||||
image_size: Union[List[int], Tuple[int, int], int],
|
||||
image_size: Union[List[int], Tuple[int, int], int]
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -1531,6 +1531,7 @@ class Meshes:
|
||||
|
||||
def sample_textures(self, fragments):
|
||||
if self.textures is not None:
|
||||
|
||||
# Check dimensions of textures match that of meshes
|
||||
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
|
||||
if not shape_ok:
|
||||
|
||||
@@ -1274,7 +1274,7 @@ def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]) -> Pointclouds
|
||||
|
||||
|
||||
def join_pointclouds_as_scene(
|
||||
pointclouds: Union[Pointclouds, List[Pointclouds]],
|
||||
pointclouds: Union[Pointclouds, List[Pointclouds]]
|
||||
) -> Pointclouds:
|
||||
"""
|
||||
Joins a batch of point cloud in the form of a Pointclouds object or a list of Pointclouds
|
||||
|
||||
@@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
|
||||
return out[..., 1:]
|
||||
|
||||
|
||||
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
|
||||
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as axis/angle to rotation matrices.
|
||||
|
||||
@@ -472,91 +472,27 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.
|
||||
as a tensor of shape (..., 3), where the magnitude is
|
||||
the angle turned anticlockwise in radians around the
|
||||
vector's direction.
|
||||
fast: Whether to use the new faster implementation (based on the
|
||||
Rodrigues formula) instead of the original implementation (which
|
||||
first converted to a quaternion and then back to a rotation matrix).
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
if not fast:
|
||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||
|
||||
shape = axis_angle.shape
|
||||
device, dtype = axis_angle.device, axis_angle.dtype
|
||||
|
||||
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
|
||||
|
||||
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
|
||||
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
|
||||
cross_product_matrix = torch.stack(
|
||||
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
|
||||
).view(shape + (3,))
|
||||
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
|
||||
|
||||
identity = torch.eye(3, dtype=dtype, device=device)
|
||||
angles_sqrd = angles * angles
|
||||
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
|
||||
return (
|
||||
identity.expand(cross_product_matrix.shape)
|
||||
+ torch.sinc(angles / torch.pi) * cross_product_matrix
|
||||
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
|
||||
)
|
||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||
|
||||
|
||||
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
|
||||
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to axis/angle.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
fast: Whether to use the new faster implementation (based on the
|
||||
Rodrigues formula) instead of the original implementation (which
|
||||
first converted to a quaternion and then back to a rotation matrix).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
|
||||
"""
|
||||
if not fast:
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
omegas = torch.stack(
|
||||
[
|
||||
matrix[..., 2, 1] - matrix[..., 1, 2],
|
||||
matrix[..., 0, 2] - matrix[..., 2, 0],
|
||||
matrix[..., 1, 0] - matrix[..., 0, 1],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
|
||||
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
|
||||
angles = torch.atan2(norms, traces - 1)
|
||||
|
||||
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
|
||||
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
|
||||
|
||||
near_pi = angles.isclose(angles.new_full((1,), torch.pi)).squeeze(-1)
|
||||
|
||||
axis_angles = torch.empty_like(omegas)
|
||||
axis_angles[~near_pi] = (
|
||||
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
|
||||
)
|
||||
|
||||
# this derives from: nnT = (R + 1) / 2
|
||||
n = 0.5 * (
|
||||
matrix[near_pi][..., 0, :]
|
||||
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
|
||||
)
|
||||
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
|
||||
|
||||
return axis_angles
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
|
||||
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
@@ -573,10 +509,22 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
||||
sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
|
||||
return torch.cat(
|
||||
[torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
|
||||
half_angles = angles * 0.5
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
||||
)
|
||||
quaternions = torch.cat(
|
||||
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
||||
)
|
||||
return quaternions
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
@@ -595,9 +543,18 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||
sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
|
||||
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
|
||||
# can't be zero
|
||||
angles = 2 * half_angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
||||
)
|
||||
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||
|
||||
|
||||
|
||||
@@ -311,7 +311,9 @@ def plot_scene(
|
||||
)
|
||||
else:
|
||||
msg = "Invalid number {} of viewpoint cameras were provided. Either 1 \
|
||||
or {} cameras are required".format(len(viewpoint_cameras), len(subplots))
|
||||
or {} cameras are required".format(
|
||||
len(viewpoint_cameras), len(subplots)
|
||||
)
|
||||
warnings.warn(msg)
|
||||
|
||||
for subplot_idx in range(len(subplots)):
|
||||
@@ -586,15 +588,9 @@ def _add_struct_from_batch(
|
||||
if isinstance(batched_struct, CamerasBase):
|
||||
# we can't index directly into camera batches
|
||||
R, T = batched_struct.R, batched_struct.T
|
||||
# pyre-fixme[6]: For 1st argument expected
|
||||
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, Module]`.
|
||||
r_idx = min(scene_num, len(R) - 1)
|
||||
# pyre-fixme[6]: For 1st argument expected
|
||||
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Tensor, Module]`.
|
||||
t_idx = min(scene_num, len(T) - 1)
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
R = R[r_idx].unsqueeze(0)
|
||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||
T = T[t_idx].unsqueeze(0)
|
||||
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
||||
elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(
|
||||
|
||||
@@ -11,6 +11,7 @@ from tests.test_ball_query import TestBallQuery
|
||||
|
||||
|
||||
def bm_ball_query() -> None:
|
||||
|
||||
backends = ["cpu", "cuda:0"]
|
||||
|
||||
kwargs_list = []
|
||||
|
||||
@@ -11,6 +11,7 @@ from tests.test_cameras_alignment import TestCamerasAlignment
|
||||
|
||||
|
||||
def bm_cameras_alignment() -> None:
|
||||
|
||||
case_grid = {
|
||||
"batch_size": [10, 100, 1000],
|
||||
"mode": ["centers", "extrinsics"],
|
||||
|
||||
@@ -11,6 +11,7 @@ from tests.test_knn import TestKNN
|
||||
|
||||
|
||||
def bm_knn() -> None:
|
||||
|
||||
backends = ["cpu", "cuda:0"]
|
||||
|
||||
kwargs_list = []
|
||||
|
||||
@@ -12,6 +12,7 @@ from tests.test_point_mesh_distance import TestPointMeshDistance
|
||||
|
||||
|
||||
def bm_point_mesh_distance() -> None:
|
||||
|
||||
backend = ["cuda:0"]
|
||||
|
||||
kwargs_list = []
|
||||
|
||||
@@ -12,6 +12,7 @@ from tests.test_points_alignment import TestCorrespondingPointsAlignment, TestIC
|
||||
|
||||
|
||||
def bm_iterative_closest_point() -> None:
|
||||
|
||||
case_grid = {
|
||||
"batch_size": [1, 10],
|
||||
"dim": [3, 20],
|
||||
@@ -42,6 +43,7 @@ def bm_iterative_closest_point() -> None:
|
||||
|
||||
|
||||
def bm_corresponding_points_alignment() -> None:
|
||||
|
||||
case_grid = {
|
||||
"allow_reflection": [True, False],
|
||||
"batch_size": [1, 10, 100],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user