mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-02 02:05:59 +08:00
Compare commits
1 Commits
main
...
bottler/ac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c586b1351 |
@@ -88,6 +88,7 @@ def workflow_pair(
|
|||||||
upload=False,
|
upload=False,
|
||||||
filter_branch,
|
filter_branch,
|
||||||
):
|
):
|
||||||
|
|
||||||
w = []
|
w = []
|
||||||
py = python_version.replace(".", "")
|
py = python_version.replace(".", "")
|
||||||
pyt = pytorch_version.replace(".", "")
|
pyt = pytorch_version.replace(".", "")
|
||||||
@@ -126,6 +127,7 @@ def generate_base_workflow(
|
|||||||
btype,
|
btype,
|
||||||
filter_branch=None,
|
filter_branch=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
d = {
|
d = {
|
||||||
"name": base_workflow_name,
|
"name": base_workflow_name,
|
||||||
"python_version": python_version,
|
"python_version": python_version,
|
||||||
|
|||||||
3
.github/workflows/build.yml
vendored
3
.github/workflows/build.yml
vendored
@@ -3,9 +3,6 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
jobs:
|
jobs:
|
||||||
binary_linux_conda_cuda:
|
binary_linux_conda_cuda:
|
||||||
runs-on: 4-core-ubuntu-gpu-t4
|
runs-on: 4-core-ubuntu-gpu-t4
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
DIR=$(dirname "${DIR}")
|
DIR=$(dirname "${DIR}")
|
||||||
|
|
||||||
if [[ -f "${DIR}/BUCK" ]]
|
if [[ -f "${DIR}/TARGETS" ]]
|
||||||
then
|
then
|
||||||
pyfmt "${DIR}"
|
pyfmt "${DIR}"
|
||||||
else
|
else
|
||||||
@@ -36,5 +36,5 @@ then
|
|||||||
|
|
||||||
echo "Running pyre..."
|
echo "Running pyre..."
|
||||||
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
||||||
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
|
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
|
||||||
from recommonmark.parser import CommonMarkParser
|
from recommonmark.parser import CommonMarkParser
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ This example demonstrates the most trivial, direct interface of the pulsar
|
|||||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||||
Output: basic.png.
|
Output: basic.png.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from os import path
|
from os import path
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ interface for sphere renderering. It renders and saves an image with
|
|||||||
10 random spheres.
|
10 random spheres.
|
||||||
Output: basic-pt3d.png.
|
Output: basic-pt3d.png.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ distorted. Gradient-based optimization is used to converge towards the
|
|||||||
original camera parameters.
|
original camera parameters.
|
||||||
Output: cam.gif.
|
Output: cam.gif.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from os import path
|
from os import path
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ distorted. Gradient-based optimization is used to converge towards the
|
|||||||
original camera parameters.
|
original camera parameters.
|
||||||
Output: cam-pt3d.gif
|
Output: cam-pt3d.gif
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ This example is not available yet through the 'unified' interface,
|
|||||||
because opacity support has not landed in PyTorch3D for general data
|
because opacity support has not landed in PyTorch3D for general data
|
||||||
structures yet.
|
structures yet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from os import path
|
from os import path
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ The scene is initialized with random spheres. Gradient-based
|
|||||||
optimization is used to converge towards a faithful
|
optimization is used to converge towards a faithful
|
||||||
scene representation.
|
scene representation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ The scene is initialized with random spheres. Gradient-based
|
|||||||
optimization is used to converge towards a faithful
|
optimization is used to converge towards a faithful
|
||||||
scene representation.
|
scene representation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ requirements:
|
|||||||
|
|
||||||
build:
|
build:
|
||||||
string: py{{py}}_{{ environ['CU_VERSION'] }}_pyt{{ environ['PYTORCH_VERSION_NODOT']}}
|
string: py{{py}}_{{ environ['CU_VERSION'] }}_pyt{{ environ['PYTORCH_VERSION_NODOT']}}
|
||||||
|
# script: LD_LIBRARY_PATH=$PREFIX/lib:$BUILD_PREFIX/lib:$LD_LIBRARY_PATH python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
||||||
script: python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
script: python setup.py install --single-version-externally-managed --record=record.txt # [not win]
|
||||||
script_env:
|
script_env:
|
||||||
- CUDA_HOME
|
- CUDA_HOME
|
||||||
@@ -56,6 +57,7 @@ test:
|
|||||||
- pandas
|
- pandas
|
||||||
- sqlalchemy
|
- sqlalchemy
|
||||||
commands:
|
commands:
|
||||||
|
#pytest .
|
||||||
python -m unittest discover -v -s tests -t .
|
python -m unittest discover -v -s tests -t .
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
""" "
|
""""
|
||||||
This file is the entry point for launching experiments with Implicitron.
|
This file is the entry point for launching experiments with Implicitron.
|
||||||
|
|
||||||
Launch Training
|
Launch Training
|
||||||
@@ -44,22 +44,25 @@ The outputs of the experiment are saved and logged in multiple ways:
|
|||||||
config file.
|
config file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.data_source import (
|
from pytorch3d.implicitron.dataset.data_source import (
|
||||||
DataSourceBase,
|
DataSourceBase,
|
||||||
ImplicitronDataSource,
|
ImplicitronDataSource,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||||
MultiPassEmissionAbsorptionRenderer,
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.optim
|
import torch.optim
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.tools import model_io
|
from pytorch3d.implicitron.tools import model_io
|
||||||
@@ -25,6 +26,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelFactoryBase(ReplaceableBase):
|
class ModelFactoryBase(ReplaceableBase):
|
||||||
|
|
||||||
resume: bool = True # resume from the last checkpoint
|
resume: bool = True # resume from the last checkpoint
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||||
@@ -114,9 +116,7 @@ class ImplicitronModelFactory(ModelFactoryBase):
|
|||||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||||
}
|
}
|
||||||
model_state_dict = torch.load(
|
model_state_dict = torch.load(
|
||||||
model_io.get_model_path(model_path),
|
model_io.get_model_path(model_path), map_location=map_location
|
||||||
map_location=map_location,
|
|
||||||
weights_only=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ from dataclasses import field
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch.optim
|
import torch.optim
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.tools import model_io
|
from pytorch3d.implicitron.tools import model_io
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@@ -121,7 +123,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
"""
|
"""
|
||||||
# Get the parameters to optimize
|
# Get the parameters to optimize
|
||||||
if hasattr(model, "_get_param_groups"): # use the model function
|
if hasattr(model, "_get_param_groups"): # use the model function
|
||||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
|
||||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||||
else:
|
else:
|
||||||
p_groups = [
|
p_groups = [
|
||||||
@@ -240,7 +241,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
map_location = {
|
map_location = {
|
||||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||||
}
|
}
|
||||||
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
|
optimizer_state = torch.load(opt_path, map_location)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||||
return optimizer_state
|
return optimizer_state
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
for epoch in range(start_epoch, self.max_epochs):
|
for epoch in range(start_epoch, self.max_epochs):
|
||||||
# automatic new_epoch and plotting of stats at every epoch start
|
# automatic new_epoch and plotting of stats at every epoch start
|
||||||
with stats:
|
with stats:
|
||||||
|
|
||||||
# Make sure to re-seed random generators to ensure reproducibility
|
# Make sure to re-seed random generators to ensure reproducibility
|
||||||
# even after restart.
|
# even after restart.
|
||||||
seed_all_random_engines(seed + epoch)
|
seed_all_random_engines(seed + epoch)
|
||||||
@@ -394,7 +395,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
):
|
):
|
||||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||||
if hasattr(model, "visualize"):
|
if hasattr(model, "visualize"):
|
||||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
|
||||||
model.visualize(
|
model.visualize(
|
||||||
viz,
|
viz,
|
||||||
visdom_env_imgs,
|
visdom_env_imgs,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import unittest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hydra import compose, initialize_config_dir
|
from hydra import compose, initialize_config_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from projects.implicitron_trainer.impl.optimizer_factory import (
|
from projects.implicitron_trainer.impl.optimizer_factory import (
|
||||||
@@ -52,8 +53,12 @@ class TestExperiment(unittest.TestCase):
|
|||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||||
"JsonIndexDatasetMapProvider"
|
"JsonIndexDatasetMapProvider"
|
||||||
)
|
)
|
||||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
dataset_args = (
|
||||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
)
|
||||||
|
dataloader_args = (
|
||||||
|
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
|
)
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
@@ -89,8 +94,12 @@ class TestExperiment(unittest.TestCase):
|
|||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||||
"JsonIndexDatasetMapProvider"
|
"JsonIndexDatasetMapProvider"
|
||||||
)
|
)
|
||||||
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
dataset_args = (
|
||||||
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
)
|
||||||
|
dataloader_args = (
|
||||||
|
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
|
)
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
@@ -102,7 +111,9 @@ class TestExperiment(unittest.TestCase):
|
|||||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
|
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
||||||
|
2
|
||||||
|
)
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
experiment.dump_cfg(cfg)
|
experiment.dump_cfg(cfg)
|
||||||
|
|||||||
@@ -81,9 +81,8 @@ class TestOptimizerFactory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_param_overrides_self_param_group_assignment(self):
|
def test_param_overrides_self_param_group_assignment(self):
|
||||||
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
||||||
na, nb = (
|
na, nb = Node(params=[pa]), Node(
|
||||||
Node(params=[pa]),
|
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
|
||||||
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
|
|
||||||
)
|
)
|
||||||
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
||||||
param_groups = self._get_param_groups(root)
|
param_groups = self._get_param_groups(root)
|
||||||
|
|||||||
@@ -84,9 +84,9 @@ def get_nerf_datasets(
|
|||||||
|
|
||||||
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
||||||
# Automatically download the data files if missing.
|
# Automatically download the data files if missing.
|
||||||
download_data([dataset_name], data_root=data_root)
|
download_data((dataset_name,), data_root=data_root)
|
||||||
|
|
||||||
train_data = torch.load(cameras_path, weights_only=True)
|
train_data = torch.load(cameras_path)
|
||||||
n_cameras = train_data["cameras"]["R"].shape[0]
|
n_cameras = train_data["cameras"]["R"].shape[0]
|
||||||
|
|
||||||
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ class Stats:
|
|||||||
it = self.it[stat_set]
|
it = self.it[stat_set]
|
||||||
|
|
||||||
for stat in self.log_vars:
|
for stat in self.log_vars:
|
||||||
|
|
||||||
if stat not in self.stats[stat_set]:
|
if stat not in self.stats[stat_set]:
|
||||||
self.stats[stat_set][stat] = AverageMeter()
|
self.stats[stat_set][stat] = AverageMeter()
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
|||||||
|
|
||||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
|
|
||||||
# Device on which to run.
|
# Device on which to run.
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
@@ -62,7 +63,7 @@ def main(cfg: DictConfig):
|
|||||||
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
||||||
|
|
||||||
print(f"Loading checkpoint {checkpoint_path}.")
|
print(f"Loading checkpoint {checkpoint_path}.")
|
||||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
loaded_data = torch.load(checkpoint_path)
|
||||||
# Do not load the cached xy grid.
|
# Do not load the cached xy grid.
|
||||||
# - this allows setting an arbitrary evaluation image size.
|
# - this allows setting an arbitrary evaluation image size.
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class TestRaysampler(unittest.TestCase):
|
|||||||
cameras, rays = [], []
|
cameras, rays = [], []
|
||||||
|
|
||||||
for _ in range(batch_size):
|
for _ in range(batch_size):
|
||||||
|
|
||||||
R = random_rotations(1)
|
R = random_rotations(1)
|
||||||
T = torch.randn(1, 3)
|
T = torch.randn(1, 3)
|
||||||
focal_length = torch.rand(1, 2) + 0.5
|
focal_length = torch.rand(1, 2) + 0.5
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
|||||||
|
|
||||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
|
|
||||||
# Set the relevant seeds for reproducibility.
|
# Set the relevant seeds for reproducibility.
|
||||||
np.random.seed(cfg.seed)
|
np.random.seed(cfg.seed)
|
||||||
torch.manual_seed(cfg.seed)
|
torch.manual_seed(cfg.seed)
|
||||||
@@ -76,7 +77,7 @@ def main(cfg: DictConfig):
|
|||||||
# Resume training if requested.
|
# Resume training if requested.
|
||||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||||
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
loaded_data = torch.load(checkpoint_path)
|
||||||
model.load_state_dict(loaded_data["model"])
|
model.load_state_dict(loaded_data["model"])
|
||||||
stats = pickle.loads(loaded_data["stats"])
|
stats = pickle.loads(loaded_data["stats"])
|
||||||
print(f" => resuming from epoch {stats.epoch}.")
|
print(f" => resuming from epoch {stats.epoch}.")
|
||||||
@@ -218,6 +219,7 @@ def main(cfg: DictConfig):
|
|||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||||
|
|
||||||
# Sample a validation camera/image.
|
# Sample a validation camera/image.
|
||||||
val_batch = next(val_dataloader.__iter__())
|
val_batch = next(val_dataloader.__iter__())
|
||||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||||
|
|||||||
@@ -6,4 +6,4 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
__version__ = "0.7.9"
|
__version__ = "0.7.8"
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
|
|||||||
|
|
||||||
|
|
||||||
def meshgrid_ij(
|
def meshgrid_ij(
|
||||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
|
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||||
|
|||||||
@@ -82,12 +82,10 @@ class _SymEig3x3(nn.Module):
|
|||||||
q = inputs_trace / 3.0
|
q = inputs_trace / 3.0
|
||||||
|
|
||||||
# Calculate squared sum of elements outside the main diagonal / 2
|
# Calculate squared sum of elements outside the main diagonal / 2
|
||||||
p1 = (
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
torch.square(inputs).sum(dim=(-1, -2)) - torch.square(inputs_diag).sum(-1)
|
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
|
||||||
) / 2
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
p2 = torch.square(inputs_diag - q[..., None]).sum(dim=-1) + 2.0 * p1.clamp(
|
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
||||||
self._eps
|
|
||||||
)
|
|
||||||
|
|
||||||
p = torch.sqrt(p2 / 6.0)
|
p = torch.sqrt(p2 / 6.0)
|
||||||
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
||||||
@@ -106,9 +104,7 @@ class _SymEig3x3(nn.Module):
|
|||||||
# Soft dispatch between the degenerate case (diagonal A) and general.
|
# Soft dispatch between the degenerate case (diagonal A) and general.
|
||||||
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
|
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
|
||||||
# We use 6 * eps to take into account the error accumulated during the p1 summation
|
# We use 6 * eps to take into account the error accumulated during the p1 summation
|
||||||
diag_soft_cond = torch.exp(-torch.square(p1 / (6 * self._eps))).detach()[
|
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
|
||||||
..., None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
|
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
|
||||||
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
|
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
|
||||||
@@ -203,7 +199,8 @@ class _SymEig3x3(nn.Module):
|
|||||||
cross_products[..., :1, :]
|
cross_products[..., :1, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
norms_sq = torch.square(cross_products).sum(dim=-1)
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
|
norms_sq = (cross_products**2).sum(dim=-1)
|
||||||
max_norms_index = norms_sq.argmax(dim=-1)
|
max_norms_index = norms_sq.argmax(dim=-1)
|
||||||
|
|
||||||
# Pick only the cross-product with highest squared norm for each input
|
# Pick only the cross-product with highest squared norm for each input
|
||||||
|
|||||||
@@ -32,9 +32,7 @@ __global__ void BallQueryKernel(
|
|||||||
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
||||||
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
||||||
const int64_t K,
|
const int64_t K,
|
||||||
const float radius,
|
const float radius2) {
|
||||||
const float radius2,
|
|
||||||
const bool skip_points_outside_cube) {
|
|
||||||
const int64_t N = p1.size(0);
|
const int64_t N = p1.size(0);
|
||||||
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
||||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
@@ -53,19 +51,7 @@ __global__ void BallQueryKernel(
|
|||||||
// Iterate over points in p2 until desired count is reached or
|
// Iterate over points in p2 until desired count is reached or
|
||||||
// all points have been considered
|
// all points have been considered
|
||||||
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
||||||
if (skip_points_outside_cube) {
|
// Calculate the distance between the points
|
||||||
bool is_within_radius = true;
|
|
||||||
// Filter when any one coordinate is already outside the radius
|
|
||||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
|
||||||
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
|
|
||||||
is_within_radius = (abs_diff <= radius);
|
|
||||||
}
|
|
||||||
if (!is_within_radius) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Else, calculate the distance between the points and compare
|
|
||||||
scalar_t dist2 = 0.0;
|
scalar_t dist2 = 0.0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
||||||
@@ -91,8 +77,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
|||||||
const at::Tensor& lengths1, // (N,)
|
const at::Tensor& lengths1, // (N,)
|
||||||
const at::Tensor& lengths2, // (N,)
|
const at::Tensor& lengths2, // (N,)
|
||||||
int K,
|
int K,
|
||||||
float radius,
|
float radius) {
|
||||||
bool skip_points_outside_cube) {
|
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
||||||
@@ -135,9 +120,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
|||||||
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
||||||
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||||
K_64,
|
K_64,
|
||||||
radius,
|
radius2);
|
||||||
radius2,
|
|
||||||
skip_points_outside_cube);
|
|
||||||
}));
|
}));
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|||||||
@@ -25,9 +25,6 @@
|
|||||||
// within the radius
|
// within the radius
|
||||||
// radius: the radius around each point within which the neighbors need to be
|
// radius: the radius around each point within which the neighbors need to be
|
||||||
// located
|
// located
|
||||||
// skip_points_outside_cube: If true, reduce multiplications of float values
|
|
||||||
// by not explicitly calculating distances to points that fall outside the
|
|
||||||
// D-cube with side length (2*radius) centered at each point in p1.
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||||
@@ -49,8 +46,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const int K,
|
const int K,
|
||||||
const float radius,
|
const float radius);
|
||||||
const bool skip_points_outside_cube);
|
|
||||||
|
|
||||||
// CUDA implementation
|
// CUDA implementation
|
||||||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||||
@@ -59,8 +55,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const int K,
|
const int K,
|
||||||
const float radius,
|
const float radius);
|
||||||
const bool skip_points_outside_cube);
|
|
||||||
|
|
||||||
// Implementation which is exposed
|
// Implementation which is exposed
|
||||||
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
||||||
@@ -70,8 +65,7 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K,
|
int K,
|
||||||
float radius,
|
float radius) {
|
||||||
bool skip_points_outside_cube) {
|
|
||||||
if (p1.is_cuda() || p2.is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CUDA(p1);
|
CHECK_CUDA(p1);
|
||||||
@@ -82,20 +76,16 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
|||||||
lengths1.contiguous(),
|
lengths1.contiguous(),
|
||||||
lengths2.contiguous(),
|
lengths2.contiguous(),
|
||||||
K,
|
K,
|
||||||
radius,
|
radius);
|
||||||
skip_points_outside_cube);
|
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(p1);
|
|
||||||
CHECK_CPU(p2);
|
|
||||||
return BallQueryCpu(
|
return BallQueryCpu(
|
||||||
p1.contiguous(),
|
p1.contiguous(),
|
||||||
p2.contiguous(),
|
p2.contiguous(),
|
||||||
lengths1.contiguous(),
|
lengths1.contiguous(),
|
||||||
lengths2.contiguous(),
|
lengths2.contiguous(),
|
||||||
K,
|
K,
|
||||||
radius,
|
radius);
|
||||||
skip_points_outside_cube);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,8 @@
|
|||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <math.h>
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||||
@@ -16,8 +16,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K,
|
int K,
|
||||||
float radius,
|
float radius) {
|
||||||
bool skip_points_outside_cube) {
|
|
||||||
const int N = p1.size(0);
|
const int N = p1.size(0);
|
||||||
const int P1 = p1.size(1);
|
const int P1 = p1.size(1);
|
||||||
const int D = p1.size(2);
|
const int D = p1.size(2);
|
||||||
@@ -39,16 +38,6 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
|||||||
const int64_t length2 = lengths2_a[n];
|
const int64_t length2 = lengths2_a[n];
|
||||||
for (int64_t i = 0; i < length1; ++i) {
|
for (int64_t i = 0; i < length1; ++i) {
|
||||||
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
|
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
|
||||||
if (skip_points_outside_cube) {
|
|
||||||
bool is_within_radius = true;
|
|
||||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
|
||||||
float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]);
|
|
||||||
is_within_radius = (abs_diff <= radius);
|
|
||||||
}
|
|
||||||
if (!is_within_radius) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
float dist2 = 0;
|
float dist2 = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
float diff = p1_a[n][i][d] - p2_a[n][j][d];
|
float diff = p1_a[n][i][d] - p2_a[n][j][d];
|
||||||
|
|||||||
@@ -98,11 +98,6 @@ at::Tensor SigmoidAlphaBlendBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(distances);
|
|
||||||
CHECK_CPU(pix_to_face);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(grad_alphas);
|
|
||||||
|
|
||||||
return SigmoidAlphaBlendBackwardCpu(
|
return SigmoidAlphaBlendBackwardCpu(
|
||||||
grad_alphas, alphas, distances, pix_to_face, sigma);
|
grad_alphas, alphas, distances, pix_to_face, sigma);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,16 +28,17 @@ __global__ void alphaCompositeCudaForwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
const int64_t batch_size = result.size(0);
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|
||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const auto batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * H * W;
|
const int num_pixels = C * H * W;
|
||||||
const auto num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Iterate over each feature in each pixel
|
// Iterate over each feature in each pixel
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
@@ -78,16 +79,17 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
const int64_t batch_size = points_idx.size(0);
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|
||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const auto batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * H * W;
|
const int num_pixels = C * H * W;
|
||||||
const auto num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
|
|||||||
@@ -74,9 +74,6 @@ torch::Tensor alphaCompositeForward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(features);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(points_idx);
|
|
||||||
return alphaCompositeCpuForward(features, alphas, points_idx);
|
return alphaCompositeCpuForward(features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -104,11 +101,6 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(grad_outputs);
|
|
||||||
CHECK_CPU(features);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(points_idx);
|
|
||||||
|
|
||||||
return alphaCompositeCpuBackward(
|
return alphaCompositeCpuBackward(
|
||||||
grad_outputs, features, alphas, points_idx);
|
grad_outputs, features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,16 +28,17 @@ __global__ void weightedSumNormCudaForwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
const int64_t batch_size = result.size(0);
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|
||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const auto batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * H * W;
|
const int num_pixels = C * H * W;
|
||||||
const auto num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
@@ -91,16 +92,17 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
const int64_t batch_size = points_idx.size(0);
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|
||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const auto batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * W * H;
|
const int num_pixels = C * W * H;
|
||||||
const auto num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
|
|||||||
@@ -73,10 +73,6 @@ torch::Tensor weightedSumNormForward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(features);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(points_idx);
|
|
||||||
|
|
||||||
return weightedSumNormCpuForward(features, alphas, points_idx);
|
return weightedSumNormCpuForward(features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -104,11 +100,6 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(grad_outputs);
|
|
||||||
CHECK_CPU(features);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(points_idx);
|
|
||||||
|
|
||||||
return weightedSumNormCpuBackward(
|
return weightedSumNormCpuBackward(
|
||||||
grad_outputs, features, alphas, points_idx);
|
grad_outputs, features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,16 +26,17 @@ __global__ void weightedSumCudaForwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
const int64_t batch_size = result.size(0);
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|
||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const auto batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * H * W;
|
const int num_pixels = C * H * W;
|
||||||
const auto num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Parallelize over each feature in each pixel in images of size H * W,
|
// Parallelize over each feature in each pixel in images of size H * W,
|
||||||
// for each image in the batch of size batch_size
|
// for each image in the batch of size batch_size
|
||||||
@@ -73,16 +74,17 @@ __global__ void weightedSumCudaBackwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
const int64_t batch_size = points_idx.size(0);
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|
||||||
// Get the batch and index
|
// Get the batch and index
|
||||||
const auto batch = blockIdx.x;
|
const int batch = blockIdx.x;
|
||||||
|
|
||||||
const int num_pixels = C * H * W;
|
const int num_pixels = C * H * W;
|
||||||
const auto num_threads = gridDim.y * blockDim.x;
|
const int num_threads = gridDim.y * blockDim.x;
|
||||||
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
// Iterate over each pixel to compute the contribution to the
|
// Iterate over each pixel to compute the contribution to the
|
||||||
// gradient for the features and weights
|
// gradient for the features and weights
|
||||||
|
|||||||
@@ -72,9 +72,6 @@ torch::Tensor weightedSumForward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(features);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(points_idx);
|
|
||||||
return weightedSumCpuForward(features, alphas, points_idx);
|
return weightedSumCpuForward(features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,11 +98,6 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(grad_outputs);
|
|
||||||
CHECK_CPU(features);
|
|
||||||
CHECK_CPU(alphas);
|
|
||||||
CHECK_CPU(points_idx);
|
|
||||||
|
|
||||||
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||||
|
#include <torch/extension.h>
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "./pulsar/pytorch/renderer.h"
|
#include "./pulsar/pytorch/renderer.h"
|
||||||
#include "./pulsar/pytorch/tensor_util.h"
|
#include "./pulsar/pytorch/tensor_util.h"
|
||||||
@@ -105,16 +106,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
py::class_<
|
py::class_<
|
||||||
pulsar::pytorch::Renderer,
|
pulsar::pytorch::Renderer,
|
||||||
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
|
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
|
||||||
.def(
|
.def(py::init<
|
||||||
py::init<
|
const uint&,
|
||||||
const uint&,
|
const uint&,
|
||||||
const uint&,
|
const uint&,
|
||||||
const uint&,
|
const bool&,
|
||||||
const bool&,
|
const bool&,
|
||||||
const bool&,
|
const float&,
|
||||||
const float&,
|
const uint&,
|
||||||
const uint&,
|
const uint&>())
|
||||||
const uint&>())
|
|
||||||
.def(
|
.def(
|
||||||
"__eq__",
|
"__eq__",
|
||||||
[](const pulsar::pytorch::Renderer& a,
|
[](const pulsar::pytorch::Renderer& a,
|
||||||
@@ -149,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
py::arg("gamma"),
|
py::arg("gamma"),
|
||||||
py::arg("max_depth"),
|
py::arg("max_depth"),
|
||||||
py::arg("min_depth") /* = 0.f*/,
|
py::arg("min_depth") /* = 0.f*/,
|
||||||
py::arg("bg_col") /* = std::nullopt not exposed properly in
|
py::arg(
|
||||||
pytorch 1.1. */
|
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
||||||
,
|
,
|
||||||
py::arg("opacity") /* = std::nullopt ... */,
|
py::arg("opacity") /* = at::nullopt ... */,
|
||||||
py::arg("percent_allowed_difference") = 0.01f,
|
py::arg("percent_allowed_difference") = 0.01f,
|
||||||
py::arg("max_n_hits") = MAX_UINT,
|
py::arg("max_n_hits") = MAX_UINT,
|
||||||
py::arg("mode") = 0)
|
py::arg("mode") = 0)
|
||||||
|
|||||||
@@ -60,8 +60,6 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(verts);
|
|
||||||
CHECK_CPU(faces);
|
|
||||||
return FaceAreasNormalsForwardCpu(verts, faces);
|
return FaceAreasNormalsForwardCpu(verts, faces);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,9 +80,5 @@ at::Tensor FaceAreasNormalsBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(grad_areas);
|
|
||||||
CHECK_CPU(grad_normals);
|
|
||||||
CHECK_CPU(verts);
|
|
||||||
CHECK_CPU(faces);
|
|
||||||
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,14 +20,14 @@ __global__ void GatherScatterCudaKernel(
|
|||||||
const size_t V,
|
const size_t V,
|
||||||
const size_t D,
|
const size_t D,
|
||||||
const size_t E) {
|
const size_t E) {
|
||||||
const auto tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
// Reverse the vertex order if backward.
|
// Reverse the vertex order if backward.
|
||||||
const int v0_idx = backward ? 1 : 0;
|
const int v0_idx = backward ? 1 : 0;
|
||||||
const int v1_idx = backward ? 0 : 1;
|
const int v1_idx = backward ? 0 : 1;
|
||||||
|
|
||||||
// Edges are split evenly across the blocks.
|
// Edges are split evenly across the blocks.
|
||||||
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
|
for (int e = blockIdx.x; e < E; e += gridDim.x) {
|
||||||
// Get indices of vertices which form the edge.
|
// Get indices of vertices which form the edge.
|
||||||
const int64_t v0 = edges[2 * e + v0_idx];
|
const int64_t v0 = edges[2 * e + v0_idx];
|
||||||
const int64_t v1 = edges[2 * e + v1_idx];
|
const int64_t v1 = edges[2 * e + v1_idx];
|
||||||
@@ -35,7 +35,7 @@ __global__ void GatherScatterCudaKernel(
|
|||||||
// Split vertex features evenly across threads.
|
// Split vertex features evenly across threads.
|
||||||
// This implementation will be quite wasteful when D<128 since there will be
|
// This implementation will be quite wasteful when D<128 since there will be
|
||||||
// a lot of threads doing nothing.
|
// a lot of threads doing nothing.
|
||||||
for (auto d = tid; d < D; d += blockDim.x) {
|
for (int d = tid; d < D; d += blockDim.x) {
|
||||||
const float val = input[v1 * D + d];
|
const float val = input[v1 * D + d];
|
||||||
float* address = output + v0 * D + d;
|
float* address = output + v0 * D + d;
|
||||||
atomicAdd(address, val);
|
atomicAdd(address, val);
|
||||||
|
|||||||
@@ -53,7 +53,5 @@ at::Tensor GatherScatter(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(input);
|
|
||||||
CHECK_CPU(edges);
|
|
||||||
return GatherScatterCpu(input, edges, directed, backward);
|
return GatherScatterCpu(input, edges, directed, backward);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
|
|||||||
const size_t P,
|
const size_t P,
|
||||||
const size_t F,
|
const size_t F,
|
||||||
const size_t D) {
|
const size_t D) {
|
||||||
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
const auto num_threads = blockDim.x * gridDim.x;
|
const int num_threads = blockDim.x * gridDim.x;
|
||||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||||
const int p = pd / D;
|
const int p = pd / D;
|
||||||
const int d = pd % D;
|
const int d = pd % D;
|
||||||
@@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
|
|||||||
const size_t P,
|
const size_t P,
|
||||||
const size_t F,
|
const size_t F,
|
||||||
const size_t D) {
|
const size_t D) {
|
||||||
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
const auto num_threads = blockDim.x * gridDim.x;
|
const int num_threads = blockDim.x * gridDim.x;
|
||||||
for (int pd = tid; pd < P * D; pd += num_threads) {
|
for (int pd = tid; pd < P * D; pd += num_threads) {
|
||||||
const int p = pd / D;
|
const int p = pd / D;
|
||||||
const int d = pd % D;
|
const int d = pd % D;
|
||||||
|
|||||||
@@ -57,8 +57,6 @@ at::Tensor InterpFaceAttrsForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(face_attrs);
|
|
||||||
CHECK_CPU(barycentric_coords);
|
|
||||||
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
|
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,9 +106,6 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(face_attrs);
|
|
||||||
CHECK_CPU(barycentric_coords);
|
|
||||||
CHECK_CPU(grad_pix_attrs);
|
|
||||||
return InterpFaceAttrsBackwardCpu(
|
return InterpFaceAttrsBackwardCpu(
|
||||||
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,5 @@ inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(boxes1);
|
|
||||||
CHECK_CPU(boxes2);
|
|
||||||
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,10 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
#include <torch/torch.h>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <numeric>
|
||||||
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include "iou_box3d/iou_utils.h"
|
#include "iou_box3d/iou_utils.h"
|
||||||
|
|
||||||
|
|||||||
@@ -461,8 +461,10 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
|||||||
__device__ inline bool IsCoplanarTriTri(
|
__device__ inline bool IsCoplanarTriTri(
|
||||||
const FaceVerts& tri1,
|
const FaceVerts& tri1,
|
||||||
const FaceVerts& tri2) {
|
const FaceVerts& tri2) {
|
||||||
|
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
||||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||||
|
|
||||||
|
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
||||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||||
|
|
||||||
// Check if parallel
|
// Check if parallel
|
||||||
@@ -498,6 +500,7 @@ __device__ inline bool IsCoplanarTriPlane(
|
|||||||
const FaceVerts& tri,
|
const FaceVerts& tri,
|
||||||
const FaceVerts& plane,
|
const FaceVerts& plane,
|
||||||
const float3& normal) {
|
const float3& normal) {
|
||||||
|
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
||||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||||
|
|
||||||
// check if parallel
|
// check if parallel
|
||||||
@@ -725,7 +728,7 @@ __device__ inline int BoxIntersections(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Update the face_verts_out tris
|
// Update the face_verts_out tris
|
||||||
num_tris = min(MAX_TRIS, offset);
|
num_tris = offset;
|
||||||
for (int j = 0; j < num_tris; ++j) {
|
for (int j = 0; j < num_tris; ++j) {
|
||||||
face_verts_out[j] = tri_verts_updated[j];
|
face_verts_out[j] = tri_verts_updated[j];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,8 +74,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(p1);
|
|
||||||
CHECK_CPU(p2);
|
|
||||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,8 +140,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(p1);
|
|
||||||
CHECK_CPU(p2);
|
|
||||||
return KNearestNeighborBackwardCpu(
|
return KNearestNeighborBackwardCpu(
|
||||||
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,5 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubes(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(vol);
|
|
||||||
return MarchingCubesCpu(vol.contiguous(), isolevel);
|
return MarchingCubesCpu(vol.contiguous(), isolevel);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -88,8 +88,6 @@ at::Tensor PackedToPadded(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(inputs_packed);
|
|
||||||
CHECK_CPU(first_idxs);
|
|
||||||
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
|
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,7 +105,5 @@ at::Tensor PaddedToPacked(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(inputs_padded);
|
|
||||||
CHECK_CPU(first_idxs);
|
|
||||||
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
|
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -174,8 +174,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
|
|||||||
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
|
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
|
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||||
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
|
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||||
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
|
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
|
||||||
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
|
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
|
||||||
auto dists_a = dists.accessor<float, 1>();
|
auto dists_a = dists.accessor<float, 1>();
|
||||||
@@ -230,10 +230,10 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
|
|||||||
at::Tensor grad_as = at::zeros_like(as);
|
at::Tensor grad_as = at::zeros_like(as);
|
||||||
at::Tensor grad_bs = at::zeros_like(bs);
|
at::Tensor grad_bs = at::zeros_like(bs);
|
||||||
|
|
||||||
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
|
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||||
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
|
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||||
auto grad_as_a = grad_as.accessor<float, H1 == 1 ? 2 : 3>();
|
auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||||
auto grad_bs_a = grad_bs.accessor<float, H2 == 1 ? 2 : 3>();
|
auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||||
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
|
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
|
||||||
auto grad_dists_a = grad_dists.accessor<float, 1>();
|
auto grad_dists_a = grad_dists.accessor<float, 1>();
|
||||||
|
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Perform reduction in shared memory.
|
// Perform reduction in shared memory.
|
||||||
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
|
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
if (min_dists[tid] > min_dists[tid + s]) {
|
if (min_dists[tid] > min_dists[tid + s]) {
|
||||||
min_dists[tid] = min_dists[tid + s];
|
min_dists[tid] = min_dists[tid + s];
|
||||||
@@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
|
|||||||
const float3* tris_f3 = (float3*)tris;
|
const float3* tris_f3 = (float3*)tris;
|
||||||
|
|
||||||
// Parallelize over P * S computations
|
// Parallelize over P * S computations
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||||
const int t = t_i / P; // segment index.
|
const int t = t_i / P; // segment index.
|
||||||
@@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
|
|||||||
const float3* tris_f3 = (float3*)tris;
|
const float3* tris_f3 = (float3*)tris;
|
||||||
|
|
||||||
// Parallelize over P * S computations
|
// Parallelize over P * S computations
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
|
||||||
const int t = t_i / P; // triangle index.
|
const int t = t_i / P; // triangle index.
|
||||||
@@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
|
|||||||
float3* segms_f3 = (float3*)segms;
|
float3* segms_f3 = (float3*)segms;
|
||||||
|
|
||||||
// Parallelize over P * S computations
|
// Parallelize over P * S computations
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||||
const int s = t_i / P; // segment index.
|
const int s = t_i / P; // segment index.
|
||||||
@@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
|
|||||||
float3* segms_f3 = (float3*)segms;
|
float3* segms_f3 = (float3*)segms;
|
||||||
|
|
||||||
// Parallelize over P * S computations
|
// Parallelize over P * S computations
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
|
||||||
const int s = t_i / P; // segment index.
|
const int s = t_i / P; // segment index.
|
||||||
|
|||||||
@@ -88,10 +88,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(points_first_idx);
|
|
||||||
CHECK_CPU(tris);
|
|
||||||
CHECK_CPU(tris_first_idx);
|
|
||||||
return PointFaceDistanceForwardCpu(
|
return PointFaceDistanceForwardCpu(
|
||||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||||
}
|
}
|
||||||
@@ -147,10 +143,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(tris);
|
|
||||||
CHECK_CPU(idx_points);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return PointFaceDistanceBackwardCpu(
|
return PointFaceDistanceBackwardCpu(
|
||||||
points, tris, idx_points, grad_dists, min_triangle_area);
|
points, tris, idx_points, grad_dists, min_triangle_area);
|
||||||
}
|
}
|
||||||
@@ -229,10 +221,6 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(points_first_idx);
|
|
||||||
CHECK_CPU(tris);
|
|
||||||
CHECK_CPU(tris_first_idx);
|
|
||||||
return FacePointDistanceForwardCpu(
|
return FacePointDistanceForwardCpu(
|
||||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||||
}
|
}
|
||||||
@@ -289,10 +277,6 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(tris);
|
|
||||||
CHECK_CPU(idx_tris);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return FacePointDistanceBackwardCpu(
|
return FacePointDistanceBackwardCpu(
|
||||||
points, tris, idx_tris, grad_dists, min_triangle_area);
|
points, tris, idx_tris, grad_dists, min_triangle_area);
|
||||||
}
|
}
|
||||||
@@ -362,10 +346,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(points_first_idx);
|
|
||||||
CHECK_CPU(segms);
|
|
||||||
CHECK_CPU(segms_first_idx);
|
|
||||||
return PointEdgeDistanceForwardCpu(
|
return PointEdgeDistanceForwardCpu(
|
||||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||||
}
|
}
|
||||||
@@ -416,10 +396,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(segms);
|
|
||||||
CHECK_CPU(idx_points);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -488,10 +464,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(points_first_idx);
|
|
||||||
CHECK_CPU(segms);
|
|
||||||
CHECK_CPU(segms_first_idx);
|
|
||||||
return EdgePointDistanceForwardCpu(
|
return EdgePointDistanceForwardCpu(
|
||||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||||
}
|
}
|
||||||
@@ -542,10 +514,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(segms);
|
|
||||||
CHECK_CPU(idx_segms);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -599,8 +567,6 @@ torch::Tensor PointFaceArrayDistanceForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(tris);
|
|
||||||
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
|
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -647,9 +613,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(tris);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return PointFaceArrayDistanceBackwardCpu(
|
return PointFaceArrayDistanceBackwardCpu(
|
||||||
points, tris, grad_dists, min_triangle_area);
|
points, tris, grad_dists, min_triangle_area);
|
||||||
}
|
}
|
||||||
@@ -698,8 +661,6 @@ torch::Tensor PointEdgeArrayDistanceForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(segms);
|
|
||||||
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -742,8 +703,5 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(segms);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,12 +104,6 @@ inline void PointsToVolumesForward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points_3d);
|
|
||||||
CHECK_CPU(points_features);
|
|
||||||
CHECK_CPU(volume_densities);
|
|
||||||
CHECK_CPU(volume_features);
|
|
||||||
CHECK_CPU(grid_sizes);
|
|
||||||
CHECK_CPU(mask);
|
|
||||||
PointsToVolumesForwardCpu(
|
PointsToVolumesForwardCpu(
|
||||||
points_3d,
|
points_3d,
|
||||||
points_features,
|
points_features,
|
||||||
@@ -189,14 +183,6 @@ inline void PointsToVolumesBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points_3d);
|
|
||||||
CHECK_CPU(points_features);
|
|
||||||
CHECK_CPU(grid_sizes);
|
|
||||||
CHECK_CPU(mask);
|
|
||||||
CHECK_CPU(grad_volume_densities);
|
|
||||||
CHECK_CPU(grad_volume_features);
|
|
||||||
CHECK_CPU(grad_points_3d);
|
|
||||||
CHECK_CPU(grad_points_features);
|
|
||||||
PointsToVolumesBackwardCpu(
|
PointsToVolumesBackwardCpu(
|
||||||
points_3d,
|
points_3d,
|
||||||
points_features,
|
points_features,
|
||||||
|
|||||||
@@ -8,7 +8,9 @@
|
|||||||
|
|
||||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||||
|
|||||||
@@ -15,8 +15,8 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(_WIN64) || defined(_WIN32)
|
#if defined(_WIN64) || defined(_WIN32)
|
||||||
using uint = unsigned int;
|
#define uint unsigned int
|
||||||
using ushort = unsigned short;
|
#define ushort unsigned short
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "./logging.h" // <- include before torch/extension.h
|
#include "./logging.h" // <- include before torch/extension.h
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ __device__ static float atomicMin(float* address, float val) {
|
|||||||
(OUT_PTR), \
|
(OUT_PTR), \
|
||||||
(NUM_SELECTED_PTR), \
|
(NUM_SELECTED_PTR), \
|
||||||
(NUM_ITEMS), \
|
(NUM_ITEMS), \
|
||||||
(STREAM));
|
stream = (STREAM));
|
||||||
|
|
||||||
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
||||||
HANDLECUDA(cudaMemcpy( \
|
HANDLECUDA(cudaMemcpy( \
|
||||||
|
|||||||
@@ -357,11 +357,11 @@ void MAX_WS(
|
|||||||
//
|
//
|
||||||
//
|
//
|
||||||
#define END_PARALLEL() \
|
#define END_PARALLEL() \
|
||||||
end_parallel:; \
|
end_parallel :; \
|
||||||
}
|
}
|
||||||
#define END_PARALLEL_NORET() }
|
#define END_PARALLEL_NORET() }
|
||||||
#define END_PARALLEL_2D() \
|
#define END_PARALLEL_2D() \
|
||||||
end_parallel:; \
|
end_parallel :; \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
#define END_PARALLEL_2D_NORET() \
|
#define END_PARALLEL_2D_NORET() \
|
||||||
|
|||||||
@@ -70,6 +70,11 @@ struct CamGradInfo {
|
|||||||
float3 pixel_dir_y;
|
float3 pixel_dir_y;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved.
|
||||||
|
struct IntWrapper {
|
||||||
|
int val;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace pulsar
|
} // namespace pulsar
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -149,6 +149,11 @@ IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) {
|
||||||
|
IntWrapper res;
|
||||||
|
res.val = a.val + b.val;
|
||||||
|
return res;
|
||||||
|
}
|
||||||
} // namespace pulsar
|
} // namespace pulsar
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -155,8 +155,8 @@ void backward(
|
|||||||
stream);
|
stream);
|
||||||
CHECKLAUNCH();
|
CHECKLAUNCH();
|
||||||
SUM_WS(
|
SUM_WS(
|
||||||
self->ids_sorted_d,
|
(IntWrapper*)(self->ids_sorted_d),
|
||||||
self->n_grad_contributions_d,
|
(IntWrapper*)(self->n_grad_contributions_d),
|
||||||
static_cast<int>(num_balls),
|
static_cast<int>(num_balls),
|
||||||
self->workspace_d,
|
self->workspace_d,
|
||||||
self->workspace_size,
|
self->workspace_size,
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ HOST void construct(
|
|||||||
self->cam.film_width = width;
|
self->cam.film_width = width;
|
||||||
self->cam.film_height = height;
|
self->cam.film_height = height;
|
||||||
self->max_num_balls = max_num_balls;
|
self->max_num_balls = max_num_balls;
|
||||||
MALLOC(self->result_d, float, width * height * n_channels);
|
MALLOC(self->result_d, float, width* height* n_channels);
|
||||||
self->cam.orthogonal_projection = orthogonal_projection;
|
self->cam.orthogonal_projection = orthogonal_projection;
|
||||||
self->cam.right_handed = right_handed_system;
|
self->cam.right_handed = right_handed_system;
|
||||||
self->cam.background_normalization_depth = background_normalization_depth;
|
self->cam.background_normalization_depth = background_normalization_depth;
|
||||||
@@ -93,7 +93,7 @@ HOST void construct(
|
|||||||
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
|
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
|
||||||
MALLOC(self->region_flags_d, char, max_num_balls);
|
MALLOC(self->region_flags_d, char, max_num_balls);
|
||||||
MALLOC(self->num_selected_d, size_t, 1);
|
MALLOC(self->num_selected_d, size_t, 1);
|
||||||
MALLOC(self->forw_info_d, float, width * height * (3 + 2 * n_track));
|
MALLOC(self->forw_info_d, float, width* height * (3 + 2 * n_track));
|
||||||
MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
|
MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
|
||||||
MALLOC(self->grad_pos_d, float3, max_num_balls);
|
MALLOC(self->grad_pos_d, float3, max_num_balls);
|
||||||
MALLOC(self->grad_col_d, float, max_num_balls* n_channels);
|
MALLOC(self->grad_col_d, float, max_num_balls* n_channels);
|
||||||
|
|||||||
@@ -18,89 +18,68 @@ namespace Renderer {
|
|||||||
|
|
||||||
template <bool DEV>
|
template <bool DEV>
|
||||||
HOST void destruct(Renderer* self) {
|
HOST void destruct(Renderer* self) {
|
||||||
if (self->result_d != NULL) {
|
if (self->result_d != NULL)
|
||||||
FREE(self->result_d);
|
FREE(self->result_d);
|
||||||
}
|
|
||||||
self->result_d = NULL;
|
self->result_d = NULL;
|
||||||
if (self->min_depth_d != NULL) {
|
if (self->min_depth_d != NULL)
|
||||||
FREE(self->min_depth_d);
|
FREE(self->min_depth_d);
|
||||||
}
|
|
||||||
self->min_depth_d = NULL;
|
self->min_depth_d = NULL;
|
||||||
if (self->min_depth_sorted_d != NULL) {
|
if (self->min_depth_sorted_d != NULL)
|
||||||
FREE(self->min_depth_sorted_d);
|
FREE(self->min_depth_sorted_d);
|
||||||
}
|
|
||||||
self->min_depth_sorted_d = NULL;
|
self->min_depth_sorted_d = NULL;
|
||||||
if (self->ii_d != NULL) {
|
if (self->ii_d != NULL)
|
||||||
FREE(self->ii_d);
|
FREE(self->ii_d);
|
||||||
}
|
|
||||||
self->ii_d = NULL;
|
self->ii_d = NULL;
|
||||||
if (self->ii_sorted_d != NULL) {
|
if (self->ii_sorted_d != NULL)
|
||||||
FREE(self->ii_sorted_d);
|
FREE(self->ii_sorted_d);
|
||||||
}
|
|
||||||
self->ii_sorted_d = NULL;
|
self->ii_sorted_d = NULL;
|
||||||
if (self->ids_d != NULL) {
|
if (self->ids_d != NULL)
|
||||||
FREE(self->ids_d);
|
FREE(self->ids_d);
|
||||||
}
|
|
||||||
self->ids_d = NULL;
|
self->ids_d = NULL;
|
||||||
if (self->ids_sorted_d != NULL) {
|
if (self->ids_sorted_d != NULL)
|
||||||
FREE(self->ids_sorted_d);
|
FREE(self->ids_sorted_d);
|
||||||
}
|
|
||||||
self->ids_sorted_d = NULL;
|
self->ids_sorted_d = NULL;
|
||||||
if (self->workspace_d != NULL) {
|
if (self->workspace_d != NULL)
|
||||||
FREE(self->workspace_d);
|
FREE(self->workspace_d);
|
||||||
}
|
|
||||||
self->workspace_d = NULL;
|
self->workspace_d = NULL;
|
||||||
if (self->di_d != NULL) {
|
if (self->di_d != NULL)
|
||||||
FREE(self->di_d);
|
FREE(self->di_d);
|
||||||
}
|
|
||||||
self->di_d = NULL;
|
self->di_d = NULL;
|
||||||
if (self->di_sorted_d != NULL) {
|
if (self->di_sorted_d != NULL)
|
||||||
FREE(self->di_sorted_d);
|
FREE(self->di_sorted_d);
|
||||||
}
|
|
||||||
self->di_sorted_d = NULL;
|
self->di_sorted_d = NULL;
|
||||||
if (self->region_flags_d != NULL) {
|
if (self->region_flags_d != NULL)
|
||||||
FREE(self->region_flags_d);
|
FREE(self->region_flags_d);
|
||||||
}
|
|
||||||
self->region_flags_d = NULL;
|
self->region_flags_d = NULL;
|
||||||
if (self->num_selected_d != NULL) {
|
if (self->num_selected_d != NULL)
|
||||||
FREE(self->num_selected_d);
|
FREE(self->num_selected_d);
|
||||||
}
|
|
||||||
self->num_selected_d = NULL;
|
self->num_selected_d = NULL;
|
||||||
if (self->forw_info_d != NULL) {
|
if (self->forw_info_d != NULL)
|
||||||
FREE(self->forw_info_d);
|
FREE(self->forw_info_d);
|
||||||
}
|
|
||||||
self->forw_info_d = NULL;
|
self->forw_info_d = NULL;
|
||||||
if (self->min_max_pixels_d != NULL) {
|
if (self->min_max_pixels_d != NULL)
|
||||||
FREE(self->min_max_pixels_d);
|
FREE(self->min_max_pixels_d);
|
||||||
}
|
|
||||||
self->min_max_pixels_d = NULL;
|
self->min_max_pixels_d = NULL;
|
||||||
if (self->grad_pos_d != NULL) {
|
if (self->grad_pos_d != NULL)
|
||||||
FREE(self->grad_pos_d);
|
FREE(self->grad_pos_d);
|
||||||
}
|
|
||||||
self->grad_pos_d = NULL;
|
self->grad_pos_d = NULL;
|
||||||
if (self->grad_col_d != NULL) {
|
if (self->grad_col_d != NULL)
|
||||||
FREE(self->grad_col_d);
|
FREE(self->grad_col_d);
|
||||||
}
|
|
||||||
self->grad_col_d = NULL;
|
self->grad_col_d = NULL;
|
||||||
if (self->grad_rad_d != NULL) {
|
if (self->grad_rad_d != NULL)
|
||||||
FREE(self->grad_rad_d);
|
FREE(self->grad_rad_d);
|
||||||
}
|
|
||||||
self->grad_rad_d = NULL;
|
self->grad_rad_d = NULL;
|
||||||
if (self->grad_cam_d != NULL) {
|
if (self->grad_cam_d != NULL)
|
||||||
FREE(self->grad_cam_d);
|
FREE(self->grad_cam_d);
|
||||||
}
|
|
||||||
self->grad_cam_d = NULL;
|
self->grad_cam_d = NULL;
|
||||||
if (self->grad_cam_buf_d != NULL) {
|
if (self->grad_cam_buf_d != NULL)
|
||||||
FREE(self->grad_cam_buf_d);
|
FREE(self->grad_cam_buf_d);
|
||||||
}
|
|
||||||
self->grad_cam_buf_d = NULL;
|
self->grad_cam_buf_d = NULL;
|
||||||
if (self->grad_opy_d != NULL) {
|
if (self->grad_opy_d != NULL)
|
||||||
FREE(self->grad_opy_d);
|
FREE(self->grad_opy_d);
|
||||||
}
|
|
||||||
self->grad_opy_d = NULL;
|
self->grad_opy_d = NULL;
|
||||||
if (self->n_grad_contributions_d != NULL) {
|
if (self->n_grad_contributions_d != NULL)
|
||||||
FREE(self->n_grad_contributions_d);
|
FREE(self->n_grad_contributions_d);
|
||||||
}
|
|
||||||
self->n_grad_contributions_d = NULL;
|
self->n_grad_contributions_d = NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ GLOBAL void calc_signature(
|
|||||||
* for every iteration through the loading loop every thread could add a
|
* for every iteration through the loading loop every thread could add a
|
||||||
* 'hit' to the buffer.
|
* 'hit' to the buffer.
|
||||||
*/
|
*/
|
||||||
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE * 2
|
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2
|
||||||
/**
|
/**
|
||||||
* The threshold after which the spheres that are in the render buffer
|
* The threshold after which the spheres that are in the render buffer
|
||||||
* are rendered and the buffer is flushed.
|
* are rendered and the buffer is flushed.
|
||||||
|
|||||||
@@ -64,9 +64,8 @@ GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls) {
|
|||||||
// The sphere only contributes to the camera gradients if it is
|
// The sphere only contributes to the camera gradients if it is
|
||||||
// large enough in screen space.
|
// large enough in screen space.
|
||||||
if (renderer.ids_sorted_d[idx] > 0 && ii.max.x >= ii.min.x + 3 &&
|
if (renderer.ids_sorted_d[idx] > 0 && ii.max.x >= ii.min.x + 3 &&
|
||||||
ii.max.y >= ii.min.y + 3) {
|
ii.max.y >= ii.min.y + 3)
|
||||||
renderer.ids_sorted_d[idx] = 1;
|
renderer.ids_sorted_d[idx] = 1;
|
||||||
}
|
|
||||||
END_PARALLEL_NORET();
|
END_PARALLEL_NORET();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -139,9 +139,8 @@ GLOBAL void render(
|
|||||||
coord_y < cam_norm.film_border_top + cam_norm.film_height) {
|
coord_y < cam_norm.film_border_top + cam_norm.film_height) {
|
||||||
// Initialize the result.
|
// Initialize the result.
|
||||||
if (mode == 0u) {
|
if (mode == 0u) {
|
||||||
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id) {
|
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id)
|
||||||
result[c_id] = bg_col[c_id];
|
result[c_id] = bg_col[c_id];
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
result[0] = 0.f;
|
result[0] = 0.f;
|
||||||
}
|
}
|
||||||
@@ -191,22 +190,20 @@ GLOBAL void render(
|
|||||||
"render|found intersection with sphere %u.\n",
|
"render|found intersection with sphere %u.\n",
|
||||||
sphere_id_l[write_idx]);
|
sphere_id_l[write_idx]);
|
||||||
}
|
}
|
||||||
if (ii.min.x == MAX_USHORT) {
|
if (ii.min.x == MAX_USHORT)
|
||||||
// This is an invalid sphere (out of image). These spheres have
|
// This is an invalid sphere (out of image). These spheres have
|
||||||
// maximum depth. Since we ordered the spheres by earliest possible
|
// maximum depth. Since we ordered the spheres by earliest possible
|
||||||
// intersection depth we re certain that there will no other sphere
|
// intersection depth we re certain that there will no other sphere
|
||||||
// that is relevant after this one.
|
// that is relevant after this one.
|
||||||
loading_done = true;
|
loading_done = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// Reset n_pixels_done.
|
// Reset n_pixels_done.
|
||||||
n_pixels_done = 0;
|
n_pixels_done = 0;
|
||||||
thread_block.sync(); // Make sure n_loaded is updated.
|
thread_block.sync(); // Make sure n_loaded is updated.
|
||||||
if (n_loaded > RENDER_BUFFER_LOAD_THRESH) {
|
if (n_loaded > RENDER_BUFFER_LOAD_THRESH) {
|
||||||
// The load buffer is full enough. Draw.
|
// The load buffer is full enough. Draw.
|
||||||
if (thread_block.thread_rank() == 0) {
|
if (thread_block.thread_rank() == 0)
|
||||||
n_balls_loaded += n_loaded;
|
n_balls_loaded += n_loaded;
|
||||||
}
|
|
||||||
max_closest_possible_intersection = 0.f;
|
max_closest_possible_intersection = 0.f;
|
||||||
// This excludes threads outside of the image boundary. Also, it reduces
|
// This excludes threads outside of the image boundary. Also, it reduces
|
||||||
// block artifacts.
|
// block artifacts.
|
||||||
@@ -293,9 +290,8 @@ GLOBAL void render(
|
|||||||
uint warp_done = thread_warp.ballot(done);
|
uint warp_done = thread_warp.ballot(done);
|
||||||
int warp_done_bit_cnt = POPC(warp_done);
|
int warp_done_bit_cnt = POPC(warp_done);
|
||||||
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||||
if (thread_warp.thread_rank() == 0) {
|
if (thread_warp.thread_rank() == 0)
|
||||||
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||||
}
|
|
||||||
// This sync is necessary to keep n_loaded until all threads are done with
|
// This sync is necessary to keep n_loaded until all threads are done with
|
||||||
// painting.
|
// painting.
|
||||||
thread_block.sync();
|
thread_block.sync();
|
||||||
@@ -303,9 +299,8 @@ GLOBAL void render(
|
|||||||
}
|
}
|
||||||
thread_block.sync();
|
thread_block.sync();
|
||||||
}
|
}
|
||||||
if (thread_block.thread_rank() == 0) {
|
if (thread_block.thread_rank() == 0)
|
||||||
n_balls_loaded += n_loaded;
|
n_balls_loaded += n_loaded;
|
||||||
}
|
|
||||||
PULSAR_LOG_DEV_PIX(
|
PULSAR_LOG_DEV_PIX(
|
||||||
PULSAR_LOG_RENDER_PIX,
|
PULSAR_LOG_RENDER_PIX,
|
||||||
"render|loaded %d balls in total.\n",
|
"render|loaded %d balls in total.\n",
|
||||||
@@ -391,9 +386,8 @@ GLOBAL void render(
|
|||||||
static_cast<float>(tracker.get_n_hits());
|
static_cast<float>(tracker.get_n_hits());
|
||||||
} else {
|
} else {
|
||||||
float sm_d_normfac = FRCP(FMAX(sm_d, FEPS));
|
float sm_d_normfac = FRCP(FMAX(sm_d, FEPS));
|
||||||
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id) {
|
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id)
|
||||||
result[c_id] *= sm_d_normfac;
|
result[c_id] *= sm_d_normfac;
|
||||||
}
|
|
||||||
int write_loc = (coord_y - cam_norm.film_border_top) * cam_norm.film_width *
|
int write_loc = (coord_y - cam_norm.film_border_top) * cam_norm.film_width *
|
||||||
(3 + 2 * n_track) +
|
(3 + 2 * n_track) +
|
||||||
(coord_x - cam_norm.film_border_left) * (3 + 2 * n_track);
|
(coord_x - cam_norm.film_border_left) * (3 + 2 * n_track);
|
||||||
|
|||||||
@@ -860,9 +860,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
|||||||
? (cudaStream_t) nullptr
|
? (cudaStream_t) nullptr
|
||||||
#endif
|
#endif
|
||||||
: (cudaStream_t) nullptr);
|
: (cudaStream_t) nullptr);
|
||||||
if (mode == 1) {
|
if (mode == 1)
|
||||||
results[batch_i] = results[batch_i].slice(2, 0, 1, 1);
|
results[batch_i] = results[batch_i].slice(2, 0, 1, 1);
|
||||||
}
|
|
||||||
forw_infos[batch_i] = from_blob(
|
forw_infos[batch_i] = from_blob(
|
||||||
this->renderer_vec[batch_i].forw_info_d,
|
this->renderer_vec[batch_i].forw_info_d,
|
||||||
{this->renderer_vec[0].cam.film_height,
|
{this->renderer_vec[0].cam.film_height,
|
||||||
|
|||||||
@@ -128,9 +128,8 @@ struct Renderer {
|
|||||||
stream << "pulsar::Renderer[";
|
stream << "pulsar::Renderer[";
|
||||||
// Device info.
|
// Device info.
|
||||||
stream << self.device_type;
|
stream << self.device_type;
|
||||||
if (self.device_index != -1) {
|
if (self.device_index != -1)
|
||||||
stream << ", ID " << self.device_index;
|
stream << ", ID " << self.device_index;
|
||||||
}
|
|
||||||
stream << "]";
|
stream << "]";
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAException.h>
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
@@ -34,13 +33,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
|
|||||||
.contiguous();
|
.contiguous();
|
||||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
cudaMemcpyAsync(
|
||||||
result.data_ptr(),
|
result.data_ptr(),
|
||||||
tmp.data_ptr(),
|
tmp.data_ptr(),
|
||||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||||
tmp.size(3),
|
tmp.size(3),
|
||||||
cudaMemcpyDeviceToDevice,
|
cudaMemcpyDeviceToDevice,
|
||||||
at::cuda::getCurrentCUDAStream()));
|
at::cuda::getCurrentCUDAStream());
|
||||||
#else
|
#else
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Copy on CUDA device initiated but built "
|
"Copy on CUDA device initiated but built "
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
#include <c10/cuda/CUDAException.h>
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
namespace pulsar {
|
namespace pulsar {
|
||||||
@@ -18,8 +17,7 @@ void cudaDevToDev(
|
|||||||
const void* src,
|
const void* src,
|
||||||
const int& size,
|
const int& size,
|
||||||
const cudaStream_t& stream) {
|
const cudaStream_t& stream) {
|
||||||
C10_CUDA_CHECK(
|
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void cudaDevToHost(
|
void cudaDevToHost(
|
||||||
@@ -27,8 +25,7 @@ void cudaDevToHost(
|
|||||||
const void* src,
|
const void* src,
|
||||||
const int& size,
|
const int& size,
|
||||||
const cudaStream_t& stream) {
|
const cudaStream_t& stream) {
|
||||||
C10_CUDA_CHECK(
|
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
|
|||||||
@@ -6,6 +6,9 @@
|
|||||||
* LICENSE file in the root directory of this source tree.
|
* LICENSE file in the root directory of this source tree.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "./global.h"
|
||||||
|
#include "./logging.h"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A compilation unit to provide warnings about the code and avoid
|
* A compilation unit to provide warnings about the code and avoid
|
||||||
* repeated messages.
|
* repeated messages.
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class BitMask {
|
|||||||
|
|
||||||
// Use all threads in the current block to clear all bits of this BitMask
|
// Use all threads in the current block to clear all bits of this BitMask
|
||||||
__device__ void block_clear() {
|
__device__ void block_clear() {
|
||||||
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
|
||||||
data[i] = 0;
|
data[i] = 0;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
|
|||||||
const float blur_radius,
|
const float blur_radius,
|
||||||
float* bboxes, // (4, F)
|
float* bboxes, // (4, F)
|
||||||
bool* skip_face) { // (F,)
|
bool* skip_face) { // (F,)
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto num_threads = blockDim.x * gridDim.x;
|
const int num_threads = blockDim.x * gridDim.x;
|
||||||
const float sqrt_radius = sqrt(blur_radius);
|
const float sqrt_radius = sqrt(blur_radius);
|
||||||
for (int f = tid; f < F; f += num_threads) {
|
for (int f = tid; f < F; f += num_threads) {
|
||||||
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
|
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
|
||||||
@@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
|
|||||||
const int P,
|
const int P,
|
||||||
float* bboxes, // (4, P)
|
float* bboxes, // (4, P)
|
||||||
bool* skip_points) {
|
bool* skip_points) {
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto num_threads = blockDim.x * gridDim.x;
|
const int num_threads = blockDim.x * gridDim.x;
|
||||||
for (int p = tid; p < P; p += num_threads) {
|
for (int p = tid; p < P; p += num_threads) {
|
||||||
const float x = points[p * 3 + 0];
|
const float x = points[p * 3 + 0];
|
||||||
const float y = points[p * 3 + 1];
|
const float y = points[p * 3 + 1];
|
||||||
@@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
|||||||
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
|
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
|
||||||
const int num_chunks = N * chunks_per_batch;
|
const int num_chunks = N * chunks_per_batch;
|
||||||
|
|
||||||
for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||||
const int batch_idx = chunk / chunks_per_batch; // batch index
|
const int batch_idx = chunk / chunks_per_batch; // batch index
|
||||||
const int chunk_idx = chunk % chunks_per_batch;
|
const int chunk_idx = chunk % chunks_per_batch;
|
||||||
const int elem_chunk_start_idx = chunk_idx * chunk_size;
|
const int elem_chunk_start_idx = chunk_idx * chunk_size;
|
||||||
@@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
|||||||
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
|
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
|
||||||
|
|
||||||
// Have each thread handle a different face within the chunk
|
// Have each thread handle a different face within the chunk
|
||||||
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
|
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
|
||||||
const int e_idx = elem_chunk_start_idx + e;
|
const int e_idx = elem_chunk_start_idx + e;
|
||||||
|
|
||||||
// Check that we are still within the same element of the batch
|
// Check that we are still within the same element of the batch
|
||||||
@@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
|
|||||||
// Now we have processed every elem in the current chunk. We need to
|
// Now we have processed every elem in the current chunk. We need to
|
||||||
// count the number of elems in each bin so we can write the indices
|
// count the number of elems in each bin so we can write the indices
|
||||||
// out to global memory. We have each thread handle a different bin.
|
// out to global memory. We have each thread handle a different bin.
|
||||||
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
|
||||||
byx += blockDim.x) {
|
byx += blockDim.x) {
|
||||||
const int by = byx / num_bins_x;
|
const int by = byx / num_bins_x;
|
||||||
const int bx = byx % num_bins_x;
|
const int bx = byx % num_bins_x;
|
||||||
|
|||||||
@@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
|||||||
float* pix_dists,
|
float* pix_dists,
|
||||||
float* bary) {
|
float* bary) {
|
||||||
// Simple version: One thread per output pixel
|
// Simple version: One thread per output pixel
|
||||||
auto num_threads = gridDim.x * blockDim.x;
|
int num_threads = gridDim.x * blockDim.x;
|
||||||
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||||
// Convert linear index to 3D index
|
// Convert linear index to 3D index
|
||||||
@@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
|||||||
|
|
||||||
// Parallelize over each pixel in images of
|
// Parallelize over each pixel in images of
|
||||||
// size H * W, for each image in the batch of size N.
|
// size H * W, for each image in the batch of size N.
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
|
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
|
||||||
// Convert linear index to 3D index
|
// Convert linear index to 3D index
|
||||||
@@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
|||||||
) {
|
) {
|
||||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||||
auto num_threads = gridDim.x * blockDim.x;
|
int num_threads = gridDim.x * blockDim.x;
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
// Convert linear index into bin and pixel indices. We make the within
|
// Convert linear index into bin and pixel indices. We make the within
|
||||||
|
|||||||
@@ -138,9 +138,6 @@ RasterizeMeshesNaive(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(face_verts);
|
|
||||||
CHECK_CPU(mesh_to_face_first_idx);
|
|
||||||
CHECK_CPU(num_faces_per_mesh);
|
|
||||||
return RasterizeMeshesNaiveCpu(
|
return RasterizeMeshesNaiveCpu(
|
||||||
face_verts,
|
face_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
@@ -235,11 +232,6 @@ torch::Tensor RasterizeMeshesBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(face_verts);
|
|
||||||
CHECK_CPU(pix_to_face);
|
|
||||||
CHECK_CPU(grad_zbuf);
|
|
||||||
CHECK_CPU(grad_bary);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return RasterizeMeshesBackwardCpu(
|
return RasterizeMeshesBackwardCpu(
|
||||||
face_verts,
|
face_verts,
|
||||||
pix_to_face,
|
pix_to_face,
|
||||||
@@ -314,9 +306,6 @@ torch::Tensor RasterizeMeshesCoarse(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(face_verts);
|
|
||||||
CHECK_CPU(mesh_to_face_first_idx);
|
|
||||||
CHECK_CPU(num_faces_per_mesh);
|
|
||||||
return RasterizeMeshesCoarseCpu(
|
return RasterizeMeshesCoarseCpu(
|
||||||
face_verts,
|
face_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
@@ -434,8 +423,6 @@ RasterizeMeshesFine(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(face_verts);
|
|
||||||
CHECK_CPU(bin_faces);
|
|
||||||
AT_ERROR("NOT IMPLEMENTED");
|
AT_ERROR("NOT IMPLEMENTED");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <queue>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include "ATen/core/TensorAccessor.h"
|
#include "ATen/core/TensorAccessor.h"
|
||||||
@@ -106,8 +107,6 @@ auto ComputeFaceAreas(const torch::Tensor& face_verts) {
|
|||||||
return face_areas;
|
return face_areas;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Helper function to use with std::find_if to find the index of any
|
// Helper function to use with std::find_if to find the index of any
|
||||||
// values in the top k struct which match a given idx.
|
// values in the top k struct which match a given idx.
|
||||||
struct IsNeighbor {
|
struct IsNeighbor {
|
||||||
@@ -120,6 +119,7 @@ struct IsNeighbor {
|
|||||||
int neighbor_idx;
|
int neighbor_idx;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
void RasterizeMeshesNaiveCpu_worker(
|
void RasterizeMeshesNaiveCpu_worker(
|
||||||
const int start_yi,
|
const int start_yi,
|
||||||
const int end_yi,
|
const int end_yi,
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
|||||||
float* zbuf, // (N, H, W, K)
|
float* zbuf, // (N, H, W, K)
|
||||||
float* pix_dists) { // (N, H, W, K)
|
float* pix_dists) { // (N, H, W, K)
|
||||||
// Simple version: One thread per output pixel
|
// Simple version: One thread per output pixel
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||||
// Convert linear index to 3D index
|
// Convert linear index to 3D index
|
||||||
const int n = i / (H * W); // Batch index
|
const int n = i / (H * W); // Batch index
|
||||||
@@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
float* pix_dists) { // (N, H, W, K)
|
float* pix_dists) { // (N, H, W, K)
|
||||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||||
const int num_pixels = N * BH * BW * bin_size * bin_size;
|
const int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||||
const auto num_threads = gridDim.x * blockDim.x;
|
const int num_threads = gridDim.x * blockDim.x;
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||||
// Convert linear index into bin and pixel indices. We make the within
|
// Convert linear index into bin and pixel indices. We make the within
|
||||||
@@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
|
|||||||
float* grad_points) { // (P, 3)
|
float* grad_points) { // (P, 3)
|
||||||
// Parallelized over each of K points per pixel, for each pixel in images of
|
// Parallelized over each of K points per pixel, for each pixel in images of
|
||||||
// size H * W, for each image in the batch of size N.
|
// size H * W, for each image in the batch of size N.
|
||||||
auto num_threads = gridDim.x * blockDim.x;
|
int num_threads = gridDim.x * blockDim.x;
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
for (int i = tid; i < N * H * W * K; i += num_threads) {
|
for (int i = tid; i < N * H * W * K; i += num_threads) {
|
||||||
// const int n = i / (H * W * K); // batch index (not needed).
|
// const int n = i / (H * W * K); // batch index (not needed).
|
||||||
const int yxk = i % (H * W * K);
|
const int yxk = i % (H * W * K);
|
||||||
|
|||||||
@@ -91,10 +91,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(cloud_to_packed_first_idx);
|
|
||||||
CHECK_CPU(num_points_per_cloud);
|
|
||||||
CHECK_CPU(radius);
|
|
||||||
return RasterizePointsNaiveCpu(
|
return RasterizePointsNaiveCpu(
|
||||||
points,
|
points,
|
||||||
cloud_to_packed_first_idx,
|
cloud_to_packed_first_idx,
|
||||||
@@ -170,10 +166,6 @@ torch::Tensor RasterizePointsCoarse(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(cloud_to_packed_first_idx);
|
|
||||||
CHECK_CPU(num_points_per_cloud);
|
|
||||||
CHECK_CPU(radius);
|
|
||||||
return RasterizePointsCoarseCpu(
|
return RasterizePointsCoarseCpu(
|
||||||
points,
|
points,
|
||||||
cloud_to_packed_first_idx,
|
cloud_to_packed_first_idx,
|
||||||
@@ -240,8 +232,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(bin_points);
|
|
||||||
AT_ERROR("NOT IMPLEMENTED");
|
AT_ERROR("NOT IMPLEMENTED");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -294,10 +284,6 @@ torch::Tensor RasterizePointsBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(idxs);
|
|
||||||
CHECK_CPU(grad_zbuf);
|
|
||||||
CHECK_CPU(grad_dists);
|
|
||||||
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
|
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ __global__ void FarthestPointSamplingKernel(
|
|||||||
__shared__ int64_t selected_store;
|
__shared__ int64_t selected_store;
|
||||||
|
|
||||||
// Get constants
|
// Get constants
|
||||||
|
const int64_t N = points.size(0);
|
||||||
|
const int64_t P = points.size(1);
|
||||||
const int64_t D = points.size(2);
|
const int64_t D = points.size(2);
|
||||||
|
|
||||||
// Get batch index and thread index
|
// Get batch index and thread index
|
||||||
@@ -107,8 +109,7 @@ at::Tensor FarthestPointSamplingCuda(
|
|||||||
const at::Tensor& points, // (N, P, 3)
|
const at::Tensor& points, // (N, P, 3)
|
||||||
const at::Tensor& lengths, // (N,)
|
const at::Tensor& lengths, // (N,)
|
||||||
const at::Tensor& K, // (N,)
|
const at::Tensor& K, // (N,)
|
||||||
const at::Tensor& start_idxs,
|
const at::Tensor& start_idxs) {
|
||||||
const int64_t max_K_known = -1) {
|
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
|
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
|
||||||
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
|
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
|
||||||
@@ -130,12 +131,7 @@ at::Tensor FarthestPointSamplingCuda(
|
|||||||
|
|
||||||
const int64_t N = points.size(0);
|
const int64_t N = points.size(0);
|
||||||
const int64_t P = points.size(1);
|
const int64_t P = points.size(1);
|
||||||
int64_t max_K;
|
const int64_t max_K = at::max(K).item<int64_t>();
|
||||||
if (max_K_known > 0) {
|
|
||||||
max_K = max_K_known;
|
|
||||||
} else {
|
|
||||||
max_K = at::max(K).item<int64_t>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the output tensor with the sampled indices
|
// Initialize the output tensor with the sampled indices
|
||||||
auto idxs = at::full({N, max_K}, -1, lengths.options());
|
auto idxs = at::full({N, max_K}, -1, lengths.options());
|
||||||
|
|||||||
@@ -43,8 +43,7 @@ at::Tensor FarthestPointSamplingCuda(
|
|||||||
const at::Tensor& points,
|
const at::Tensor& points,
|
||||||
const at::Tensor& lengths,
|
const at::Tensor& lengths,
|
||||||
const at::Tensor& K,
|
const at::Tensor& K,
|
||||||
const at::Tensor& start_idxs,
|
const at::Tensor& start_idxs);
|
||||||
const int64_t max_K_known = -1);
|
|
||||||
|
|
||||||
at::Tensor FarthestPointSamplingCpu(
|
at::Tensor FarthestPointSamplingCpu(
|
||||||
const at::Tensor& points,
|
const at::Tensor& points,
|
||||||
@@ -57,23 +56,17 @@ at::Tensor FarthestPointSampling(
|
|||||||
const at::Tensor& points,
|
const at::Tensor& points,
|
||||||
const at::Tensor& lengths,
|
const at::Tensor& lengths,
|
||||||
const at::Tensor& K,
|
const at::Tensor& K,
|
||||||
const at::Tensor& start_idxs,
|
const at::Tensor& start_idxs) {
|
||||||
const int64_t max_K_known = -1) {
|
|
||||||
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
|
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CUDA(lengths);
|
CHECK_CUDA(lengths);
|
||||||
CHECK_CUDA(K);
|
CHECK_CUDA(K);
|
||||||
CHECK_CUDA(start_idxs);
|
CHECK_CUDA(start_idxs);
|
||||||
return FarthestPointSamplingCuda(
|
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
|
||||||
points, lengths, K, start_idxs, max_K_known);
|
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(points);
|
|
||||||
CHECK_CPU(lengths);
|
|
||||||
CHECK_CPU(K);
|
|
||||||
CHECK_CPU(start_idxs);
|
|
||||||
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
|
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,8 +71,6 @@ inline void SamplePdf(
|
|||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
CHECK_CPU(weights);
|
|
||||||
CHECK_CPU(outputs);
|
|
||||||
CHECK_CONTIGUOUS(outputs);
|
CHECK_CONTIGUOUS(outputs);
|
||||||
SamplePdfCpu(bins, weights, outputs, eps);
|
SamplePdfCpu(bins, weights, outputs, eps);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,7 +99,8 @@ namespace {
|
|||||||
// and increment it via template recursion until it is equal to the run-time
|
// and increment it via template recursion until it is equal to the run-time
|
||||||
// argument N.
|
// argument N.
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t> class Kernel,
|
template <typename, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -123,7 +124,8 @@ struct DispatchKernelHelper1D {
|
|||||||
// 1D dispatch: Specialization when curN == maxN
|
// 1D dispatch: Specialization when curN == maxN
|
||||||
// We need this base case to avoid infinite template recursion.
|
// We need this base case to avoid infinite template recursion.
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t> class Kernel,
|
template <typename, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -143,7 +145,8 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
|||||||
// the run-time values of N and M, at which point we dispatch to the run
|
// the run-time values of N and M, at which point we dispatch to the run
|
||||||
// method of the kernel.
|
// method of the kernel.
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -200,7 +203,8 @@ struct DispatchKernelHelper2D {
|
|||||||
|
|
||||||
// 2D dispatch, specialization for curN == maxN
|
// 2D dispatch, specialization for curN == maxN
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -239,7 +243,8 @@ struct DispatchKernelHelper2D<
|
|||||||
|
|
||||||
// 2D dispatch, specialization for curM == maxM
|
// 2D dispatch, specialization for curM == maxM
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -278,7 +283,8 @@ struct DispatchKernelHelper2D<
|
|||||||
|
|
||||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -307,7 +313,8 @@ struct DispatchKernelHelper2D<
|
|||||||
|
|
||||||
// This is the function we expect users to call to dispatch to 1D functions
|
// This is the function we expect users to call to dispatch to 1D functions
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t> class Kernel,
|
template <typename, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
@@ -323,7 +330,8 @@ void DispatchKernel1D(const int64_t N, Args... args) {
|
|||||||
|
|
||||||
// This is the function we expect users to call to dispatch to 2D functions
|
// This is the function we expect users to call to dispatch to 2D functions
|
||||||
template <
|
template <
|
||||||
template <typename, int64_t, int64_t> class Kernel,
|
template <typename, int64_t, int64_t>
|
||||||
|
class Kernel,
|
||||||
typename T,
|
typename T,
|
||||||
int64_t minN,
|
int64_t minN,
|
||||||
int64_t maxN,
|
int64_t maxN,
|
||||||
|
|||||||
@@ -376,6 +376,8 @@ PointLineDistanceBackward(
|
|||||||
float tt = t_top / t_bot;
|
float tt = t_top / t_bot;
|
||||||
tt = __saturatef(tt);
|
tt = __saturatef(tt);
|
||||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||||
|
const float2 d = p - p_proj;
|
||||||
|
const float dist = sqrt(dot(d, d));
|
||||||
|
|
||||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||||
|
|||||||
@@ -15,7 +15,3 @@
|
|||||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||||
CHECK_CUDA(x); \
|
CHECK_CUDA(x); \
|
||||||
CHECK_CONTIGUOUS(x)
|
CHECK_CONTIGUOUS(x)
|
||||||
#define CHECK_CPU(x) \
|
|
||||||
TORCH_CHECK( \
|
|
||||||
x.device().type() == torch::kCPU, \
|
|
||||||
"Cannot use CPU implementation: " #x " not on CPU.")
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
|
|||||||
):
|
):
|
||||||
synset_set.add(synset)
|
synset_set.add(synset)
|
||||||
elif (synset in self.synset_inv.keys()) and (
|
elif (synset in self.synset_inv.keys()) and (
|
||||||
path.isdir(path.join(data_dir, self.synset_inv[synset]))
|
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||||
):
|
):
|
||||||
synset_set.add(self.synset_inv[synset])
|
synset_set.add(self.synset_inv[synset])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
|||||||
|
|
||||||
collated_dict["mesh"] = None
|
collated_dict["mesh"] = None
|
||||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||||
|
|
||||||
textures = None
|
textures = None
|
||||||
if "textures" in collated_dict:
|
if "textures" in collated_dict:
|
||||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset import orm_types, types
|
|
||||||
|
from pytorch3d.implicitron.dataset import types
|
||||||
from pytorch3d.implicitron.dataset.utils import (
|
from pytorch3d.implicitron.dataset.utils import (
|
||||||
adjust_camera_to_bbox_crop_,
|
adjust_camera_to_bbox_crop_,
|
||||||
adjust_camera_to_image_scale_,
|
adjust_camera_to_image_scale_,
|
||||||
@@ -47,12 +48,8 @@ from pytorch3d.implicitron.dataset.utils import (
|
|||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
|
||||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||||
|
|
||||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
|
||||||
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrameData(Mapping[str, Any]):
|
class FrameData(Mapping[str, Any]):
|
||||||
@@ -125,9 +122,9 @@ class FrameData(Mapping[str, Any]):
|
|||||||
meta: A dict for storing additional frame information.
|
meta: A dict for storing additional frame information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_number: Optional[torch.LongTensor] = None
|
frame_number: Optional[torch.LongTensor]
|
||||||
sequence_name: Union[str, List[str]] = ""
|
sequence_name: Union[str, List[str]]
|
||||||
sequence_category: Union[str, List[str]] = ""
|
sequence_category: Union[str, List[str]]
|
||||||
frame_timestamp: Optional[torch.Tensor] = None
|
frame_timestamp: Optional[torch.Tensor] = None
|
||||||
image_size_hw: Optional[torch.LongTensor] = None
|
image_size_hw: Optional[torch.LongTensor] = None
|
||||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||||
@@ -158,7 +155,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
new_params = {}
|
new_params = {}
|
||||||
for field_name in iter(self):
|
for field_name in iter(self):
|
||||||
value = getattr(self, field_name)
|
value = getattr(self, field_name)
|
||||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||||
new_params[field_name] = value.to(*args, **kwargs)
|
new_params[field_name] = value.to(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
new_params[field_name] = value
|
new_params[field_name] = value
|
||||||
@@ -420,6 +417,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
for f in fields(elem):
|
for f in fields(elem):
|
||||||
if not f.init:
|
if not f.init:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
list_values = override_fields.get(
|
list_values = override_fields.get(
|
||||||
f.name, [getattr(d, f.name) for d in batch]
|
f.name, [getattr(d, f.name) for d in batch]
|
||||||
)
|
)
|
||||||
@@ -428,7 +426,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
if all(list_value is not None for list_value in list_values)
|
if all(list_value is not None for list_value in list_values)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return type(elem)(**collated)
|
return cls(**collated)
|
||||||
|
|
||||||
elif isinstance(elem, Pointclouds):
|
elif isinstance(elem, Pointclouds):
|
||||||
return join_pointclouds_as_batch(batch)
|
return join_pointclouds_as_batch(batch)
|
||||||
@@ -436,8 +434,6 @@ class FrameData(Mapping[str, Any]):
|
|||||||
elif isinstance(elem, CamerasBase):
|
elif isinstance(elem, CamerasBase):
|
||||||
# TODO: don't store K; enforce working in NDC space
|
# TODO: don't store K; enforce working in NDC space
|
||||||
return join_cameras_as_batch(batch)
|
return join_cameras_as_batch(batch)
|
||||||
elif isinstance(elem, Meshes):
|
|
||||||
return join_meshes_as_batch(batch)
|
|
||||||
else:
|
else:
|
||||||
return torch.utils.data.dataloader.default_collate(batch)
|
return torch.utils.data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
@@ -458,8 +454,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: FrameAnnotationT,
|
frame_annotation: types.FrameAnnotation,
|
||||||
sequence_annotation: SequenceAnnotationT,
|
sequence_annotation: types.SequenceAnnotation,
|
||||||
*,
|
*,
|
||||||
load_blobs: bool = True,
|
load_blobs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -545,8 +541,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: FrameAnnotationT,
|
frame_annotation: types.FrameAnnotation,
|
||||||
sequence_annotation: SequenceAnnotationT,
|
sequence_annotation: types.SequenceAnnotation,
|
||||||
*,
|
*,
|
||||||
load_blobs: bool = True,
|
load_blobs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -590,81 +586,58 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_root = self.dataset_root
|
fg_mask_np: Optional[np.ndarray] = None
|
||||||
mask_annotation = frame_annotation.mask
|
mask_annotation = frame_annotation.mask
|
||||||
depth_annotation = frame_annotation.depth
|
|
||||||
image_path: str | None = None
|
|
||||||
mask_path: str | None = None
|
|
||||||
depth_path: str | None = None
|
|
||||||
pcl_path: str | None = None
|
|
||||||
if dataset_root is not None: # set all paths even if we won’t load blobs
|
|
||||||
if frame_annotation.image.path is not None:
|
|
||||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
|
||||||
frame_data.image_path = image_path
|
|
||||||
|
|
||||||
if mask_annotation is not None and mask_annotation.path:
|
|
||||||
mask_path = os.path.join(dataset_root, mask_annotation.path)
|
|
||||||
frame_data.mask_path = mask_path
|
|
||||||
|
|
||||||
if depth_annotation is not None and depth_annotation.path is not None:
|
|
||||||
depth_path = os.path.join(dataset_root, depth_annotation.path)
|
|
||||||
frame_data.depth_path = depth_path
|
|
||||||
|
|
||||||
if point_cloud is not None:
|
|
||||||
pcl_path = os.path.join(dataset_root, point_cloud.path)
|
|
||||||
frame_data.sequence_point_cloud_path = pcl_path
|
|
||||||
|
|
||||||
fg_mask_np: np.ndarray | None = None
|
|
||||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
|
||||||
|
|
||||||
if mask_annotation is not None:
|
if mask_annotation is not None:
|
||||||
if load_blobs and self.load_masks and mask_path:
|
if load_blobs and self.load_masks:
|
||||||
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
|
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||||
|
frame_data.mask_path = mask_path
|
||||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||||
|
|
||||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||||
|
if bbox_xywh is None and fg_mask_np is not None:
|
||||||
|
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||||
|
|
||||||
|
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||||
|
|
||||||
if frame_annotation.image is not None:
|
if frame_annotation.image is not None:
|
||||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||||
frame_data.image_size_hw = image_size_hw # original image size
|
frame_data.image_size_hw = image_size_hw # original image size
|
||||||
# image size after crop/resize
|
# image size after crop/resize
|
||||||
frame_data.effective_image_size_hw = image_size_hw
|
frame_data.effective_image_size_hw = image_size_hw
|
||||||
|
image_path = None
|
||||||
|
dataset_root = self.dataset_root
|
||||||
|
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||||
|
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||||
|
frame_data.image_path = image_path
|
||||||
|
|
||||||
if load_blobs and self.load_images:
|
if load_blobs and self.load_images:
|
||||||
if image_path is None:
|
if image_path is None:
|
||||||
raise ValueError("Image path is required to load images.")
|
raise ValueError("Image path is required to load images.")
|
||||||
|
|
||||||
no_mask = fg_mask_np is None # didn’t read the mask file
|
image_np = load_image(self._local_path(image_path))
|
||||||
image_np = load_image(
|
|
||||||
self._local_path(image_path), try_read_alpha=no_mask
|
|
||||||
)
|
|
||||||
if image_np.shape[0] == 4: # RGBA image
|
|
||||||
if no_mask:
|
|
||||||
fg_mask_np = image_np[3:]
|
|
||||||
frame_data.fg_probability = safe_as_tensor(
|
|
||||||
fg_mask_np, torch.float
|
|
||||||
)
|
|
||||||
|
|
||||||
image_np = image_np[:3]
|
|
||||||
|
|
||||||
frame_data.image_rgb = self._postprocess_image(
|
frame_data.image_rgb = self._postprocess_image(
|
||||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||||
)
|
)
|
||||||
|
|
||||||
if bbox_xywh is None and fg_mask_np is not None:
|
if (
|
||||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
load_blobs
|
||||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
and self.load_depths
|
||||||
|
and frame_annotation.depth is not None
|
||||||
if load_blobs and self.load_depths and depth_path is not None:
|
and frame_annotation.depth.path is not None
|
||||||
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
|
):
|
||||||
frame_annotation, depth_path, fg_mask_np
|
(
|
||||||
)
|
frame_data.depth_map,
|
||||||
|
frame_data.depth_path,
|
||||||
|
frame_data.depth_mask,
|
||||||
|
) = self._load_mask_depth(frame_annotation, fg_mask_np)
|
||||||
|
|
||||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||||
assert pcl_path is not None
|
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
||||||
frame_data.sequence_point_cloud = load_pointcloud(
|
frame_data.sequence_point_cloud = load_pointcloud(
|
||||||
self._local_path(pcl_path), max_points=self.max_points
|
self._local_path(pcl_path), max_points=self.max_points
|
||||||
)
|
)
|
||||||
|
frame_data.sequence_point_cloud_path = pcl_path
|
||||||
|
|
||||||
if frame_annotation.viewpoint is not None:
|
if frame_annotation.viewpoint is not None:
|
||||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||||
@@ -680,14 +653,18 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
|
def _load_fg_probability(
|
||||||
fg_probability = load_mask(self._local_path(path))
|
self, entry: types.FrameAnnotation
|
||||||
|
) -> Tuple[np.ndarray, str]:
|
||||||
|
assert self.dataset_root is not None and entry.mask is not None
|
||||||
|
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||||
|
fg_probability = load_mask(self._local_path(full_path))
|
||||||
if fg_probability.shape[-2:] != entry.image.size:
|
if fg_probability.shape[-2:] != entry.image.size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||||
)
|
)
|
||||||
|
|
||||||
return fg_probability
|
return fg_probability, full_path
|
||||||
|
|
||||||
def _postprocess_image(
|
def _postprocess_image(
|
||||||
self,
|
self,
|
||||||
@@ -708,14 +685,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def _load_mask_depth(
|
def _load_mask_depth(
|
||||||
self,
|
self,
|
||||||
entry: FrameAnnotationT,
|
entry: types.FrameAnnotation,
|
||||||
path: str,
|
|
||||||
fg_mask: Optional[np.ndarray],
|
fg_mask: Optional[np.ndarray],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||||
entry_depth = entry.depth
|
entry_depth = entry.depth
|
||||||
dataset_root = self.dataset_root
|
dataset_root = self.dataset_root
|
||||||
assert dataset_root is not None
|
assert dataset_root is not None
|
||||||
assert entry_depth is not None
|
assert entry_depth is not None and entry_depth.path is not None
|
||||||
|
path = os.path.join(dataset_root, entry_depth.path)
|
||||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||||
|
|
||||||
if self.mask_depths:
|
if self.mask_depths:
|
||||||
@@ -729,11 +706,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
else:
|
else:
|
||||||
depth_mask = (depth_map > 0.0).astype(np.float32)
|
depth_mask = (depth_map > 0.0).astype(np.float32)
|
||||||
|
|
||||||
return torch.tensor(depth_map), torch.tensor(depth_mask)
|
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
||||||
|
|
||||||
def _get_pytorch3d_camera(
|
def _get_pytorch3d_camera(
|
||||||
self,
|
self,
|
||||||
entry: FrameAnnotationT,
|
entry: types.FrameAnnotation,
|
||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
entry_viewpoint = entry.viewpoint
|
entry_viewpoint = entry.viewpoint
|
||||||
assert entry_viewpoint is not None
|
assert entry_viewpoint is not None
|
||||||
@@ -762,6 +739,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _fix_point_cloud_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Fix up a point cloud path from the dataset.
|
||||||
|
Some files in Co3Dv2 have an accidental absolute path stored.
|
||||||
|
"""
|
||||||
|
unwanted_prefix = (
|
||||||
|
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
||||||
|
)
|
||||||
|
if path.startswith(unwanted_prefix):
|
||||||
|
path = path[len(unwanted_prefix) :]
|
||||||
|
assert self.dataset_root is not None
|
||||||
|
return os.path.join(self.dataset_root, path)
|
||||||
|
|
||||||
def _local_path(self, path: str) -> str:
|
def _local_path(self, path: str) -> str:
|
||||||
if self.path_manager is None:
|
if self.path_manager is None:
|
||||||
return path
|
return path
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
|
|||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
@@ -326,9 +327,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
assert os.path.normpath(
|
assert os.path.normpath(
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.frame_annots[idx]["frame_annotation"].image.path
|
self.frame_annots[idx]["frame_annotation"].image.path
|
||||||
) == os.path.normpath(path), (
|
) == os.path.normpath(
|
||||||
f"Inconsistent frame indices {seq_name, frame_no, path}."
|
path
|
||||||
)
|
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
dataset_idx = [
|
dataset_idx = [
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
|
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||||
from .json_index_dataset import JsonIndexDataset
|
from .json_index_dataset import JsonIndexDataset
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DATASET_TYPE_KNOWN,
|
DATASET_TYPE_KNOWN,
|
||||||
DATASET_TYPE_TEST,
|
DATASET_TYPE_TEST,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
@@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
registry,
|
registry,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -220,6 +222,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
|||||||
self.dataset_map = dataset_map
|
self.dataset_map = dataset_map
|
||||||
|
|
||||||
def _load_category(self, category: str) -> DatasetMap:
|
def _load_category(self, category: str) -> DatasetMap:
|
||||||
|
|
||||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||||
sequence_file = os.path.join(
|
sequence_file = os.path.join(
|
||||||
self.dataset_root, category, "sequence_annotations.jgz"
|
self.dataset_root, category, "sequence_annotations.jgz"
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
|
||||||
from .load_llff import load_llff_data
|
from .load_llff import load_llff_data
|
||||||
|
|
||||||
from .single_sequence_dataset import (
|
from .single_sequence_dataset import (
|
||||||
_interpret_blender_cameras,
|
_interpret_blender_cameras,
|
||||||
SingleSceneDatasetMapProviderBase,
|
SingleSceneDatasetMapProviderBase,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
@@ -74,6 +75,7 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
|
|||||||
def _load_data(
|
def _load_data(
|
||||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||||
):
|
):
|
||||||
|
|
||||||
poses_arr = np.load(
|
poses_arr = np.load(
|
||||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||||
)
|
)
|
||||||
@@ -162,6 +164,7 @@ def ptstocam(pts, c2w):
|
|||||||
|
|
||||||
|
|
||||||
def poses_avg(poses):
|
def poses_avg(poses):
|
||||||
|
|
||||||
hwf = poses[0, :3, -1:]
|
hwf = poses[0, :3, -1:]
|
||||||
|
|
||||||
center = poses[:, :3, 3].mean(0)
|
center = poses[:, :3, 3].mean(0)
|
||||||
@@ -189,6 +192,7 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
|||||||
|
|
||||||
|
|
||||||
def recenter_poses(poses):
|
def recenter_poses(poses):
|
||||||
|
|
||||||
poses_ = poses + 0
|
poses_ = poses + 0
|
||||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||||
c2w = poses_avg(poses)
|
c2w = poses_avg(poses)
|
||||||
@@ -252,6 +256,7 @@ def spherify_poses(poses, bds):
|
|||||||
new_poses = []
|
new_poses = []
|
||||||
|
|
||||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||||
|
|
||||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||||
up = np.array([0, 0, -1.0])
|
up = np.array([0, 0, -1.0])
|
||||||
|
|
||||||
@@ -306,6 +311,7 @@ def load_llff_data(
|
|||||||
path_zflat=False,
|
path_zflat=False,
|
||||||
path_manager=None,
|
path_manager=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
poses, bds, imgs = _load_data(
|
poses, bds, imgs = _load_data(
|
||||||
basedir, factor=factor, path_manager=path_manager
|
basedir, factor=factor, path_manager=path_manager
|
||||||
) # factor=8 downsamples original imgs by 8x
|
) # factor=8 downsamples original imgs by 8x
|
||||||
|
|||||||
@@ -4,8 +4,6 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# pyre-unsafe
|
|
||||||
|
|
||||||
# This functionality requires SQLAlchemy 2.0 or later.
|
# This functionality requires SQLAlchemy 2.0 or later.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@@ -13,6 +11,7 @@ import struct
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.types import (
|
from pytorch3d.implicitron.dataset.types import (
|
||||||
DepthAnnotation,
|
DepthAnnotation,
|
||||||
ImageAnnotation,
|
ImageAnnotation,
|
||||||
@@ -21,6 +20,7 @@ from pytorch3d.implicitron.dataset.types import (
|
|||||||
VideoAnnotation,
|
VideoAnnotation,
|
||||||
ViewpointAnnotation,
|
ViewpointAnnotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy import LargeBinary
|
from sqlalchemy import LargeBinary
|
||||||
from sqlalchemy.orm import (
|
from sqlalchemy.orm import (
|
||||||
composite,
|
composite,
|
||||||
|
|||||||
@@ -4,14 +4,11 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# pyre-unsafe
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import urllib
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, Field, field
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@@ -31,9 +28,10 @@ import pandas as pd
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
from pytorch3d.implicitron.dataset.frame_data import (
|
|
||||||
|
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
||||||
FrameData,
|
FrameData,
|
||||||
FrameDataBuilder, # noqa
|
FrameDataBuilder,
|
||||||
FrameDataBuilderBase,
|
FrameDataBuilderBase,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@@ -41,7 +39,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
ReplaceableBase,
|
ReplaceableBase,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||||
|
|
||||||
@@ -53,7 +51,7 @@ _SET_LISTS_TABLE: str = "set_lists"
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||||
"""
|
"""
|
||||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||||
The length is returned after all sequence and frame filters are applied (see param
|
The length is returned after all sequence and frame filters are applied (see param
|
||||||
@@ -90,7 +88,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
engine verbatim. Don’t expose it to end users of your application!
|
engine verbatim. Don’t expose it to end users of your application!
|
||||||
pick_categories: Restrict the dataset to the given list of categories.
|
pick_categories: Restrict the dataset to the given list of categories.
|
||||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||||
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
|
||||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||||
sequences within each category (applies after all other sequence filters
|
sequences within each category (applies after all other sequence filters
|
||||||
@@ -105,16 +102,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
more frames than that; applied after other frame-level filters.
|
more frames than that; applied after other frame-level filters.
|
||||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||||
random frames per sequence.
|
random frames per sequence.
|
||||||
preload_metadata: If True, the metadata is preloaded into memory.
|
|
||||||
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
|
||||||
scoped_session: If True, allows different parts of the code to share
|
|
||||||
a global session to access the database.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||||
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
|
|
||||||
SqlSequenceAnnotation
|
|
||||||
)
|
|
||||||
|
|
||||||
sqlite_metadata_file: str = ""
|
sqlite_metadata_file: str = ""
|
||||||
dataset_root: Optional[str] = None
|
dataset_root: Optional[str] = None
|
||||||
@@ -127,7 +117,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
pick_categories: Tuple[str, ...] = ()
|
pick_categories: Tuple[str, ...] = ()
|
||||||
|
|
||||||
pick_sequences: Tuple[str, ...] = ()
|
pick_sequences: Tuple[str, ...] = ()
|
||||||
pick_sequences_sql_clause: Optional[str] = None
|
|
||||||
exclude_sequences: Tuple[str, ...] = ()
|
exclude_sequences: Tuple[str, ...] = ()
|
||||||
limit_sequences_per_category_to: int = 0
|
limit_sequences_per_category_to: int = 0
|
||||||
limit_sequences_to: int = 0
|
limit_sequences_to: int = 0
|
||||||
@@ -135,22 +124,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||||
preload_metadata: bool = False
|
|
||||||
precompute_seq_to_idx: bool = False
|
|
||||||
# we set it manually in the constructor
|
# we set it manually in the constructor
|
||||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
# _index: pd.DataFrame = field(init=False)
|
||||||
_sql_engine: sa.engine.Engine = field(
|
|
||||||
init=False, metadata={"omegaconf_ignore": True}
|
|
||||||
)
|
|
||||||
eval_batches: Optional[List[Any]] = field(
|
|
||||||
init=False, metadata={"omegaconf_ignore": True}
|
|
||||||
)
|
|
||||||
|
|
||||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
frame_data_builder: FrameDataBuilderBase
|
||||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||||
|
|
||||||
scoped_session: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if sa.__version__ < "2.0":
|
if sa.__version__ < "2.0":
|
||||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||||
@@ -159,28 +138,19 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
raise ValueError("sqlite_metadata_file must be set")
|
raise ValueError("sqlite_metadata_file must be set")
|
||||||
|
|
||||||
if self.dataset_root:
|
if self.dataset_root:
|
||||||
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
|
frame_builder_type = self.frame_data_builder_class_type
|
||||||
getattr(self, frame_args)["dataset_root"] = self.dataset_root
|
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
||||||
getattr(self, frame_args)["path_manager"] = self.path_manager
|
"dataset_root"
|
||||||
|
] = self.dataset_root
|
||||||
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
self.frame_data_builder.path_manager = self.path_manager
|
||||||
|
|
||||||
if self.path_manager is not None:
|
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
||||||
self.sqlite_metadata_file = self.path_manager.get_local_path(
|
|
||||||
self.sqlite_metadata_file
|
|
||||||
)
|
|
||||||
self.subset_lists_file = self.path_manager.get_local_path(
|
|
||||||
self.subset_lists_file
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: sqlite-specific args (read-only mode).
|
|
||||||
self._sql_engine = sa.create_engine(
|
self._sql_engine = sa.create_engine(
|
||||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.preload_metadata:
|
|
||||||
self._sql_engine = self._preload_database(self._sql_engine)
|
|
||||||
|
|
||||||
sequences = self._get_filtered_sequences_if_any()
|
sequences = self._get_filtered_sequences_if_any()
|
||||||
|
|
||||||
if self.subsets:
|
if self.subsets:
|
||||||
@@ -196,29 +166,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
if len(index) == 0:
|
if len(index) == 0:
|
||||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||||
|
|
||||||
self._index = index.set_index(["sequence_name", "frame_number"])
|
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
||||||
|
|
||||||
self.eval_batches = None
|
self.eval_batches = None # pyre-ignore
|
||||||
if self.eval_batches_file:
|
if self.eval_batches_file:
|
||||||
self.eval_batches = self._load_filter_eval_batches()
|
self.eval_batches = self._load_filter_eval_batches()
|
||||||
|
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
if self.scoped_session:
|
|
||||||
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
|
||||||
|
|
||||||
if self.precompute_seq_to_idx:
|
|
||||||
# This is deprecated and will be removed in the future.
|
|
||||||
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
|
||||||
logger.warning(
|
|
||||||
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
|
||||||
)
|
|
||||||
self._index["rowid"] = np.arange(len(self._index))
|
|
||||||
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
|
||||||
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
|
||||||
del self._index["rowid"]
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
# pyre-ignore[16]
|
||||||
return len(self._index)
|
return len(self._index)
|
||||||
|
|
||||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||||
@@ -275,18 +232,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self.frame_annotations_type.frame_number
|
self.frame_annotations_type.frame_number
|
||||||
== int(frame), # cast from np.int64
|
== int(frame), # cast from np.int64
|
||||||
)
|
)
|
||||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
||||||
self.sequence_annotations_type.sequence_name == seq
|
SqlSequenceAnnotation.sequence_name == seq
|
||||||
)
|
)
|
||||||
if self.scoped_session:
|
with Session(self._sql_engine) as session:
|
||||||
# pyre-ignore
|
entry = session.scalars(stmt).one()
|
||||||
with scoped_session(self._session_factory)() as session:
|
seq_metadata = session.scalars(seq_stmt).one()
|
||||||
entry = session.scalars(stmt).one()
|
|
||||||
seq_metadata = session.scalars(seq_stmt).one()
|
|
||||||
else:
|
|
||||||
with Session(self._sql_engine) as session:
|
|
||||||
entry = session.scalars(stmt).one()
|
|
||||||
seq_metadata = session.scalars(seq_stmt).one()
|
|
||||||
|
|
||||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||||
|
|
||||||
@@ -299,6 +250,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
# pyre-ignore[16]
|
||||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||||
|
|
||||||
def sequence_names(self) -> Iterable[str]:
|
def sequence_names(self) -> Iterable[str]:
|
||||||
@@ -308,10 +260,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
# override
|
# override
|
||||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.sequence_annotations_type.category,
|
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
||||||
self.sequence_annotations_type.sequence_name,
|
|
||||||
).where( # we limit results to sequences that have frames after all filters
|
).where( # we limit results to sequences that have frames after all filters
|
||||||
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
|
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
||||||
)
|
)
|
||||||
with self._sql_engine.connect() as connection:
|
with self._sql_engine.connect() as connection:
|
||||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||||
@@ -384,31 +335,17 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
rows = self._index.index.get_loc(seq_name)
|
rows = self._index.index.get_loc(seq_name)
|
||||||
if isinstance(rows, slice):
|
if isinstance(rows, slice):
|
||||||
assert rows.stop is not None, "Unexpected result from pandas"
|
assert rows.stop is not None, "Unexpected result from pandas"
|
||||||
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
|
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||||
else:
|
else:
|
||||||
rows_seq = list(np.where(rows)[0])
|
rows = np.where(rows)[0]
|
||||||
|
|
||||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||||
rows_seq, seq_name, subset_filter
|
rows, seq_name, subset_filter
|
||||||
)
|
)
|
||||||
index_slice["idx"] = idx
|
index_slice["idx"] = idx
|
||||||
|
|
||||||
yield from index_slice.itertuples(index=False)
|
yield from index_slice.itertuples(index=False)
|
||||||
|
|
||||||
# override
|
|
||||||
def sequence_indices_in_order(
|
|
||||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
|
||||||
) -> Iterator[int]:
|
|
||||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
|
||||||
only dataset indices.
|
|
||||||
"""
|
|
||||||
if self.precompute_seq_to_idx and subset_filter is None:
|
|
||||||
# pyre-ignore
|
|
||||||
yield from self._seq_to_indices[seq_name]
|
|
||||||
else:
|
|
||||||
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
|
||||||
yield idx
|
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -442,35 +379,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
or self.limit_sequences_to > 0
|
or self.limit_sequences_to > 0
|
||||||
or self.limit_sequences_per_category_to > 0
|
or self.limit_sequences_per_category_to > 0
|
||||||
or len(self.pick_sequences) > 0
|
or len(self.pick_sequences) > 0
|
||||||
or self.pick_sequences_sql_clause is not None
|
|
||||||
or len(self.exclude_sequences) > 0
|
or len(self.exclude_sequences) > 0
|
||||||
or len(self.pick_categories) > 0
|
or len(self.pick_categories) > 0
|
||||||
or self.n_frames_per_sequence > 0
|
or self.n_frames_per_sequence > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preload_database(
|
|
||||||
self, source_engine: sa.engine.base.Engine
|
|
||||||
) -> sa.engine.base.Engine:
|
|
||||||
destination_engine = sa.create_engine("sqlite:///:memory:")
|
|
||||||
metadata = sa.MetaData()
|
|
||||||
metadata.reflect(bind=source_engine)
|
|
||||||
metadata.create_all(bind=destination_engine)
|
|
||||||
|
|
||||||
with source_engine.connect() as source_conn:
|
|
||||||
with destination_engine.connect() as destination_conn:
|
|
||||||
for table_obj in metadata.tables.values():
|
|
||||||
# Select all rows from the source table
|
|
||||||
source_rows = source_conn.execute(table_obj.select())
|
|
||||||
|
|
||||||
# Insert rows into the destination table
|
|
||||||
for row in source_rows:
|
|
||||||
destination_conn.execute(table_obj.insert().values(row))
|
|
||||||
|
|
||||||
# Commit the changes for each table
|
|
||||||
destination_conn.commit()
|
|
||||||
|
|
||||||
return destination_engine
|
|
||||||
|
|
||||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||||
# WHERE category IN 'self.pick_categories'
|
# WHERE category IN 'self.pick_categories'
|
||||||
@@ -483,30 +396,25 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
*self._get_pick_filters(),
|
*self._get_pick_filters(),
|
||||||
*self._get_exclude_filters(),
|
*self._get_exclude_filters(),
|
||||||
]
|
]
|
||||||
if pick_sequences_sql_clause := self.pick_sequences_sql_clause:
|
|
||||||
print("Applying the custom SQL clause.")
|
|
||||||
# pyre-ignore[6]: TextClause is compatible with where conditions
|
|
||||||
where_conditions.append(sa.text(pick_sequences_sql_clause))
|
|
||||||
|
|
||||||
def add_where(stmt):
|
def add_where(stmt):
|
||||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
|
|
||||||
if self.limit_sequences_per_category_to <= 0:
|
if self.limit_sequences_per_category_to <= 0:
|
||||||
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
|
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
||||||
else:
|
else:
|
||||||
subquery = sa.select(
|
subquery = sa.select(
|
||||||
self.sequence_annotations_type.sequence_name,
|
SqlSequenceAnnotation.sequence_name,
|
||||||
sa.func.row_number()
|
sa.func.row_number()
|
||||||
.over(
|
.over(
|
||||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||||
partition_by=self.sequence_annotations_type.category,
|
partition_by=SqlSequenceAnnotation.category,
|
||||||
)
|
)
|
||||||
.label("row_number"),
|
.label("row_number"),
|
||||||
)
|
)
|
||||||
|
|
||||||
subquery = add_where(subquery).subquery()
|
subquery = add_where(subquery).subquery()
|
||||||
stmt = sa.select(subquery.c.sequence_name).where(
|
stmt = sa.select(subquery.c.sequence_name).where(
|
||||||
# pyre-ignore[6]: SQLAlchemy column comparison returns ColumnElement, not bool
|
|
||||||
subquery.c.row_number <= self.limit_sequences_per_category_to
|
subquery.c.row_number <= self.limit_sequences_per_category_to
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -536,34 +444,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||||
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
|
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
||||||
|
|
||||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||||
if not self.pick_sequences:
|
if not self.pick_sequences:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||||
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
|
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
||||||
|
|
||||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||||
if not self.exclude_sequences:
|
if not self.exclude_sequences:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||||
return [
|
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||||
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||||
subsets = self.subsets
|
assert self.subsets is not None
|
||||||
assert subsets is not None
|
|
||||||
with open(subset_lists_path, "r") as f:
|
with open(subset_lists_path, "r") as f:
|
||||||
subset_to_seq_frame = json.load(f)
|
subset_to_seq_frame = json.load(f)
|
||||||
|
|
||||||
seq_frame_list = sum(
|
seq_frame_list = sum(
|
||||||
(
|
(
|
||||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||||
for subset in subsets
|
for subset in self.subsets
|
||||||
),
|
),
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
@@ -617,7 +522,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.frame_annotations_type.sequence_name,
|
self.frame_annotations_type.sequence_name,
|
||||||
self.frame_annotations_type.frame_number,
|
self.frame_annotations_type.frame_number,
|
||||||
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
|
).where(self.frame_annotations_type._mask_mass == 0)
|
||||||
with Session(self._sql_engine) as session:
|
with Session(self._sql_engine) as session:
|
||||||
to_remove = session.execute(stmt).all()
|
to_remove = session.execute(stmt).all()
|
||||||
|
|
||||||
@@ -635,10 +540,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
if self.pick_frames_sql_clause:
|
||||||
logger.info("Applying the custom SQL clause.")
|
logger.info("Applying the custom SQL clause.")
|
||||||
# pyre-ignore[6]: TextClause is compatible with where conditions
|
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
|
||||||
pick_frames_criteria.append(sa.text(pick_frames_sql_clause))
|
|
||||||
|
|
||||||
if pick_frames_criteria:
|
if pick_frames_criteria:
|
||||||
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
||||||
@@ -682,7 +586,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.frame_annotations_type.sequence_name,
|
self.frame_annotations_type.sequence_name,
|
||||||
self.frame_annotations_type.frame_number,
|
self.frame_annotations_type.frame_number,
|
||||||
self.frame_annotations_type._image_path, # pyre-ignore[16]
|
self.frame_annotations_type._image_path,
|
||||||
sa.null().label("subset"),
|
sa.null().label("subset"),
|
||||||
)
|
)
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
@@ -696,15 +600,14 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
logger.info(" excluding samples with empty masks")
|
logger.info(" excluding samples with empty masks")
|
||||||
where_conditions.append(
|
where_conditions.append(
|
||||||
sa.or_(
|
sa.or_(
|
||||||
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
self.frame_annotations_type._mask_mass.is_(None),
|
||||||
self.frame_annotations_type._mask_mass != 0,
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
if self.pick_frames_sql_clause:
|
||||||
logger.info(" applying custom SQL clause")
|
logger.info(" applying custom SQL clause")
|
||||||
# pyre-ignore[6]: TextClause is compatible with where conditions
|
where_conditions.append(sa.text(self.pick_frames_sql_clause))
|
||||||
where_conditions.append(sa.text(pick_frames_sql_clause))
|
|
||||||
|
|
||||||
if where_conditions:
|
if where_conditions:
|
||||||
stmt = stmt.where(*where_conditions)
|
stmt = stmt.where(*where_conditions)
|
||||||
@@ -731,9 +634,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
assert self.eval_batches_file
|
assert self.eval_batches_file
|
||||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||||
|
|
||||||
if (
|
if not os.path.isfile(self.eval_batches_file):
|
||||||
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
|
|
||||||
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
|
|
||||||
# The batch indices file does not exist.
|
# The batch indices file does not exist.
|
||||||
# Most probably the user has not specified the root folder.
|
# Most probably the user has not specified the root folder.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -741,8 +642,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
+ "Please specify a correct dataset_root folder."
|
+ "Please specify a correct dataset_root folder."
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_batches_file = self._local_path(self.eval_batches_file)
|
with open(self.eval_batches_file, "r") as f:
|
||||||
with open(eval_batches_file, "r") as f:
|
|
||||||
eval_batches = json.load(f)
|
eval_batches = json.load(f)
|
||||||
|
|
||||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||||
@@ -756,7 +656,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
if pick_sequences:
|
if pick_sequences:
|
||||||
old_len = len(eval_batches)
|
old_len = len(eval_batches)
|
||||||
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
|
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
|
||||||
logger.warning(
|
logger.warn(
|
||||||
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
|
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -764,7 +664,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
old_len = len(eval_batches)
|
old_len = len(eval_batches)
|
||||||
exclude_sequences = set(self.exclude_sequences)
|
exclude_sequences = set(self.exclude_sequences)
|
||||||
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
|
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
|
||||||
logger.warning(
|
logger.warn(
|
||||||
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
|
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -826,15 +726,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self.frame_annotations_type.sequence_name == seq_name,
|
self.frame_annotations_type.sequence_name == seq_name,
|
||||||
self.frame_annotations_type.frame_number.in_(frames),
|
self.frame_annotations_type.frame_number.in_(frames),
|
||||||
)
|
)
|
||||||
frame_no_ts = None
|
|
||||||
|
|
||||||
if self.scoped_session:
|
with self._sql_engine.connect() as connection:
|
||||||
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||||
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
|
||||||
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
|
||||||
else:
|
|
||||||
with self._sql_engine.connect() as connection:
|
|
||||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
|
||||||
|
|
||||||
if len(frame_no_ts) != len(index_slice):
|
if len(frame_no_ts) != len(index_slice):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -864,18 +758,11 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def pre_expand(cls) -> None:
|
|
||||||
# remove dataclass annotations that are not meant to be init params
|
|
||||||
# because they cause troubles for OmegaConf
|
|
||||||
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
|
|
||||||
if isinstance(attr_value, Field) and attr_value.metadata.get(
|
|
||||||
"omegaconf_ignore", False
|
|
||||||
):
|
|
||||||
delattr(cls, attr)
|
|
||||||
del cls.__annotations__[attr]
|
|
||||||
|
|
||||||
|
|
||||||
def _seq_name_to_seed(seq_name) -> int:
|
def _seq_name_to_seed(seq_name) -> int:
|
||||||
"""Generates numbers in [0, 2 ** 28)"""
|
"""Generates numbers in [0, 2 ** 28)"""
|
||||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_as_tensor(data, dtype):
|
||||||
|
return torch.tensor(data, dtype=dtype) if data is not None else None
|
||||||
|
|||||||
@@ -4,15 +4,15 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
# pyre-unsafe
|
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||||
"""
|
"""
|
||||||
Generates the training, validation, and testing dataset objects for
|
Generates the training, validation, and testing dataset objects for
|
||||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||||
@@ -193,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
|||||||
|
|
||||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||||
dataset_class_type: str = "SqlIndexDataset"
|
dataset_class_type: str = "SqlIndexDataset"
|
||||||
dataset: SqlIndexDataset # pyre-ignore [13]
|
dataset: SqlIndexDataset
|
||||||
|
|
||||||
path_manager_factory: PathManagerFactory # pyre-ignore [13]
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -282,14 +282,8 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
|||||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
|
|
||||||
logger.debug("Extracting test dataset.")
|
logger.debug("Extracting test dataset.")
|
||||||
if self.eval_batches_path is None:
|
eval_batches_file = self._get_lists_file("eval_batches")
|
||||||
eval_batches_file = None
|
del common_dataset_kwargs["eval_batches_file"]
|
||||||
else:
|
|
||||||
eval_batches_file = self._get_lists_file("eval_batches")
|
|
||||||
|
|
||||||
if "eval_batches_file" in common_dataset_kwargs:
|
|
||||||
common_dataset_kwargs.pop("eval_batches_file", None)
|
|
||||||
|
|
||||||
test_dataset = dataset_type(
|
test_dataset = dataset_type(
|
||||||
**common_dataset_kwargs,
|
**common_dataset_kwargs,
|
||||||
subsets=self._get_subsets(self.test_subsets, True),
|
subsets=self._get_subsets(self.test_subsets, True),
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
|||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import List, Optional, Tuple, TypeVar, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pytorch3d.io import IO
|
from pytorch3d.io import IO
|
||||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
@@ -86,15 +87,6 @@ def is_train_frame(
|
|||||||
def get_bbox_from_mask(
|
def get_bbox_from_mask(
|
||||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||||
) -> Tuple[int, int, int, int]:
|
) -> Tuple[int, int, int, int]:
|
||||||
# these corner cases need to be handled in order to avoid an infinite loop
|
|
||||||
if mask.size == 0:
|
|
||||||
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
|
|
||||||
return 0, 0, 1, 1
|
|
||||||
|
|
||||||
if not mask.min() >= 0.0:
|
|
||||||
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
|
|
||||||
mask = mask.clip(min=0.0)
|
|
||||||
|
|
||||||
# bbox in xywh
|
# bbox in xywh
|
||||||
masks_for_box = np.zeros_like(mask)
|
masks_for_box = np.zeros_like(mask)
|
||||||
while masks_for_box.sum() <= 1.0:
|
while masks_for_box.sum() <= 1.0:
|
||||||
@@ -142,15 +134,7 @@ T = TypeVar("T", bound=torch.Tensor)
|
|||||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
||||||
wh = xyxy[2:] - xyxy[:2]
|
wh = xyxy[2:] - xyxy[:2]
|
||||||
xywh = torch.cat([xyxy[:2], wh])
|
xywh = torch.cat([xyxy[:2], wh])
|
||||||
return xywh # pyre-ignore[7]
|
return xywh # pyre-ignore
|
||||||
|
|
||||||
|
|
||||||
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
|
|
||||||
wh = xywh[2:]
|
|
||||||
if clamp_size is not None:
|
|
||||||
wh = wh.clamp(min=clamp_size)
|
|
||||||
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
|
|
||||||
return xyxy # pyre-ignore[7]
|
|
||||||
|
|
||||||
|
|
||||||
def get_clamp_bbox(
|
def get_clamp_bbox(
|
||||||
@@ -196,6 +180,16 @@ def rescale_bbox(
|
|||||||
return bbox * rel_size
|
return bbox * rel_size
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_xywh_to_xyxy(
|
||||||
|
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
xyxy = xywh.clone()
|
||||||
|
if clamp_size is not None:
|
||||||
|
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||||
|
xyxy[2:] += xyxy[:2]
|
||||||
|
return xyxy
|
||||||
|
|
||||||
|
|
||||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
||||||
nz = np.flatnonzero(arr)
|
nz = np.flatnonzero(arr)
|
||||||
return nz[0], nz[-1] + 1
|
return nz[0], nz[-1] + 1
|
||||||
@@ -207,24 +201,18 @@ def resize_image(
|
|||||||
image_width: Optional[int],
|
image_width: Optional[int],
|
||||||
mode: str = "bilinear",
|
mode: str = "bilinear",
|
||||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||||
|
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
if (
|
if image_height is None or image_width is None:
|
||||||
image_height is None
|
|
||||||
or image_width is None
|
|
||||||
or image.shape[-2] == 0
|
|
||||||
or image.shape[-1] == 0
|
|
||||||
):
|
|
||||||
# skip the resizing
|
# skip the resizing
|
||||||
return image, 1.0, torch.ones_like(image[:1])
|
return image, 1.0, torch.ones_like(image[:1])
|
||||||
|
|
||||||
# takes numpy array or tensor, returns pytorch tensor
|
# takes numpy array or tensor, returns pytorch tensor
|
||||||
minscale = min(
|
minscale = min(
|
||||||
image_height / image.shape[-2],
|
image_height / image.shape[-2],
|
||||||
image_width / image.shape[-1],
|
image_width / image.shape[-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
imre = torch.nn.functional.interpolate(
|
imre = torch.nn.functional.interpolate(
|
||||||
image[None],
|
image[None],
|
||||||
scale_factor=minscale,
|
scale_factor=minscale,
|
||||||
@@ -232,7 +220,6 @@ def resize_image(
|
|||||||
align_corners=False if mode == "bilinear" else None,
|
align_corners=False if mode == "bilinear" else None,
|
||||||
recompute_scale_factor=True,
|
recompute_scale_factor=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||||
mask = torch.zeros(1, image_height, image_width)
|
mask = torch.zeros(1, image_height, image_width)
|
||||||
@@ -245,21 +232,9 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
|||||||
return im.astype(np.float32) / 255.0
|
return im.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
def load_image(
|
def load_image(path: str) -> np.ndarray:
|
||||||
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Load an image from a path and return it as a numpy array.
|
|
||||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
|
||||||
returned as the fourth channel.
|
|
||||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
|
||||||
"""
|
|
||||||
with Image.open(path) as pil_im:
|
with Image.open(path) as pil_im:
|
||||||
# Check if the image has an alpha channel
|
im = np.array(pil_im.convert("RGB"))
|
||||||
if try_read_alpha and pil_im.mode == "RGBA":
|
|
||||||
im = np.array(pil_im)
|
|
||||||
else:
|
|
||||||
im = np.array(pil_im.convert(pil_format))
|
|
||||||
|
|
||||||
return transpose_normalize_image(im)
|
return transpose_normalize_image(im)
|
||||||
|
|
||||||
@@ -354,7 +329,6 @@ def adjust_camera_to_bbox_crop_(
|
|||||||
|
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
|
||||||
camera.principal_point[0],
|
camera.principal_point[0],
|
||||||
image_size_wh,
|
image_size_wh,
|
||||||
)
|
)
|
||||||
@@ -367,7 +341,6 @@ def adjust_camera_to_bbox_crop_(
|
|||||||
)
|
)
|
||||||
|
|
||||||
camera.focal_length = focal_length[None]
|
camera.focal_length = focal_length[None]
|
||||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
|
||||||
camera.principal_point = principal_point_cropped[None]
|
camera.principal_point = principal_point_cropped[None]
|
||||||
|
|
||||||
|
|
||||||
@@ -379,7 +352,6 @@ def adjust_camera_to_image_scale_(
|
|||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
|
||||||
camera.principal_point[0],
|
camera.principal_point[0],
|
||||||
original_size_wh,
|
original_size_wh,
|
||||||
)
|
)
|
||||||
@@ -396,8 +368,7 @@ def adjust_camera_to_image_scale_(
|
|||||||
image_size_wh_output,
|
image_size_wh_output,
|
||||||
)
|
)
|
||||||
camera.focal_length = focal_length_scaled[None]
|
camera.focal_length = focal_length_scaled[None]
|
||||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
camera.principal_point = principal_point_scaled[None]
|
||||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE this cache is per-worker; they are implemented as processes.
|
# NOTE this cache is per-worker; they are implemented as processes.
|
||||||
|
|||||||
@@ -299,6 +299,7 @@ def eval_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
||||||
|
|
||||||
loss_mask_now = mask_crop * loss_fg_mask
|
loss_mask_now = mask_crop * loss_fg_mask
|
||||||
|
|
||||||
for rgb_metric_name, rgb_metric_fun in zip(
|
for rgb_metric_name, rgb_metric_fun in zip(
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import warnings
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
self.layers = torch.nn.ModuleList()
|
self.layers = torch.nn.ModuleList()
|
||||||
self.proj_layers = torch.nn.ModuleList()
|
self.proj_layers = torch.nn.ModuleList()
|
||||||
for stage in range(self.max_stage):
|
for stage in range(self.max_stage):
|
||||||
stage_name = f"layer{stage + 1}"
|
stage_name = f"layer{stage+1}"
|
||||||
feature_name = self._get_resnet_stage_feature_name(stage)
|
feature_name = self._get_resnet_stage_feature_name(stage)
|
||||||
if (stage + 1) in self.stages:
|
if (stage + 1) in self.stages:
|
||||||
if (
|
if (
|
||||||
@@ -139,18 +139,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
self.stages = set(self.stages) # convert to set for faster "in"
|
self.stages = set(self.stages) # convert to set for faster "in"
|
||||||
|
|
||||||
def _get_resnet_stage_feature_name(self, stage) -> str:
|
def _get_resnet_stage_feature_name(self, stage) -> str:
|
||||||
return f"res_layer_{stage + 1}"
|
return f"res_layer_{stage+1}"
|
||||||
|
|
||||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
|
|
||||||
# `Union[Tensor, Module]`.
|
|
||||||
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
|
|
||||||
# `Union[Tensor, Module]`.
|
|
||||||
return (img - self._resnet_mean) / self._resnet_std
|
return (img - self._resnet_mean) / self._resnet_std
|
||||||
|
|
||||||
def get_feat_dims(self) -> int:
|
def get_feat_dims(self) -> int:
|
||||||
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
|
||||||
# not a function.
|
|
||||||
return sum(self._feat_dim.values())
|
return sum(self._feat_dim.values())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -189,12 +183,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
else:
|
else:
|
||||||
imgs_normed = imgs_resized
|
imgs_normed = imgs_resized
|
||||||
# is not a function.
|
# is not a function.
|
||||||
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
|
||||||
feats = self.stem(imgs_normed)
|
feats = self.stem(imgs_normed)
|
||||||
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
|
|
||||||
# `Union[Tensor, Module]`.
|
|
||||||
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
|
|
||||||
# `Union[Tensor, Module]`.
|
|
||||||
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
||||||
feats = layer(feats)
|
feats = layer(feats)
|
||||||
# just a sanity check below
|
# just a sanity check below
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user