mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-04 03:05:59 +08:00
Compare commits
5 Commits
v0.7.9
...
cbcae096a0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbcae096a0 | ||
|
|
5b1cce56bc | ||
|
|
0c3b204375 | ||
|
|
6be5e2da06 | ||
|
|
f5f6b78e70 |
@@ -19,7 +19,6 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
|
||||||
from recommonmark.parser import CommonMarkParser
|
from recommonmark.parser import CommonMarkParser
|
||||||
|
|||||||
@@ -48,22 +48,18 @@ The outputs of the experiment are saved and logged in multiple ways:
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.data_source import (
|
from pytorch3d.implicitron.dataset.data_source import (
|
||||||
DataSourceBase,
|
DataSourceBase,
|
||||||
ImplicitronDataSource,
|
ImplicitronDataSource,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||||
MultiPassEmissionAbsorptionRenderer,
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.optim
|
import torch.optim
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.tools import model_io
|
from pytorch3d.implicitron.tools import model_io
|
||||||
|
|||||||
@@ -14,9 +14,7 @@ from dataclasses import field
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch.optim
|
import torch.optim
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.tools import model_io
|
from pytorch3d.implicitron.tools import model_io
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import unittest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hydra import compose, initialize_config_dir
|
from hydra import compose, initialize_config_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from projects.implicitron_trainer.impl.optimizer_factory import (
|
from projects.implicitron_trainer.impl.optimizer_factory import (
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset import orm_types, types
|
from pytorch3d.implicitron.dataset import orm_types, types
|
||||||
from pytorch3d.implicitron.dataset.utils import (
|
from pytorch3d.implicitron.dataset.utils import (
|
||||||
adjust_camera_to_bbox_crop_,
|
adjust_camera_to_bbox_crop_,
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
|
|||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
@@ -327,9 +326,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
assert os.path.normpath(
|
assert os.path.normpath(
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.frame_annots[idx]["frame_annotation"].image.path
|
self.frame_annots[idx]["frame_annotation"].image.path
|
||||||
) == os.path.normpath(
|
) == os.path.normpath(path), (
|
||||||
path
|
f"Inconsistent frame indices {seq_name, frame_no, path}."
|
||||||
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
)
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
dataset_idx = [
|
dataset_idx = [
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
|
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
||||||
from .json_index_dataset import JsonIndexDataset
|
from .json_index_dataset import JsonIndexDataset
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DATASET_TYPE_KNOWN,
|
DATASET_TYPE_KNOWN,
|
||||||
DATASET_TYPE_TEST,
|
DATASET_TYPE_TEST,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
@@ -31,7 +30,6 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
registry,
|
registry,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import torch
|
|||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
|
||||||
from .load_llff import load_llff_data
|
from .load_llff import load_llff_data
|
||||||
|
|
||||||
from .single_sequence_dataset import (
|
from .single_sequence_dataset import (
|
||||||
_interpret_blender_cameras,
|
_interpret_blender_cameras,
|
||||||
SingleSceneDatasetMapProviderBase,
|
SingleSceneDatasetMapProviderBase,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import struct
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.types import (
|
from pytorch3d.implicitron.dataset.types import (
|
||||||
DepthAnnotation,
|
DepthAnnotation,
|
||||||
ImageAnnotation,
|
ImageAnnotation,
|
||||||
@@ -22,7 +21,6 @@ from pytorch3d.implicitron.dataset.types import (
|
|||||||
VideoAnnotation,
|
VideoAnnotation,
|
||||||
ViewpointAnnotation,
|
ViewpointAnnotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sqlalchemy import LargeBinary
|
from sqlalchemy import LargeBinary
|
||||||
from sqlalchemy.orm import (
|
from sqlalchemy.orm import (
|
||||||
composite,
|
composite,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import urllib
|
import urllib
|
||||||
from dataclasses import dataclass, Field, field
|
from dataclasses import dataclass, Field, field
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -32,13 +31,11 @@ import pandas as pd
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.frame_data import (
|
from pytorch3d.implicitron.dataset.frame_data import (
|
||||||
FrameData,
|
FrameData,
|
||||||
FrameDataBuilder, # noqa
|
FrameDataBuilder, # noqa
|
||||||
FrameDataBuilderBase,
|
FrameDataBuilderBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
registry,
|
registry,
|
||||||
ReplaceableBase,
|
ReplaceableBase,
|
||||||
@@ -486,9 +483,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
*self._get_pick_filters(),
|
*self._get_pick_filters(),
|
||||||
*self._get_exclude_filters(),
|
*self._get_exclude_filters(),
|
||||||
]
|
]
|
||||||
if self.pick_sequences_sql_clause:
|
if pick_sequences_sql_clause := self.pick_sequences_sql_clause:
|
||||||
print("Applying the custom SQL clause.")
|
print("Applying the custom SQL clause.")
|
||||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||||
|
where_conditions.append(sa.text(pick_sequences_sql_clause))
|
||||||
|
|
||||||
def add_where(stmt):
|
def add_where(stmt):
|
||||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
@@ -508,6 +506,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
|
|
||||||
subquery = add_where(subquery).subquery()
|
subquery = add_where(subquery).subquery()
|
||||||
stmt = sa.select(subquery.c.sequence_name).where(
|
stmt = sa.select(subquery.c.sequence_name).where(
|
||||||
|
# pyre-ignore[6]: SQLAlchemy column comparison returns ColumnElement, not bool
|
||||||
subquery.c.row_number <= self.limit_sequences_per_category_to
|
subquery.c.row_number <= self.limit_sequences_per_category_to
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -636,9 +635,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pick_frames_sql_clause:
|
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
||||||
logger.info("Applying the custom SQL clause.")
|
logger.info("Applying the custom SQL clause.")
|
||||||
pick_frames_criteria.append(sa.text(self.pick_frames_sql_clause))
|
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||||
|
pick_frames_criteria.append(sa.text(pick_frames_sql_clause))
|
||||||
|
|
||||||
if pick_frames_criteria:
|
if pick_frames_criteria:
|
||||||
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
index = self._pick_frames_by_criteria(index, pick_frames_criteria)
|
||||||
@@ -701,9 +701,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pick_frames_sql_clause:
|
if pick_frames_sql_clause := self.pick_frames_sql_clause:
|
||||||
logger.info(" applying custom SQL clause")
|
logger.info(" applying custom SQL clause")
|
||||||
where_conditions.append(sa.text(self.pick_frames_sql_clause))
|
# pyre-ignore[6]: TextClause is compatible with where conditions
|
||||||
|
where_conditions.append(sa.text(pick_frames_sql_clause))
|
||||||
|
|
||||||
if where_conditions:
|
if where_conditions:
|
||||||
stmt = stmt.where(*where_conditions)
|
stmt = stmt.where(*where_conditions)
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ import os
|
|||||||
from typing import List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
|||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from typing import List, Optional, Tuple, TypeVar, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pytorch3d.io import IO
|
from pytorch3d.io import IO
|
||||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import warnings
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.base_model import (
|
from pytorch3d.implicitron.models.base_model import (
|
||||||
ImplicitronModelBase,
|
ImplicitronModelBase,
|
||||||
ImplicitronRender,
|
ImplicitronRender,
|
||||||
@@ -28,7 +27,6 @@ from pytorch3d.implicitron.models.metrics import (
|
|||||||
RegularizationMetricsBase,
|
RegularizationMetricsBase,
|
||||||
ViewMetricsBase,
|
ViewMetricsBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import (
|
from pytorch3d.implicitron.models.renderer.base import (
|
||||||
BaseRenderer,
|
BaseRenderer,
|
||||||
EvaluationMode,
|
EvaluationMode,
|
||||||
@@ -38,7 +36,6 @@ from pytorch3d.implicitron.models.renderer.base import (
|
|||||||
RenderSamplingMode,
|
RenderSamplingMode,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.utils import (
|
from pytorch3d.implicitron.models.utils import (
|
||||||
apply_chunked,
|
apply_chunked,
|
||||||
chunk_generator,
|
chunk_generator,
|
||||||
@@ -53,7 +50,6 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
registry,
|
registry,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
||||||
from pytorch3d.renderer import utils as rend_utils
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
|||||||
@@ -16,14 +16,11 @@ This file contains
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
Configurable,
|
Configurable,
|
||||||
registry,
|
registry,
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import torch
|
|||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .base import ImplicitFunctionBase
|
from .base import ImplicitFunctionBase
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from pytorch3d.renderer.implicit import HarmonicEmbedding
|
|||||||
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
|
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
|
||||||
|
|
||||||
from .base import ImplicitFunctionBase
|
from .base import ImplicitFunctionBase
|
||||||
|
|
||||||
from .decoding_functions import ( # noqa
|
from .decoding_functions import ( # noqa
|
||||||
_xavier_init,
|
_xavier_init,
|
||||||
MLPWithInputSkips,
|
MLPWithInputSkips,
|
||||||
|
|||||||
@@ -9,7 +9,6 @@
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.common.compat import prod
|
from pytorch3d.common.compat import prod
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ from dataclasses import fields
|
|||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
||||||
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
from pytorch3d.implicitron.models.implicit_function.decoding_functions import (
|
||||||
DecoderFunctionBase,
|
DecoderFunctionBase,
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Un
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.base_model import (
|
from pytorch3d.implicitron.models.base_model import (
|
||||||
ImplicitronModelBase,
|
ImplicitronModelBase,
|
||||||
ImplicitronRender,
|
ImplicitronRender,
|
||||||
@@ -28,7 +27,6 @@ from pytorch3d.implicitron.models.metrics import (
|
|||||||
RegularizationMetricsBase,
|
RegularizationMetricsBase,
|
||||||
ViewMetricsBase,
|
ViewMetricsBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import (
|
from pytorch3d.implicitron.models.renderer.base import (
|
||||||
BaseRenderer,
|
BaseRenderer,
|
||||||
EvaluationMode,
|
EvaluationMode,
|
||||||
@@ -50,7 +48,6 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
registry,
|
registry,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
||||||
from pytorch3d.renderer import utils as rend_utils
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
||||||
|
|
||||||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
|
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import torch
|
|||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,8 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from pytorch3d.common.compat import prod
|
from pytorch3d.common.compat import prod
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools import image_utils
|
from pytorch3d.implicitron.tools import image_utils
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||||
|
|
||||||
|
|
||||||
@@ -83,9 +80,9 @@ def preprocess_input(
|
|||||||
|
|
||||||
if mask_depths and fg_mask is not None and depth_map is not None:
|
if mask_depths and fg_mask is not None and depth_map is not None:
|
||||||
# mask the depths
|
# mask the depths
|
||||||
assert (
|
assert mask_threshold > 0.0, (
|
||||||
mask_threshold > 0.0
|
"Depths should be masked only with thresholded masks"
|
||||||
), "Depths should be masked only with thresholded masks"
|
)
|
||||||
warnings.warn("Masking depths!")
|
warnings.warn("Masking depths!")
|
||||||
depth_map = depth_map * fg_mask
|
depth_map = depth_map * fg_mask
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import math
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import pytorch3d
|
import pytorch3d
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.ops import packed_to_padded
|
from pytorch3d.ops import packed_to_padded
|
||||||
from pytorch3d.renderer import PerspectiveCameras
|
from pytorch3d.renderer import PerspectiveCameras
|
||||||
|
|||||||
@@ -499,7 +499,7 @@ class StatsJSONEncoder(json.JSONEncoder):
|
|||||||
return enc
|
return enc
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Object of type {o.__class__.__name__} " f"is not JSON serializable"
|
f"Object of type {o.__class__.__name__} is not JSON serializable"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import matplotlib
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
_NO_TORCHVISION = False
|
_NO_TORCHVISION = False
|
||||||
|
|||||||
@@ -796,7 +796,7 @@ def save_obj(
|
|||||||
# Create .mtl file with the material name and texture map filename
|
# Create .mtl file with the material name and texture map filename
|
||||||
# TODO: enable material properties to also be saved.
|
# TODO: enable material properties to also be saved.
|
||||||
with _open_file(mtl_path, path_manager, "w") as f_mtl:
|
with _open_file(mtl_path, path_manager, "w") as f_mtl:
|
||||||
lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n"
|
lines = f"newmtl mesh\nmap_Kd {output_path.stem}.png\n"
|
||||||
f_mtl.write(lines)
|
f_mtl.write(lines)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,8 @@
|
|||||||
|
|
||||||
|
|
||||||
from .chamfer import chamfer_distance
|
from .chamfer import chamfer_distance
|
||||||
|
|
||||||
from .mesh_edge_loss import mesh_edge_loss
|
from .mesh_edge_loss import mesh_edge_loss
|
||||||
|
|
||||||
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
|
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
|
||||||
|
|
||||||
from .mesh_normal_consistency import mesh_normal_consistency
|
from .mesh_normal_consistency import mesh_normal_consistency
|
||||||
from .point_mesh_distance import point_mesh_edge_distance, point_mesh_face_distance
|
from .point_mesh_distance import point_mesh_edge_distance, point_mesh_face_distance
|
||||||
|
|
||||||
|
|||||||
@@ -8,17 +8,14 @@
|
|||||||
|
|
||||||
from .ball_query import ball_query
|
from .ball_query import ball_query
|
||||||
from .cameras_alignment import corresponding_cameras_alignment
|
from .cameras_alignment import corresponding_cameras_alignment
|
||||||
|
|
||||||
from .cubify import cubify
|
from .cubify import cubify
|
||||||
from .graph_conv import GraphConv
|
from .graph_conv import GraphConv
|
||||||
from .interp_face_attrs import interpolate_face_attributes
|
from .interp_face_attrs import interpolate_face_attributes
|
||||||
from .iou_box3d import box3d_overlap
|
from .iou_box3d import box3d_overlap
|
||||||
from .knn import knn_gather, knn_points
|
from .knn import knn_gather, knn_points
|
||||||
from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian
|
from .laplacian_matrices import cot_laplacian, laplacian, norm_laplacian
|
||||||
|
|
||||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
from .mesh_filtering import taubin_smoothing
|
from .mesh_filtering import taubin_smoothing
|
||||||
|
|
||||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||||
from .perspective_n_points import efficient_pnp
|
from .perspective_n_points import efficient_pnp
|
||||||
from .points_alignment import corresponding_points_alignment, iterative_closest_point
|
from .points_alignment import corresponding_points_alignment, iterative_closest_point
|
||||||
@@ -30,9 +27,7 @@ from .points_to_volumes import (
|
|||||||
add_pointclouds_to_volumes,
|
add_pointclouds_to_volumes,
|
||||||
add_points_features_to_volume_densities_features,
|
add_points_features_to_volume_densities_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .sample_farthest_points import sample_farthest_points
|
from .sample_farthest_points import sample_farthest_points
|
||||||
|
|
||||||
from .sample_points_from_meshes import sample_points_from_meshes
|
from .sample_points_from_meshes import sample_points_from_meshes
|
||||||
from .subdivide_meshes import SubdivideMeshes
|
from .subdivide_meshes import SubdivideMeshes
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -42,7 +37,6 @@ from .utils import (
|
|||||||
is_pointclouds,
|
is_pointclouds,
|
||||||
wmean,
|
wmean,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .vert_align import vert_align
|
from .vert_align import vert_align
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from pytorch3d.common.compat import meshgrid_ij
|
from pytorch3d.common.compat import meshgrid_ij
|
||||||
|
|
||||||
from pytorch3d.structures import Meshes
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ import sys
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
|
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
|
|
||||||
from pytorch3d.ops.packed_to_padded import packed_to_padded
|
from pytorch3d.ops.packed_to_padded import packed_to_padded
|
||||||
from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments
|
from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ from .mesh import (
|
|||||||
TexturesUV,
|
TexturesUV,
|
||||||
TexturesVertex,
|
TexturesVertex,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .points import (
|
from .points import (
|
||||||
AlphaCompositor,
|
AlphaCompositor,
|
||||||
NormWeightedCompositor,
|
NormWeightedCompositor,
|
||||||
|
|||||||
@@ -153,12 +153,12 @@ def _pulsar_from_opencv_projection(
|
|||||||
# Check image sizes.
|
# Check image sizes.
|
||||||
image_w = image_size_wh[0, 0]
|
image_w = image_size_wh[0, 0]
|
||||||
image_h = image_size_wh[0, 1]
|
image_h = image_size_wh[0, 1]
|
||||||
assert torch.all(
|
assert torch.all(image_size_wh[:, 0] == image_w), (
|
||||||
image_size_wh[:, 0] == image_w
|
"All images in a batch must have the same width!"
|
||||||
), "All images in a batch must have the same width!"
|
)
|
||||||
assert torch.all(
|
assert torch.all(image_size_wh[:, 1] == image_h), (
|
||||||
image_size_wh[:, 1] == image_h
|
"All images in a batch must have the same height!"
|
||||||
), "All images in a batch must have the same height!"
|
)
|
||||||
# Focal length.
|
# Focal length.
|
||||||
fx = camera_matrix[:, 0, 0].unsqueeze(1)
|
fx = camera_matrix[:, 0, 0].unsqueeze(1)
|
||||||
fy = camera_matrix[:, 1, 1].unsqueeze(1)
|
fy = camera_matrix[:, 1, 1].unsqueeze(1)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from .clip import (
|
|||||||
ClippedFaces,
|
ClippedFaces,
|
||||||
convert_clipped_rasterization_to_original_faces,
|
convert_clipped_rasterization_to_original_faces,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .rasterize_meshes import rasterize_meshes
|
from .rasterize_meshes import rasterize_meshes
|
||||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||||
from .renderer import MeshRenderer, MeshRendererWithFragments
|
from .renderer import MeshRenderer, MeshRendererWithFragments
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import torch
|
|||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
|
|
||||||
from ..utils import parse_image_size
|
from ..utils import parse_image_size
|
||||||
|
|
||||||
from .clip import (
|
from .clip import (
|
||||||
clip_faces,
|
clip_faces,
|
||||||
ClipFrustum,
|
ClipFrustum,
|
||||||
|
|||||||
@@ -625,9 +625,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
of length `k`.
|
of length `k`.
|
||||||
"""
|
"""
|
||||||
if len(faces_ids_list) != len(self.atlas_list()):
|
if len(faces_ids_list) != len(self.atlas_list()):
|
||||||
raise IndexError(
|
raise IndexError("faces_ids_list must be of the same length as atlas_list.")
|
||||||
"faces_ids_list must be of " "the same length as atlas_list."
|
|
||||||
)
|
|
||||||
|
|
||||||
sub_features = []
|
sub_features = []
|
||||||
for atlas, faces_ids in zip(self.atlas_list(), faces_ids_list):
|
for atlas, faces_ids in zip(self.atlas_list(), faces_ids_list):
|
||||||
@@ -1657,7 +1655,7 @@ class TexturesUV(TexturesBase):
|
|||||||
raise NotImplementedError("This function does not support multiple maps.")
|
raise NotImplementedError("This function does not support multiple maps.")
|
||||||
if len(faces_ids_list) != len(self.faces_uvs_padded()):
|
if len(faces_ids_list) != len(self.faces_uvs_padded()):
|
||||||
raise IndexError(
|
raise IndexError(
|
||||||
"faces_uvs_padded must be of " "the same length as face_ids_list."
|
"faces_uvs_padded must be of the same length as face_ids_list."
|
||||||
)
|
)
|
||||||
|
|
||||||
sub_faces_uvs, sub_verts_uvs, sub_maps = [], [], []
|
sub_faces_uvs, sub_verts_uvs, sub_maps = [], [], []
|
||||||
@@ -1871,7 +1869,7 @@ class TexturesVertex(TexturesBase):
|
|||||||
"""
|
"""
|
||||||
if len(vertex_ids_list) != len(self.verts_features_list()):
|
if len(vertex_ids_list) != len(self.verts_features_list()):
|
||||||
raise IndexError(
|
raise IndexError(
|
||||||
"verts_features_list must be of " "the same length as vertex_ids_list."
|
"verts_features_list must be of the same length as vertex_ids_list."
|
||||||
)
|
)
|
||||||
|
|
||||||
sub_features = []
|
sub_features = []
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
||||||
import OpenGL.EGL as egl # noqa
|
import OpenGL.EGL as egl # noqa
|
||||||
|
|
||||||
import pycuda.driver as cuda # noqa
|
import pycuda.driver as cuda # noqa
|
||||||
from OpenGL._opaque import opaque_pointer_cls # noqa
|
from OpenGL._opaque import opaque_pointer_cls # noqa
|
||||||
from OpenGL.raw.EGL._errors import EGLError # noqa
|
from OpenGL.raw.EGL._errors import EGLError # noqa
|
||||||
|
|||||||
@@ -17,15 +17,12 @@ import numpy as np
|
|||||||
import OpenGL.GL as gl
|
import OpenGL.GL as gl
|
||||||
import pycuda.gl
|
import pycuda.gl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
|
||||||
from ..cameras import FoVOrthographicCameras, FoVPerspectiveCameras
|
from ..cameras import FoVOrthographicCameras, FoVPerspectiveCameras
|
||||||
from ..mesh.rasterizer import Fragments, RasterizationSettings
|
from ..mesh.rasterizer import Fragments, RasterizationSettings
|
||||||
from ..utils import parse_image_size
|
from ..utils import parse_image_size
|
||||||
|
|
||||||
from .opengl_utils import _torch_to_opengl, global_device_context_store
|
from .opengl_utils import _torch_to_opengl, global_device_context_store
|
||||||
|
|
||||||
# Shader strings, used below to compile an OpenGL program.
|
# Shader strings, used below to compile an OpenGL program.
|
||||||
|
|||||||
@@ -9,9 +9,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .compositor import AlphaCompositor, NormWeightedCompositor
|
from .compositor import AlphaCompositor, NormWeightedCompositor
|
||||||
|
|
||||||
from .pulsar.unified import PulsarPointsRenderer
|
from .pulsar.unified import PulsarPointsRenderer
|
||||||
|
|
||||||
from .rasterize_points import rasterize_points
|
from .rasterize_points import rasterize_points
|
||||||
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
||||||
from .renderer import PointsRenderer
|
from .renderer import PointsRenderer
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
|
|
||||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
|
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
|
||||||
|
|
||||||
from ..utils import parse_image_size
|
from ..utils import parse_image_size
|
||||||
|
|||||||
@@ -531,9 +531,9 @@ class Meshes:
|
|||||||
list of tensors of vertices of shape (V_n, 3).
|
list of tensors of vertices of shape (V_n, 3).
|
||||||
"""
|
"""
|
||||||
if self._verts_list is None:
|
if self._verts_list is None:
|
||||||
assert (
|
assert self._verts_padded is not None, (
|
||||||
self._verts_padded is not None
|
"verts_padded is required to compute verts_list."
|
||||||
), "verts_padded is required to compute verts_list."
|
)
|
||||||
self._verts_list = struct_utils.padded_to_list(
|
self._verts_list = struct_utils.padded_to_list(
|
||||||
self._verts_padded, self.num_verts_per_mesh().tolist()
|
self._verts_padded, self.num_verts_per_mesh().tolist()
|
||||||
)
|
)
|
||||||
@@ -547,9 +547,9 @@ class Meshes:
|
|||||||
list of tensors of faces of shape (F_n, 3).
|
list of tensors of faces of shape (F_n, 3).
|
||||||
"""
|
"""
|
||||||
if self._faces_list is None:
|
if self._faces_list is None:
|
||||||
assert (
|
assert self._faces_padded is not None, (
|
||||||
self._faces_padded is not None
|
"faces_padded is required to compute faces_list."
|
||||||
), "faces_padded is required to compute faces_list."
|
)
|
||||||
self._faces_list = struct_utils.padded_to_list(
|
self._faces_list = struct_utils.padded_to_list(
|
||||||
self._faces_padded, self.num_faces_per_mesh().tolist()
|
self._faces_padded, self.num_faces_per_mesh().tolist()
|
||||||
)
|
)
|
||||||
@@ -925,9 +925,9 @@ class Meshes:
|
|||||||
|
|
||||||
verts_list = self.verts_list()
|
verts_list = self.verts_list()
|
||||||
faces_list = self.faces_list()
|
faces_list = self.faces_list()
|
||||||
assert (
|
assert faces_list is not None and verts_list is not None, (
|
||||||
faces_list is not None and verts_list is not None
|
"faces_list and verts_list arguments are required"
|
||||||
), "faces_list and verts_list arguments are required"
|
)
|
||||||
|
|
||||||
if self.isempty():
|
if self.isempty():
|
||||||
self._faces_padded = torch.zeros(
|
self._faces_padded = torch.zeros(
|
||||||
|
|||||||
@@ -433,9 +433,9 @@ class Pointclouds:
|
|||||||
list of tensors of points of shape (P_n, 3).
|
list of tensors of points of shape (P_n, 3).
|
||||||
"""
|
"""
|
||||||
if self._points_list is None:
|
if self._points_list is None:
|
||||||
assert (
|
assert self._points_padded is not None, (
|
||||||
self._points_padded is not None
|
"points_padded is required to compute points_list."
|
||||||
), "points_padded is required to compute points_list."
|
)
|
||||||
points_list = []
|
points_list = []
|
||||||
for i in range(self._N):
|
for i in range(self._N):
|
||||||
points_list.append(
|
points_list.append(
|
||||||
|
|||||||
@@ -12,11 +12,8 @@ from .camera_conversions import (
|
|||||||
pulsar_from_cameras_projection,
|
pulsar_from_cameras_projection,
|
||||||
pulsar_from_opencv_projection,
|
pulsar_from_opencv_projection,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .checkerboard import checkerboard
|
from .checkerboard import checkerboard
|
||||||
|
|
||||||
from .ico_sphere import ico_sphere
|
from .ico_sphere import ico_sphere
|
||||||
|
|
||||||
from .torus import torus
|
from .torus import torus
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
15
setup.py
15
setup.py
@@ -75,6 +75,21 @@ def get_extensions():
|
|||||||
]
|
]
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
nvcc_args.append("-std=c++17")
|
nvcc_args.append("-std=c++17")
|
||||||
|
|
||||||
|
# CUDA 13.0+ compatibility flags for pulsar.
|
||||||
|
# Starting with CUDA 13, __global__ function visibility changed.
|
||||||
|
# See: https://developer.nvidia.com/blog/
|
||||||
|
# cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
|
||||||
|
cuda_version = torch.version.cuda
|
||||||
|
if cuda_version is not None:
|
||||||
|
major = int(cuda_version.split(".")[0])
|
||||||
|
if major >= 13:
|
||||||
|
nvcc_args.extend(
|
||||||
|
[
|
||||||
|
"--device-entity-has-hidden-visibility=false",
|
||||||
|
"-static-global-template-stub=false",
|
||||||
|
]
|
||||||
|
)
|
||||||
if cub_home is None:
|
if cub_home is None:
|
||||||
prefix = os.environ.get("CONDA_PREFIX", None)
|
prefix = os.environ.get("CONDA_PREFIX", None)
|
||||||
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
|
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from itertools import product
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fvcore.common.benchmark import benchmark
|
from fvcore.common.benchmark import benchmark
|
||||||
|
|
||||||
from pytorch3d.ops.ball_query import ball_query
|
from pytorch3d.ops.ball_query import ball_query
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses
|
from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,12 +11,10 @@ from dataclasses import dataclass
|
|||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||||
DoublePoolBatchSampler,
|
DoublePoolBatchSampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||||
|
|||||||
@@ -7,9 +7,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.utils import (
|
from pytorch3d.implicitron.dataset.utils import (
|
||||||
bbox_xywh_to_xyxy,
|
bbox_xywh_to_xyxy,
|
||||||
bbox_xyxy_to_xywh,
|
bbox_xyxy_to_xywh,
|
||||||
@@ -21,7 +19,6 @@ from pytorch3d.implicitron.dataset.utils import (
|
|||||||
rescale_bbox,
|
rescale_bbox,
|
||||||
resize_image,
|
resize_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
|
||||||
SequenceDataLoaderMapProvider,
|
SequenceDataLoaderMapProvider,
|
||||||
SimpleDataLoaderMapProvider,
|
SimpleDataLoaderMapProvider,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from pytorch3d.implicitron import eval_demo
|
from pytorch3d.implicitron import eval_demo
|
||||||
|
|
||||||
from tests.common_testing import interactive_testing_requested
|
from tests.common_testing import interactive_testing_requested
|
||||||
|
|
||||||
from .common_resources import CO3D_MANIFOLD_PATH
|
from .common_resources import CO3D_MANIFOLD_PATH
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import unittest
|
|||||||
import lpips
|
import lpips
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
from pytorch3d.implicitron.dataset.frame_data import FrameData
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from typing import ClassVar, Optional, Type
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset import types
|
from pytorch3d.implicitron.dataset import types
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder
|
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder
|
||||||
from pytorch3d.implicitron.dataset.orm_types import (
|
from pytorch3d.implicitron.dataset.orm_types import (
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset import types
|
from pytorch3d.implicitron.dataset import types
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
|
from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder
|
||||||
@@ -29,7 +28,6 @@ from pytorch3d.implicitron.dataset.utils import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
from tests.implicitron.common_resources import get_skateboard_data
|
from tests.implicitron.common_resources import get_skateboard_data
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import unittest
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
|
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
|
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
|
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
|
||||||
|
|||||||
@@ -8,9 +8,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.base import (
|
from pytorch3d.implicitron.models.renderer.base import (
|
||||||
approximate_conical_frustum_as_gaussians,
|
approximate_conical_frustum_as_gaussians,
|
||||||
compute_3d_diagonal_covariance_gaussian,
|
compute_3d_diagonal_covariance_gaussian,
|
||||||
@@ -18,7 +16,6 @@ from pytorch3d.implicitron.models.renderer.base import (
|
|||||||
ImplicitronRayBundle,
|
ImplicitronRayBundle,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.ray_sampler import AbstractMaskRaySampler
|
from pytorch3d.implicitron.models.renderer.ray_sampler import AbstractMaskRaySampler
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -18,7 +17,6 @@ from pytorch3d.implicitron.models.renderer.ray_sampler import (
|
|||||||
compute_radii,
|
compute_radii,
|
||||||
NearFarRaySampler,
|
NearFarRaySampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.renderer.cameras import (
|
from pytorch3d.renderer.cameras import (
|
||||||
CamerasBase,
|
CamerasBase,
|
||||||
FoVOrthographicCameras,
|
FoVOrthographicCameras,
|
||||||
@@ -28,7 +26,6 @@ from pytorch3d.renderer.cameras import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
|
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle
|
||||||
from tests.common_camera_utils import init_random_cameras
|
from tests.common_camera_utils import init_random_cameras
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
CAMERA_TYPES = (
|
CAMERA_TYPES = (
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
|
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||||
|
|
||||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import unittest
|
|||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
|
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
|
||||||
apply_blurpool_on_weights,
|
apply_blurpool_on_weights,
|
||||||
RayPointRefiner,
|
RayPointRefiner,
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ import unittest
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
||||||
|
|
||||||
NO_BLOBS_KWARGS = {
|
NO_BLOBS_KWARGS = {
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from pytorch3d.implicitron.models.implicit_function.scene_representation_network
|
|||||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
from pytorch3d.renderer import PerspectiveCameras
|
from pytorch3d.renderer import PerspectiveCameras
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
_BATCH_SIZE: int = 3
|
_BATCH_SIZE: int = 3
|
||||||
|
|||||||
@@ -8,13 +8,11 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
|
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (
|
||||||
VoxelGridImplicitFunction,
|
VoxelGridImplicitFunction,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.implicit_function.utils import (
|
from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||||
interpolate_line,
|
interpolate_line,
|
||||||
interpolate_plane,
|
interpolate_plane,
|
||||||
@@ -22,7 +21,6 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
|||||||
VMFactorizedVoxelGrid,
|
VMFactorizedVoxelGrid,
|
||||||
VoxelGridModule,
|
VoxelGridModule,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ from pytorch3d.transforms.rotation_conversions import random_rotations
|
|||||||
from pytorch3d.transforms.so3 import so3_exp_map
|
from pytorch3d.transforms.so3 import so3_exp_map
|
||||||
|
|
||||||
from .common_camera_utils import init_random_cameras
|
from .common_camera_utils import init_random_cameras
|
||||||
|
|
||||||
from .common_testing import TestCaseMixin
|
from .common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -673,9 +673,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_load_simple_binary(self):
|
def test_load_simple_binary(self):
|
||||||
for big_endian in [True, False]:
|
for big_endian in [True, False]:
|
||||||
verts = (
|
verts = ("0 0 0 0 0 1 0 1 1 0 1 0 1 0 0 1 0 1 1 1 1 1 1 0").split()
|
||||||
"0 0 0 " "0 0 1 " "0 1 1 " "0 1 0 " "1 0 0 " "1 0 1 " "1 1 1 " "1 1 0"
|
|
||||||
).split()
|
|
||||||
faces = (
|
faces = (
|
||||||
"4 0 1 2 3 "
|
"4 0 1 2 3 "
|
||||||
"4 7 6 5 4 "
|
"4 7 6 5 4 "
|
||||||
@@ -688,7 +686,7 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
"3 4 5 1"
|
"3 4 5 1"
|
||||||
).split()
|
).split()
|
||||||
short_one = b"\00\01" if big_endian else b"\01\00"
|
short_one = b"\00\01" if big_endian else b"\01\00"
|
||||||
mixed_data = b"\00\00" b"\03\03" + (short_one + b"\00\01\01\01" b"\00\02")
|
mixed_data = b"\00\00\03\03" + (short_one + b"\00\01\01\01\00\02")
|
||||||
minus_one_data = b"\xff" * 14
|
minus_one_data = b"\xff" * 14
|
||||||
endian_char = ">" if big_endian else "<"
|
endian_char = ">" if big_endian else "<"
|
||||||
format = (
|
format = (
|
||||||
|
|||||||
@@ -604,9 +604,9 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
|||||||
# test weather they are of the correct shape
|
# test weather they are of the correct shape
|
||||||
for attr in ("origins", "directions", "lengths", "xys"):
|
for attr in ("origins", "directions", "lengths", "xys"):
|
||||||
tensor = getattr(ray_bundle, attr)
|
tensor = getattr(ray_bundle, attr)
|
||||||
assert tensor.shape[:2] == torch.Size(
|
assert tensor.shape[:2] == torch.Size((n_rays_total, 1)), (
|
||||||
(n_rays_total, 1)
|
tensor.shape
|
||||||
), tensor.shape
|
)
|
||||||
|
|
||||||
# if two camera ids are same than origins should also be the same
|
# if two camera ids are same than origins should also be the same
|
||||||
# directions and xys are always different and lengths equal
|
# directions and xys are always different and lengths equal
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ Sanity checks for output images from the renderer.
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|||||||
@@ -148,14 +148,14 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
for t_pair in ((t1, t2), (t1, t3), (t2, t3)):
|
for t_pair in ((t1, t2), (t1, t3), (t2, t3)):
|
||||||
matrix1 = t_pair[0].get_matrix()
|
matrix1 = t_pair[0].get_matrix()
|
||||||
matrix2 = t_pair[1].get_matrix()
|
matrix2 = t_pair[1].get_matrix()
|
||||||
self.assertTrue(torch.allclose(matrix1, matrix2))
|
self.assertClose(matrix1, matrix2)
|
||||||
|
|
||||||
def test_init_with_custom_matrix(self):
|
def test_init_with_custom_matrix(self):
|
||||||
for matrix in (torch.randn(10, 4, 4), torch.randn(4, 4)):
|
for matrix in (torch.randn(10, 4, 4), torch.randn(4, 4)):
|
||||||
t = Transform3d(matrix=matrix)
|
t = Transform3d(matrix=matrix)
|
||||||
self.assertTrue(t.device == matrix.device)
|
self.assertTrue(t.device == matrix.device)
|
||||||
self.assertTrue(t._matrix.dtype == matrix.dtype)
|
self.assertTrue(t._matrix.dtype == matrix.dtype)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix.view(t._matrix.shape)))
|
self.assertClose(t._matrix, matrix.view(t._matrix.shape))
|
||||||
|
|
||||||
def test_init_with_custom_matrix_errors(self):
|
def test_init_with_custom_matrix_errors(self):
|
||||||
bad_shapes = [[10, 5, 4], [3, 4], [10, 4, 4, 1], [10, 4, 4, 2], [4, 4, 4, 3]]
|
bad_shapes = [[10, 5, 4], [3, 4], [10, 4, 4, 1], [10, 4, 4, 2], [4, 4, 4, 3]]
|
||||||
@@ -189,8 +189,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
normals_out_expected = torch.tensor(
|
normals_out_expected = torch.tensor(
|
||||||
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
|
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
|
||||||
).view(1, 3, 3)
|
).view(1, 3, 3)
|
||||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
self.assertClose(points_out, points_out_expected)
|
||||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
self.assertClose(normals_out, normals_out_expected)
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
|
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
|
||||||
def test_rotate_check_rot_valid_on(self):
|
def test_rotate_check_rot_valid_on(self):
|
||||||
@@ -206,8 +206,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
normals_out = t.transform_normals(normals)
|
normals_out = t.transform_normals(normals)
|
||||||
points_out_expected = torch.bmm(points, R)
|
points_out_expected = torch.bmm(points, R)
|
||||||
normals_out_expected = torch.bmm(normals, R)
|
normals_out_expected = torch.bmm(normals, R)
|
||||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
self.assertClose(points_out, points_out_expected)
|
||||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
self.assertClose(normals_out, normals_out_expected)
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
|
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
|
||||||
def test_rotate_check_rot_valid_off(self):
|
def test_rotate_check_rot_valid_off(self):
|
||||||
@@ -223,8 +223,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
normals_out = t.transform_normals(normals)
|
normals_out = t.transform_normals(normals)
|
||||||
points_out_expected = torch.bmm(points, R)
|
points_out_expected = torch.bmm(points, R)
|
||||||
normals_out_expected = torch.bmm(normals, R)
|
normals_out_expected = torch.bmm(normals, R)
|
||||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
self.assertClose(points_out, points_out_expected)
|
||||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
self.assertClose(normals_out, normals_out_expected)
|
||||||
|
|
||||||
def test_scale(self):
|
def test_scale(self):
|
||||||
t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
|
t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
|
||||||
@@ -242,8 +242,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
normals_out_expected = torch.tensor(
|
normals_out_expected = torch.tensor(
|
||||||
[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
|
[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [1.0, 2.0, 0.0]]
|
||||||
).view(1, 3, 3)
|
).view(1, 3, 3)
|
||||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
self.assertClose(points_out, points_out_expected)
|
||||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
self.assertClose(normals_out, normals_out_expected)
|
||||||
|
|
||||||
def test_scale_translate(self):
|
def test_scale_translate(self):
|
||||||
t = Transform3d().scale(2, 1, 3).translate(1, 2, 3)
|
t = Transform3d().scale(2, 1, 3).translate(1, 2, 3)
|
||||||
@@ -261,8 +261,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
normals_out_expected = torch.tensor(
|
normals_out_expected = torch.tensor(
|
||||||
[[0.5, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 1.0, 0.0]]
|
[[0.5, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 1.0, 0.0]]
|
||||||
).view(1, 3, 3)
|
).view(1, 3, 3)
|
||||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
self.assertClose(points_out, points_out_expected)
|
||||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
self.assertClose(normals_out, normals_out_expected)
|
||||||
|
|
||||||
def test_rotate_axis_angle(self):
|
def test_rotate_axis_angle(self):
|
||||||
t = Transform3d().rotate_axis_angle(90.0, axis="Z")
|
t = Transform3d().rotate_axis_angle(90.0, axis="Z")
|
||||||
@@ -280,8 +280,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
normals_out_expected = torch.tensor(
|
normals_out_expected = torch.tensor(
|
||||||
[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
|
[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
|
||||||
).view(1, 3, 3)
|
).view(1, 3, 3)
|
||||||
self.assertTrue(torch.allclose(points_out, points_out_expected, atol=1e-7))
|
self.assertClose(points_out, points_out_expected, atol=1e-7)
|
||||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected, atol=1e-7))
|
self.assertClose(normals_out, normals_out_expected, atol=1e-7)
|
||||||
|
|
||||||
def test_transform_points_fail(self):
|
def test_transform_points_fail(self):
|
||||||
t1 = Scale(0.1, 0.1, 0.1)
|
t1 = Scale(0.1, 0.1, 0.1)
|
||||||
@@ -369,7 +369,7 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# assert all same
|
# assert all same
|
||||||
for m in (m1, m2, m3, m4):
|
for m in (m1, m2, m3, m4):
|
||||||
self.assertTrue(torch.allclose(m, m5, atol=1e-3))
|
self.assertClose(m, m5, atol=1e-3)
|
||||||
|
|
||||||
def _check_indexed_transforms(self, t3d, t3d_selected, indices):
|
def _check_indexed_transforms(self, t3d, t3d_selected, indices):
|
||||||
t3d_matrix = t3d.get_matrix()
|
t3d_matrix = t3d.get_matrix()
|
||||||
@@ -488,7 +488,7 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(new_points, new_points_expect)
|
self.assertClose(new_points, new_points_expect)
|
||||||
|
|
||||||
|
|
||||||
class TestTranslate(unittest.TestCase):
|
class TestTranslate(TestCaseMixin, unittest.TestCase):
|
||||||
def test_python_scalar(self):
|
def test_python_scalar(self):
|
||||||
t = Translate(0.2, 0.3, 0.4)
|
t = Translate(0.2, 0.3, 0.4)
|
||||||
matrix = torch.tensor(
|
matrix = torch.tensor(
|
||||||
@@ -502,7 +502,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_torch_scalar(self):
|
def test_torch_scalar(self):
|
||||||
x = torch.tensor(0.2)
|
x = torch.tensor(0.2)
|
||||||
@@ -520,7 +520,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_mixed_scalars(self):
|
def test_mixed_scalars(self):
|
||||||
x = 0.2
|
x = 0.2
|
||||||
@@ -538,7 +538,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_torch_scalar_grads(self):
|
def test_torch_scalar_grads(self):
|
||||||
# Make sure backprop works if we give torch scalars
|
# Make sure backprop works if we give torch scalars
|
||||||
@@ -549,8 +549,8 @@ class TestTranslate(unittest.TestCase):
|
|||||||
t._matrix.sum().backward()
|
t._matrix.sum().backward()
|
||||||
self.assertTrue(hasattr(x, "grad"))
|
self.assertTrue(hasattr(x, "grad"))
|
||||||
self.assertTrue(hasattr(y, "grad"))
|
self.assertTrue(hasattr(y, "grad"))
|
||||||
self.assertTrue(torch.allclose(x.grad, x.new_ones(x.shape)))
|
self.assertClose(x.grad, x.new_ones(x.shape))
|
||||||
self.assertTrue(torch.allclose(y.grad, y.new_ones(y.shape)))
|
self.assertClose(y.grad, y.new_ones(y.shape))
|
||||||
|
|
||||||
def test_torch_vectors(self):
|
def test_torch_vectors(self):
|
||||||
x = torch.tensor([0.2, 2.0])
|
x = torch.tensor([0.2, 2.0])
|
||||||
@@ -574,7 +574,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_vector_broadcast(self):
|
def test_vector_broadcast(self):
|
||||||
x = torch.tensor([0.2, 2.0])
|
x = torch.tensor([0.2, 2.0])
|
||||||
@@ -598,7 +598,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_bad_broadcast(self):
|
def test_bad_broadcast(self):
|
||||||
x = torch.tensor([0.2, 2.0, 20.0])
|
x = torch.tensor([0.2, 2.0, 20.0])
|
||||||
@@ -629,7 +629,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_mixed_broadcast_grad(self):
|
def test_mixed_broadcast_grad(self):
|
||||||
x = 0.2
|
x = 0.2
|
||||||
@@ -643,8 +643,8 @@ class TestTranslate(unittest.TestCase):
|
|||||||
z_grad = torch.tensor([1.0, 1.0])
|
z_grad = torch.tensor([1.0, 1.0])
|
||||||
self.assertEqual(y.grad.shape, y_grad.shape)
|
self.assertEqual(y.grad.shape, y_grad.shape)
|
||||||
self.assertEqual(z.grad.shape, z_grad.shape)
|
self.assertEqual(z.grad.shape, z_grad.shape)
|
||||||
self.assertTrue(torch.allclose(y.grad, y_grad))
|
self.assertClose(y.grad, y_grad)
|
||||||
self.assertTrue(torch.allclose(z.grad, z_grad))
|
self.assertClose(z.grad, z_grad)
|
||||||
|
|
||||||
def test_matrix(self):
|
def test_matrix(self):
|
||||||
xyz = torch.tensor([[0.2, 0.3, 0.4], [2.0, 3.0, 4.0]])
|
xyz = torch.tensor([[0.2, 0.3, 0.4], [2.0, 3.0, 4.0]])
|
||||||
@@ -666,7 +666,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_matrix_extra_args(self):
|
def test_matrix_extra_args(self):
|
||||||
xyz = torch.tensor([[0.2, 0.3, 0.4], [2.0, 3.0, 4.0]])
|
xyz = torch.tensor([[0.2, 0.3, 0.4], [2.0, 3.0, 4.0]])
|
||||||
@@ -679,8 +679,8 @@ class TestTranslate(unittest.TestCase):
|
|||||||
im = t.inverse()._matrix
|
im = t.inverse()._matrix
|
||||||
im_2 = t._matrix.inverse()
|
im_2 = t._matrix.inverse()
|
||||||
im_comp = t.get_matrix().inverse()
|
im_comp = t.get_matrix().inverse()
|
||||||
self.assertTrue(torch.allclose(im, im_comp))
|
self.assertClose(im, im_comp, atol=1e-4)
|
||||||
self.assertTrue(torch.allclose(im, im_2))
|
self.assertClose(im, im_2, atol=1e-4)
|
||||||
|
|
||||||
def test_get_item(self, batch_size=5):
|
def test_get_item(self, batch_size=5):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@@ -692,7 +692,7 @@ class TestTranslate(unittest.TestCase):
|
|||||||
self.assertIsInstance(t3d_selected, Translate)
|
self.assertIsInstance(t3d_selected, Translate)
|
||||||
|
|
||||||
|
|
||||||
class TestScale(unittest.TestCase):
|
class TestScale(TestCaseMixin, unittest.TestCase):
|
||||||
def test_single_python_scalar(self):
|
def test_single_python_scalar(self):
|
||||||
t = Scale(0.1)
|
t = Scale(0.1)
|
||||||
matrix = torch.tensor(
|
matrix = torch.tensor(
|
||||||
@@ -706,7 +706,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_single_torch_scalar(self):
|
def test_single_torch_scalar(self):
|
||||||
t = Scale(torch.tensor(0.1))
|
t = Scale(torch.tensor(0.1))
|
||||||
@@ -721,7 +721,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_single_vector(self):
|
def test_single_vector(self):
|
||||||
t = Scale(torch.tensor([0.1, 0.2]))
|
t = Scale(torch.tensor([0.1, 0.2]))
|
||||||
@@ -742,7 +742,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_single_matrix(self):
|
def test_single_matrix(self):
|
||||||
xyz = torch.tensor([[0.1, 0.2, 0.3], [1.0, 2.0, 3.0]])
|
xyz = torch.tensor([[0.1, 0.2, 0.3], [1.0, 2.0, 3.0]])
|
||||||
@@ -764,7 +764,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_three_python_scalar(self):
|
def test_three_python_scalar(self):
|
||||||
t = Scale(0.1, 0.2, 0.3)
|
t = Scale(0.1, 0.2, 0.3)
|
||||||
@@ -779,7 +779,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_three_torch_scalar(self):
|
def test_three_torch_scalar(self):
|
||||||
t = Scale(torch.tensor(0.1), torch.tensor(0.2), torch.tensor(0.3))
|
t = Scale(torch.tensor(0.1), torch.tensor(0.2), torch.tensor(0.3))
|
||||||
@@ -794,7 +794,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_three_mixed_scalar(self):
|
def test_three_mixed_scalar(self):
|
||||||
t = Scale(torch.tensor(0.1), 0.2, torch.tensor(0.3))
|
t = Scale(torch.tensor(0.1), 0.2, torch.tensor(0.3))
|
||||||
@@ -809,7 +809,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_three_vector_broadcast(self):
|
def test_three_vector_broadcast(self):
|
||||||
x = torch.tensor([0.1])
|
x = torch.tensor([0.1])
|
||||||
@@ -833,7 +833,7 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_three_mixed_broadcast_grad(self):
|
def test_three_mixed_broadcast_grad(self):
|
||||||
x = 0.1
|
x = 0.1
|
||||||
@@ -857,14 +857,14 @@ class TestScale(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
t._matrix.sum().backward()
|
t._matrix.sum().backward()
|
||||||
self.assertTrue(hasattr(y, "grad"))
|
self.assertTrue(hasattr(y, "grad"))
|
||||||
self.assertTrue(hasattr(z, "grad"))
|
self.assertTrue(hasattr(z, "grad"))
|
||||||
y_grad = torch.tensor(2.0)
|
y_grad = torch.tensor(2.0)
|
||||||
z_grad = torch.tensor([1.0, 1.0])
|
z_grad = torch.tensor([1.0, 1.0])
|
||||||
self.assertTrue(torch.allclose(y.grad, y_grad))
|
self.assertClose(y.grad, y_grad)
|
||||||
self.assertTrue(torch.allclose(z.grad, z_grad))
|
self.assertClose(z.grad, z_grad)
|
||||||
|
|
||||||
def test_inverse(self):
|
def test_inverse(self):
|
||||||
x = torch.tensor([0.1])
|
x = torch.tensor([0.1])
|
||||||
@@ -874,8 +874,8 @@ class TestScale(unittest.TestCase):
|
|||||||
im = t.inverse()._matrix
|
im = t.inverse()._matrix
|
||||||
im_2 = t._matrix.inverse()
|
im_2 = t._matrix.inverse()
|
||||||
im_comp = t.get_matrix().inverse()
|
im_comp = t.get_matrix().inverse()
|
||||||
self.assertTrue(torch.allclose(im, im_comp))
|
self.assertClose(im, im_comp)
|
||||||
self.assertTrue(torch.allclose(im, im_2))
|
self.assertClose(im, im_2)
|
||||||
|
|
||||||
def test_get_item(self, batch_size=5):
|
def test_get_item(self, batch_size=5):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@@ -887,7 +887,7 @@ class TestScale(unittest.TestCase):
|
|||||||
self.assertIsInstance(t3d_selected, Scale)
|
self.assertIsInstance(t3d_selected, Scale)
|
||||||
|
|
||||||
|
|
||||||
class TestTransformBroadcast(unittest.TestCase):
|
class TestTransformBroadcast(TestCaseMixin, unittest.TestCase):
|
||||||
def test_broadcast_transform_points(self):
|
def test_broadcast_transform_points(self):
|
||||||
t1 = Scale(0.1, 0.1, 0.1)
|
t1 = Scale(0.1, 0.1, 0.1)
|
||||||
N = 10
|
N = 10
|
||||||
@@ -965,10 +965,10 @@ class TestTransformBroadcast(unittest.TestCase):
|
|||||||
composed_mat = t1N2.get_matrix()
|
composed_mat = t1N2.get_matrix()
|
||||||
self.assertTrue(composed_mat.shape == (N, 4, 4))
|
self.assertTrue(composed_mat.shape == (N, 4, 4))
|
||||||
expected_mat = torch.eye(3, dtype=torch.float32) * 0.3 * 0.2 * 0.1
|
expected_mat = torch.eye(3, dtype=torch.float32) * 0.3 * 0.2 * 0.1
|
||||||
self.assertTrue(torch.allclose(composed_mat[0, :3, :3], expected_mat))
|
self.assertClose(composed_mat[0, :3, :3], expected_mat)
|
||||||
|
|
||||||
|
|
||||||
class TestRotate(unittest.TestCase):
|
class TestRotate(TestCaseMixin, unittest.TestCase):
|
||||||
def test_single_matrix(self):
|
def test_single_matrix(self):
|
||||||
R = torch.eye(3)
|
R = torch.eye(3)
|
||||||
t = Rotate(R)
|
t = Rotate(R)
|
||||||
@@ -983,7 +983,7 @@ class TestRotate(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
self.assertClose(t._matrix, matrix)
|
||||||
|
|
||||||
def test_invalid_dimensions(self):
|
def test_invalid_dimensions(self):
|
||||||
R = torch.eye(4)
|
R = torch.eye(4)
|
||||||
@@ -998,8 +998,8 @@ class TestRotate(unittest.TestCase):
|
|||||||
im = t.inverse()._matrix
|
im = t.inverse()._matrix
|
||||||
im_2 = t._matrix.inverse()
|
im_2 = t._matrix.inverse()
|
||||||
im_comp = t.get_matrix().inverse()
|
im_comp = t.get_matrix().inverse()
|
||||||
self.assertTrue(torch.allclose(im, im_comp, atol=1e-4))
|
self.assertClose(im, im_comp, atol=1e-4)
|
||||||
self.assertTrue(torch.allclose(im, im_2, atol=1e-4))
|
self.assertClose(im, im_2, atol=1e-4)
|
||||||
|
|
||||||
def test_get_item(self, batch_size=5):
|
def test_get_item(self, batch_size=5):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@@ -1011,7 +1011,7 @@ class TestRotate(unittest.TestCase):
|
|||||||
self.assertIsInstance(t3d_selected, Rotate)
|
self.assertIsInstance(t3d_selected, Rotate)
|
||||||
|
|
||||||
|
|
||||||
class TestRotateAxisAngle(unittest.TestCase):
|
class TestRotateAxisAngle(TestCaseMixin, unittest.TestCase):
|
||||||
def test_rotate_x_python_scalar(self):
|
def test_rotate_x_python_scalar(self):
|
||||||
t = RotateAxisAngle(angle=90, axis="X")
|
t = RotateAxisAngle(angle=90, axis="X")
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -1030,10 +1030,8 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
|
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
|
||||||
transformed_points = t.transform_points(points)
|
transformed_points = t.transform_points(points)
|
||||||
expected_points = torch.tensor([0.0, 0.0, 1.0])
|
expected_points = torch.tensor([0.0, 0.0, 1.0])
|
||||||
self.assertTrue(
|
self.assertClose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
|
||||||
|
|
||||||
def test_rotate_x_torch_scalar(self):
|
def test_rotate_x_torch_scalar(self):
|
||||||
angle = torch.tensor(90.0)
|
angle = torch.tensor(90.0)
|
||||||
@@ -1054,10 +1052,8 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
|
points = torch.tensor([0.0, 1.0, 0.0])[None, None, :] # (1, 1, 3)
|
||||||
transformed_points = t.transform_points(points)
|
transformed_points = t.transform_points(points)
|
||||||
expected_points = torch.tensor([0.0, 0.0, 1.0])
|
expected_points = torch.tensor([0.0, 0.0, 1.0])
|
||||||
self.assertTrue(
|
self.assertClose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
|
||||||
|
|
||||||
def test_rotate_x_torch_tensor(self):
|
def test_rotate_x_torch_tensor(self):
|
||||||
angle = torch.tensor([0, 45.0, 90.0]) # (N)
|
angle = torch.tensor([0, 45.0, 90.0]) # (N)
|
||||||
@@ -1089,10 +1085,10 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
angle = angle
|
angle = angle
|
||||||
t = RotateAxisAngle(angle=angle, axis="X")
|
t = RotateAxisAngle(angle=angle, axis="X")
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
|
|
||||||
def test_rotate_y_python_scalar(self):
|
def test_rotate_y_python_scalar(self):
|
||||||
t = RotateAxisAngle(angle=90, axis="Y")
|
t = RotateAxisAngle(angle=90, axis="Y")
|
||||||
@@ -1112,10 +1108,8 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
||||||
transformed_points = t.transform_points(points)
|
transformed_points = t.transform_points(points)
|
||||||
expected_points = torch.tensor([0.0, 0.0, -1.0])
|
expected_points = torch.tensor([0.0, 0.0, -1.0])
|
||||||
self.assertTrue(
|
self.assertClose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
|
||||||
|
|
||||||
def test_rotate_y_torch_scalar(self):
|
def test_rotate_y_torch_scalar(self):
|
||||||
"""
|
"""
|
||||||
@@ -1141,10 +1135,8 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
||||||
transformed_points = t.transform_points(points)
|
transformed_points = t.transform_points(points)
|
||||||
expected_points = torch.tensor([0.0, 0.0, -1.0])
|
expected_points = torch.tensor([0.0, 0.0, -1.0])
|
||||||
self.assertTrue(
|
self.assertClose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
|
||||||
|
|
||||||
def test_rotate_y_torch_tensor(self):
|
def test_rotate_y_torch_tensor(self):
|
||||||
angle = torch.tensor([0, 45.0, 90.0])
|
angle = torch.tensor([0, 45.0, 90.0])
|
||||||
@@ -1176,7 +1168,7 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
|
|
||||||
def test_rotate_z_python_scalar(self):
|
def test_rotate_z_python_scalar(self):
|
||||||
t = RotateAxisAngle(angle=90, axis="Z")
|
t = RotateAxisAngle(angle=90, axis="Z")
|
||||||
@@ -1196,10 +1188,8 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
||||||
transformed_points = t.transform_points(points)
|
transformed_points = t.transform_points(points)
|
||||||
expected_points = torch.tensor([0.0, 1.0, 0.0])
|
expected_points = torch.tensor([0.0, 1.0, 0.0])
|
||||||
self.assertTrue(
|
self.assertClose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
|
||||||
|
|
||||||
def test_rotate_z_torch_scalar(self):
|
def test_rotate_z_torch_scalar(self):
|
||||||
angle = torch.tensor(90.0)
|
angle = torch.tensor(90.0)
|
||||||
@@ -1220,10 +1210,8 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
points = torch.tensor([1.0, 0.0, 0.0])[None, None, :] # (1, 1, 3)
|
||||||
transformed_points = t.transform_points(points)
|
transformed_points = t.transform_points(points)
|
||||||
expected_points = torch.tensor([0.0, 1.0, 0.0])
|
expected_points = torch.tensor([0.0, 1.0, 0.0])
|
||||||
self.assertTrue(
|
self.assertClose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
|
||||||
|
|
||||||
def test_rotate_z_torch_tensor(self):
|
def test_rotate_z_torch_tensor(self):
|
||||||
angle = torch.tensor([0, 45.0, 90.0])
|
angle = torch.tensor([0, 45.0, 90.0])
|
||||||
@@ -1255,7 +1243,7 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
|
|
||||||
def test_rotate_compose_x_y_z(self):
|
def test_rotate_compose_x_y_z(self):
|
||||||
angle = torch.tensor(90.0)
|
angle = torch.tensor(90.0)
|
||||||
@@ -1301,7 +1289,7 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
# order of transforms is t1 -> t2
|
# order of transforms is t1 -> t2
|
||||||
matrix = torch.matmul(matrix1, torch.matmul(matrix2, matrix3))
|
matrix = torch.matmul(matrix1, torch.matmul(matrix2, matrix3))
|
||||||
composed_matrix = t.get_matrix()
|
composed_matrix = t.get_matrix()
|
||||||
self.assertTrue(torch.allclose(composed_matrix, matrix, atol=1e-7))
|
self.assertClose(composed_matrix, matrix, atol=1e-7)
|
||||||
|
|
||||||
def test_rotate_angle_radians(self):
|
def test_rotate_angle_radians(self):
|
||||||
t = RotateAxisAngle(angle=math.pi / 2, degrees=False, axis="Z")
|
t = RotateAxisAngle(angle=math.pi / 2, degrees=False, axis="Z")
|
||||||
@@ -1318,7 +1306,7 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
|
|
||||||
def test_lower_case_axis(self):
|
def test_lower_case_axis(self):
|
||||||
t = RotateAxisAngle(angle=90.0, axis="z")
|
t = RotateAxisAngle(angle=90.0, axis="z")
|
||||||
@@ -1335,7 +1323,7 @@ class TestRotateAxisAngle(unittest.TestCase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
self.assertClose(t._matrix, matrix, atol=1e-7)
|
||||||
|
|
||||||
def test_axis_fail(self):
|
def test_axis_fail(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|||||||
Reference in New Issue
Block a user