mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
upgrade pyre version in fbcode/vision
- batch 2
Reviewed By: bottler Differential Revision: D60992234 fbshipit-source-id: 899db6ed590ef966ff651c11027819e59b8401a3
This commit is contained in:
parent
1e0b1d9c72
commit
38afdcfc68
@ -99,7 +99,7 @@ except ModuleNotFoundError:
|
|||||||
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Configurable): # pyre-ignore: 13
|
class Experiment(Configurable):
|
||||||
"""
|
"""
|
||||||
This class is at the top level of Implicitron's config hierarchy. Its
|
This class is at the top level of Implicitron's config hierarchy. Its
|
||||||
members are high-level components necessary for training an implicit rende-
|
members are high-level components necessary for training an implicit rende-
|
||||||
@ -120,12 +120,16 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
will be saved here.
|
will be saved here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `data_source` is never initialized.
|
||||||
data_source: DataSourceBase
|
data_source: DataSourceBase
|
||||||
data_source_class_type: str = "ImplicitronDataSource"
|
data_source_class_type: str = "ImplicitronDataSource"
|
||||||
|
# pyre-fixme[13]: Attribute `model_factory` is never initialized.
|
||||||
model_factory: ModelFactoryBase
|
model_factory: ModelFactoryBase
|
||||||
model_factory_class_type: str = "ImplicitronModelFactory"
|
model_factory_class_type: str = "ImplicitronModelFactory"
|
||||||
|
# pyre-fixme[13]: Attribute `optimizer_factory` is never initialized.
|
||||||
optimizer_factory: OptimizerFactoryBase
|
optimizer_factory: OptimizerFactoryBase
|
||||||
optimizer_factory_class_type: str = "ImplicitronOptimizerFactory"
|
optimizer_factory_class_type: str = "ImplicitronOptimizerFactory"
|
||||||
|
# pyre-fixme[13]: Attribute `training_loop` is never initialized.
|
||||||
training_loop: TrainingLoopBase
|
training_loop: TrainingLoopBase
|
||||||
training_loop_class_type: str = "ImplicitronTrainingLoop"
|
training_loop_class_type: str = "ImplicitronTrainingLoop"
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ class ModelFactoryBase(ReplaceableBase):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
class ImplicitronModelFactory(ModelFactoryBase):
|
||||||
"""
|
"""
|
||||||
A factory class that initializes an implicit rendering model.
|
A factory class that initializes an implicit rendering model.
|
||||||
|
|
||||||
@ -61,6 +61,7 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `model` is never initialized.
|
||||||
model: ImplicitronModelBase
|
model: ImplicitronModelBase
|
||||||
model_class_type: str = "GenericModel"
|
model_class_type: str = "GenericModel"
|
||||||
resume: bool = True
|
resume: bool = True
|
||||||
|
@ -30,13 +30,13 @@ from .utils import seed_all_random_engines
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# pyre-fixme[13]: Attribute `evaluator` is never initialized.
|
|
||||||
class TrainingLoopBase(ReplaceableBase):
|
class TrainingLoopBase(ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
Members:
|
Members:
|
||||||
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `evaluator` is never initialized.
|
||||||
evaluator: Optional[EvaluatorBase]
|
evaluator: Optional[EvaluatorBase]
|
||||||
evaluator_class_type: Optional[str] = "ImplicitronEvaluator"
|
evaluator_class_type: Optional[str] = "ImplicitronEvaluator"
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class DataSourceBase(ReplaceableBase):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
class ImplicitronDataSource(DataSourceBase):
|
||||||
"""
|
"""
|
||||||
Represents the data used in Implicitron. This is the only implementation
|
Represents the data used in Implicitron. This is the only implementation
|
||||||
of DataSourceBase provided.
|
of DataSourceBase provided.
|
||||||
@ -52,8 +52,11 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `dataset_map_provider` is never initialized.
|
||||||
dataset_map_provider: DatasetMapProviderBase
|
dataset_map_provider: DatasetMapProviderBase
|
||||||
|
# pyre-fixme[13]: Attribute `dataset_map_provider_class_type` is never initialized.
|
||||||
dataset_map_provider_class_type: str
|
dataset_map_provider_class_type: str
|
||||||
|
# pyre-fixme[13]: Attribute `data_loader_map_provider` is never initialized.
|
||||||
data_loader_map_provider: DataLoaderMapProviderBase
|
data_loader_map_provider: DataLoaderMapProviderBase
|
||||||
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ _NEED_CONTROL: Tuple[str, ...] = (
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
class JsonIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Generates the training / validation and testing dataset objects for
|
Generates the training / validation and testing dataset objects for
|
||||||
a dataset laid out on disk like Co3D, with annotations in json files.
|
a dataset laid out on disk like Co3D, with annotations in json files.
|
||||||
@ -95,6 +95,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `category` is never initialized.
|
||||||
category: str
|
category: str
|
||||||
task_str: str = "singlesequence"
|
task_str: str = "singlesequence"
|
||||||
dataset_root: str = _CO3D_DATASET_ROOT
|
dataset_root: str = _CO3D_DATASET_ROOT
|
||||||
@ -104,8 +105,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
test_restrict_sequence_id: int = -1
|
test_restrict_sequence_id: int = -1
|
||||||
assert_single_seq: bool = False
|
assert_single_seq: bool = False
|
||||||
only_test_set: bool = False
|
only_test_set: bool = False
|
||||||
|
# pyre-fixme[13]: Attribute `dataset` is never initialized.
|
||||||
dataset: JsonIndexDataset
|
dataset: JsonIndexDataset
|
||||||
dataset_class_type: str = "JsonIndexDataset"
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Generates the training, validation, and testing dataset objects for
|
Generates the training, validation, and testing dataset objects for
|
||||||
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
||||||
@ -171,7 +171,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `category` is never initialized.
|
||||||
category: str
|
category: str
|
||||||
|
# pyre-fixme[13]: Attribute `subset_name` is never initialized.
|
||||||
subset_name: str
|
subset_name: str
|
||||||
dataset_root: str = _CO3DV2_DATASET_ROOT
|
dataset_root: str = _CO3DV2_DATASET_ROOT
|
||||||
|
|
||||||
@ -183,8 +185,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
n_known_frames_for_test: int = 0
|
n_known_frames_for_test: int = 0
|
||||||
|
|
||||||
dataset_class_type: str = "JsonIndexDataset"
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
# pyre-fixme[13]: Attribute `dataset` is never initialized.
|
||||||
dataset: JsonIndexDataset
|
dataset: JsonIndexDataset
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from .utils import DATASET_TYPE_KNOWN
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
A simple single-scene dataset based on PyTorch3D renders of a mesh.
|
A simple single-scene dataset based on PyTorch3D renders of a mesh.
|
||||||
Provides `num_views` renders of the mesh as train, with no val
|
Provides `num_views` renders of the mesh as train, with no val
|
||||||
@ -76,6 +76,7 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
|
|||||||
resolution: int = 128
|
resolution: int = 128
|
||||||
use_point_light: bool = True
|
use_point_light: bool = True
|
||||||
gpu_idx: Optional[int] = 0
|
gpu_idx: Optional[int] = 0
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
@ -83,7 +83,6 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
|||||||
return self.eval_batches
|
return self.eval_batches
|
||||||
|
|
||||||
|
|
||||||
# pyre-fixme[13]: Uninitialized attribute
|
|
||||||
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Base for provider of data for one scene from LLFF or blender datasets.
|
Base for provider of data for one scene from LLFF or blender datasets.
|
||||||
@ -100,8 +99,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
testing frame.
|
testing frame.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `base_dir` is never initialized.
|
||||||
base_dir: str
|
base_dir: str
|
||||||
|
# pyre-fixme[13]: Attribute `object_name` is never initialized.
|
||||||
object_name: str
|
object_name: str
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
n_known_frames_for_test: Optional[int] = None
|
n_known_frames_for_test: Optional[int] = None
|
||||||
|
@ -348,6 +348,7 @@ def adjust_camera_to_image_scale_(
|
|||||||
camera: PerspectiveCameras,
|
camera: PerspectiveCameras,
|
||||||
original_size_wh: torch.Tensor,
|
original_size_wh: torch.Tensor,
|
||||||
new_size_wh: torch.LongTensor,
|
new_size_wh: torch.LongTensor,
|
||||||
|
# pyre-fixme[7]: Expected `PerspectiveCameras` but got implicit return value of `None`.
|
||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
@ -367,7 +368,7 @@ def adjust_camera_to_image_scale_(
|
|||||||
image_size_wh_output,
|
image_size_wh_output,
|
||||||
)
|
)
|
||||||
camera.focal_length = focal_length_scaled[None]
|
camera.focal_length = focal_length_scaled[None]
|
||||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore
|
camera.principal_point = principal_point_scaled[None]
|
||||||
|
|
||||||
|
|
||||||
# NOTE this cache is per-worker; they are implemented as processes.
|
# NOTE this cache is per-worker; they are implemented as processes.
|
||||||
|
@ -65,7 +65,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
class GenericModel(ImplicitronModelBase):
|
||||||
"""
|
"""
|
||||||
GenericModel is a wrapper for the neural implicit
|
GenericModel is a wrapper for the neural implicit
|
||||||
rendering and reconstruction pipeline which consists
|
rendering and reconstruction pipeline which consists
|
||||||
@ -226,34 +226,42 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
|
|
||||||
# ---- global encoder settings
|
# ---- global encoder settings
|
||||||
global_encoder_class_type: Optional[str] = None
|
global_encoder_class_type: Optional[str] = None
|
||||||
|
# pyre-fixme[13]: Attribute `global_encoder` is never initialized.
|
||||||
global_encoder: Optional[GlobalEncoderBase]
|
global_encoder: Optional[GlobalEncoderBase]
|
||||||
|
|
||||||
# ---- raysampler
|
# ---- raysampler
|
||||||
raysampler_class_type: str = "AdaptiveRaySampler"
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
|
# pyre-fixme[13]: Attribute `raysampler` is never initialized.
|
||||||
raysampler: RaySamplerBase
|
raysampler: RaySamplerBase
|
||||||
|
|
||||||
# ---- renderer configs
|
# ---- renderer configs
|
||||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
|
# pyre-fixme[13]: Attribute `renderer` is never initialized.
|
||||||
renderer: BaseRenderer
|
renderer: BaseRenderer
|
||||||
|
|
||||||
# ---- image feature extractor settings
|
# ---- image feature extractor settings
|
||||||
# (This is only created if view_pooler is enabled)
|
# (This is only created if view_pooler is enabled)
|
||||||
|
# pyre-fixme[13]: Attribute `image_feature_extractor` is never initialized.
|
||||||
image_feature_extractor: Optional[FeatureExtractorBase]
|
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||||
image_feature_extractor_class_type: Optional[str] = None
|
image_feature_extractor_class_type: Optional[str] = None
|
||||||
# ---- view pooler settings
|
# ---- view pooler settings
|
||||||
view_pooler_enabled: bool = False
|
view_pooler_enabled: bool = False
|
||||||
|
# pyre-fixme[13]: Attribute `view_pooler` is never initialized.
|
||||||
view_pooler: Optional[ViewPooler]
|
view_pooler: Optional[ViewPooler]
|
||||||
|
|
||||||
# ---- implicit function settings
|
# ---- implicit function settings
|
||||||
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||||
# This is just a model, never constructed.
|
# This is just a model, never constructed.
|
||||||
# The actual implicit functions live in self._implicit_functions
|
# The actual implicit functions live in self._implicit_functions
|
||||||
|
# pyre-fixme[13]: Attribute `implicit_function` is never initialized.
|
||||||
implicit_function: ImplicitFunctionBase
|
implicit_function: ImplicitFunctionBase
|
||||||
|
|
||||||
# ----- metrics
|
# ----- metrics
|
||||||
|
# pyre-fixme[13]: Attribute `view_metrics` is never initialized.
|
||||||
view_metrics: ViewMetricsBase
|
view_metrics: ViewMetricsBase
|
||||||
view_metrics_class_type: str = "ViewMetrics"
|
view_metrics_class_type: str = "ViewMetrics"
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `regularization_metrics` is never initialized.
|
||||||
regularization_metrics: RegularizationMetricsBase
|
regularization_metrics: RegularizationMetricsBase
|
||||||
regularization_metrics_class_type: str = "RegularizationMetrics"
|
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||||
|
|
||||||
|
@ -59,12 +59,13 @@ class GlobalEncoderBase(ReplaceableBase):
|
|||||||
|
|
||||||
# TODO: probabilistic embeddings?
|
# TODO: probabilistic embeddings?
|
||||||
@registry.register
|
@registry.register
|
||||||
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
|
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A global encoder implementation which provides an autodecoder encoding
|
A global encoder implementation which provides an autodecoder encoding
|
||||||
of the frame's sequence identifier.
|
of the frame's sequence identifier.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `autodecoder` is never initialized.
|
||||||
autodecoder: Autodecoder
|
autodecoder: Autodecoder
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -244,7 +244,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
# pyre-fixme[13]: Attribute `network` is never initialized.
|
|
||||||
class MLPDecoder(DecoderFunctionBase):
|
class MLPDecoder(DecoderFunctionBase):
|
||||||
"""
|
"""
|
||||||
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||||
@ -272,6 +271,7 @@ class MLPDecoder(DecoderFunctionBase):
|
|||||||
|
|
||||||
input_dim: int = 3
|
input_dim: int = 3
|
||||||
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
# pyre-fixme[13]: Attribute `network` is never initialized.
|
||||||
network: MLPWithInputSkips
|
network: MLPWithInputSkips
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -318,10 +318,11 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
# pyre-fixme[13]: Uninitialized attribute
|
|
||||||
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||||
latent_dim: int = 0
|
latent_dim: int = 0
|
||||||
|
# pyre-fixme[13]: Attribute `raymarch_function` is never initialized.
|
||||||
raymarch_function: SRNRaymarchFunction
|
raymarch_function: SRNRaymarchFunction
|
||||||
|
# pyre-fixme[13]: Attribute `pixel_generator` is never initialized.
|
||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -366,7 +367,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
# pyre-fixme[13]: Uninitialized attribute
|
|
||||||
class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This implicit function uses a hypernetwork to generate the
|
This implicit function uses a hypernetwork to generate the
|
||||||
@ -377,7 +377,9 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
latent_dim_hypernet: int = 0
|
latent_dim_hypernet: int = 0
|
||||||
latent_dim: int = 0
|
latent_dim: int = 0
|
||||||
|
# pyre-fixme[13]: Attribute `hypernet` is never initialized.
|
||||||
hypernet: SRNRaymarchHyperNet
|
hypernet: SRNRaymarchHyperNet
|
||||||
|
# pyre-fixme[13]: Attribute `pixel_generator` is never initialized.
|
||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -805,7 +805,6 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# pyre-fixme[13]: Attribute `voxel_grid` is never initialized.
|
|
||||||
class VoxelGridModule(Configurable, torch.nn.Module):
|
class VoxelGridModule(Configurable, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
||||||
@ -845,6 +844,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||||
|
# pyre-fixme[13]: Attribute `voxel_grid` is never initialized.
|
||||||
voxel_grid: VoxelGridBase
|
voxel_grid: VoxelGridBase
|
||||||
|
|
||||||
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
|
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
|
||||||
|
@ -39,7 +39,6 @@ enable_get_default_args(HarmonicEmbedding)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
# pyre-ignore[13]
|
|
||||||
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This implicit function consists of two streams, one for the density calculation and one
|
This implicit function consists of two streams, one for the density calculation and one
|
||||||
@ -145,9 +144,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# ---- voxel grid for density
|
# ---- voxel grid for density
|
||||||
|
# pyre-fixme[13]: Attribute `voxel_grid_density` is never initialized.
|
||||||
voxel_grid_density: VoxelGridModule
|
voxel_grid_density: VoxelGridModule
|
||||||
|
|
||||||
# ---- voxel grid for color
|
# ---- voxel grid for color
|
||||||
|
# pyre-fixme[13]: Attribute `voxel_grid_color` is never initialized.
|
||||||
voxel_grid_color: VoxelGridModule
|
voxel_grid_color: VoxelGridModule
|
||||||
|
|
||||||
# ---- harmonic embeddings density
|
# ---- harmonic embeddings density
|
||||||
@ -163,10 +164,12 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
# ---- decoder function for density
|
# ---- decoder function for density
|
||||||
decoder_density_class_type: str = "MLPDecoder"
|
decoder_density_class_type: str = "MLPDecoder"
|
||||||
|
# pyre-fixme[13]: Attribute `decoder_density` is never initialized.
|
||||||
decoder_density: DecoderFunctionBase
|
decoder_density: DecoderFunctionBase
|
||||||
|
|
||||||
# ---- decoder function for color
|
# ---- decoder function for color
|
||||||
decoder_color_class_type: str = "MLPDecoder"
|
decoder_color_class_type: str = "MLPDecoder"
|
||||||
|
# pyre-fixme[13]: Attribute `decoder_color` is never initialized.
|
||||||
decoder_color: DecoderFunctionBase
|
decoder_color: DecoderFunctionBase
|
||||||
|
|
||||||
# ---- cuda streams
|
# ---- cuda streams
|
||||||
|
@ -69,7 +69,7 @@ IMPLICIT_FUNCTION_ARGS_TO_REMOVE: List[str] = [
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
|
class OverfitModel(ImplicitronModelBase):
|
||||||
"""
|
"""
|
||||||
OverfitModel is a wrapper for the neural implicit
|
OverfitModel is a wrapper for the neural implicit
|
||||||
rendering and reconstruction pipeline which consists
|
rendering and reconstruction pipeline which consists
|
||||||
@ -198,27 +198,34 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
|
|
||||||
# ---- global encoder settings
|
# ---- global encoder settings
|
||||||
global_encoder_class_type: Optional[str] = None
|
global_encoder_class_type: Optional[str] = None
|
||||||
|
# pyre-fixme[13]: Attribute `global_encoder` is never initialized.
|
||||||
global_encoder: Optional[GlobalEncoderBase]
|
global_encoder: Optional[GlobalEncoderBase]
|
||||||
|
|
||||||
# ---- raysampler
|
# ---- raysampler
|
||||||
raysampler_class_type: str = "AdaptiveRaySampler"
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
|
# pyre-fixme[13]: Attribute `raysampler` is never initialized.
|
||||||
raysampler: RaySamplerBase
|
raysampler: RaySamplerBase
|
||||||
|
|
||||||
# ---- renderer configs
|
# ---- renderer configs
|
||||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
|
# pyre-fixme[13]: Attribute `renderer` is never initialized.
|
||||||
renderer: BaseRenderer
|
renderer: BaseRenderer
|
||||||
|
|
||||||
# ---- implicit function settings
|
# ---- implicit function settings
|
||||||
share_implicit_function_across_passes: bool = False
|
share_implicit_function_across_passes: bool = False
|
||||||
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||||
|
# pyre-fixme[13]: Attribute `implicit_function` is never initialized.
|
||||||
implicit_function: ImplicitFunctionBase
|
implicit_function: ImplicitFunctionBase
|
||||||
coarse_implicit_function_class_type: Optional[str] = None
|
coarse_implicit_function_class_type: Optional[str] = None
|
||||||
|
# pyre-fixme[13]: Attribute `coarse_implicit_function` is never initialized.
|
||||||
coarse_implicit_function: Optional[ImplicitFunctionBase]
|
coarse_implicit_function: Optional[ImplicitFunctionBase]
|
||||||
|
|
||||||
# ----- metrics
|
# ----- metrics
|
||||||
|
# pyre-fixme[13]: Attribute `view_metrics` is never initialized.
|
||||||
view_metrics: ViewMetricsBase
|
view_metrics: ViewMetricsBase
|
||||||
view_metrics_class_type: str = "ViewMetrics"
|
view_metrics_class_type: str = "ViewMetrics"
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `regularization_metrics` is never initialized.
|
||||||
regularization_metrics: RegularizationMetricsBase
|
regularization_metrics: RegularizationMetricsBase
|
||||||
regularization_metrics_class_type: str = "RegularizationMetrics"
|
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||||
|
|
||||||
|
@ -18,9 +18,7 @@ from .raymarcher import RaymarcherBase
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
||||||
BaseRenderer, torch.nn.Module
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Implements the multi-pass rendering function, in particular,
|
Implements the multi-pass rendering function, in particular,
|
||||||
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
||||||
@ -86,6 +84,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
|
raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
|
||||||
|
# pyre-fixme[13]: Attribute `raymarcher` is never initialized.
|
||||||
raymarcher: RaymarcherBase
|
raymarcher: RaymarcherBase
|
||||||
|
|
||||||
n_pts_per_ray_fine_training: int = 64
|
n_pts_per_ray_fine_training: int = 64
|
||||||
|
@ -16,8 +16,6 @@ from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
|
|||||||
|
|
||||||
|
|
||||||
@expand_args_fields
|
@expand_args_fields
|
||||||
# pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized.
|
|
||||||
# pyre-fixme[13]: Attribute `random_sampling` is never initialized.
|
|
||||||
class RayPointRefiner(Configurable, torch.nn.Module):
|
class RayPointRefiner(Configurable, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Implements the importance sampling of points along rays.
|
Implements the importance sampling of points along rays.
|
||||||
@ -45,7 +43,9 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
|||||||
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
for Anti-Aliasing Neural Radiance Fields." ICCV 2021.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized.
|
||||||
n_pts_per_ray: int
|
n_pts_per_ray: int
|
||||||
|
# pyre-fixme[13]: Attribute `random_sampling` is never initialized.
|
||||||
random_sampling: bool
|
random_sampling: bool
|
||||||
add_input_samples: bool = True
|
add_input_samples: bool = True
|
||||||
blurpool_weights: bool = False
|
blurpool_weights: bool = False
|
||||||
|
@ -24,9 +24,10 @@ from .rgb_net import RayNormalColoringNetwork
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ignore[13]
|
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||||
render_features_dimensions: int = 3
|
render_features_dimensions: int = 3
|
||||||
object_bounding_sphere: float = 1.0
|
object_bounding_sphere: float = 1.0
|
||||||
|
# pyre-fixme[13]: Attribute `ray_tracer` is never initialized.
|
||||||
ray_tracer: RayTracing
|
ray_tracer: RayTracing
|
||||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||||
RayNormalColoringNetwork
|
RayNormalColoringNetwork
|
||||||
|
@ -16,7 +16,6 @@ from .feature_aggregator import FeatureAggregatorBase
|
|||||||
from .view_sampler import ViewSampler
|
from .view_sampler import ViewSampler
|
||||||
|
|
||||||
|
|
||||||
# pyre-ignore: 13
|
|
||||||
class ViewPooler(Configurable, torch.nn.Module):
|
class ViewPooler(Configurable, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Implements sampling of image-based features at the 2d projections of a set
|
Implements sampling of image-based features at the 2d projections of a set
|
||||||
@ -35,8 +34,10 @@ class ViewPooler(Configurable, torch.nn.Module):
|
|||||||
from a set of source images. FeatureAggregator executes step (4) above.
|
from a set of source images. FeatureAggregator executes step (4) above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `view_sampler` is never initialized.
|
||||||
view_sampler: ViewSampler
|
view_sampler: ViewSampler
|
||||||
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
||||||
|
# pyre-fixme[13]: Attribute `feature_aggregator` is never initialized.
|
||||||
feature_aggregator: FeatureAggregatorBase
|
feature_aggregator: FeatureAggregatorBase
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -156,7 +156,6 @@ def render_point_cloud_pytorch3d(
|
|||||||
cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
|
cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
|
||||||
depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
|
depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
|
||||||
# add the rendering mask
|
# add the rendering mask
|
||||||
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
|
||||||
render_mask = -torch.prod(1.0 - weights, dim=-1) + 1.0
|
render_mask = -torch.prod(1.0 - weights, dim=-1) + 1.0
|
||||||
|
|
||||||
# cat depths and render mask
|
# cat depths and render mask
|
||||||
|
@ -409,6 +409,7 @@ def _parse_mtl(
|
|||||||
texture_files = {}
|
texture_files = {}
|
||||||
material_name = ""
|
material_name = ""
|
||||||
|
|
||||||
|
# pyre-fixme[9]: f has type `str`; used as `IO[typing.Any]`.
|
||||||
with _open_file(f, path_manager, "r") as f:
|
with _open_file(f, path_manager, "r") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
tokens = line.strip().split()
|
tokens = line.strip().split()
|
||||||
|
@ -756,10 +756,13 @@ def save_obj(
|
|||||||
output_path = Path(f)
|
output_path = Path(f)
|
||||||
|
|
||||||
# Save the .obj file
|
# Save the .obj file
|
||||||
|
# pyre-fixme[9]: f has type `Union[Path, str]`; used as `IO[typing.Any]`.
|
||||||
with _open_file(f, path_manager, "w") as f:
|
with _open_file(f, path_manager, "w") as f:
|
||||||
if save_texture:
|
if save_texture:
|
||||||
# Add the header required for the texture info to be loaded correctly
|
# Add the header required for the texture info to be loaded correctly
|
||||||
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
|
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
|
||||||
|
# pyre-fixme[16]: Item `Path` of `Union[Path, str]` has no attribute
|
||||||
|
# `write`.
|
||||||
f.write(obj_header)
|
f.write(obj_header)
|
||||||
_save(
|
_save(
|
||||||
f,
|
f,
|
||||||
|
@ -617,6 +617,7 @@ def _splat_points_to_volumes(
|
|||||||
w = wX * wY * wZ
|
w = wX * wY * wZ
|
||||||
|
|
||||||
# valid - binary indicators of votes that fall into the volume
|
# valid - binary indicators of votes that fall into the volume
|
||||||
|
# pyre-fixme[16]: `int` has no attribute `long`.
|
||||||
valid = (
|
valid = (
|
||||||
(0 <= X_)
|
(0 <= X_)
|
||||||
* (X_ < grid_sizes_xyz[:, None, 0:1])
|
* (X_ < grid_sizes_xyz[:, None, 0:1])
|
||||||
@ -635,14 +636,19 @@ def _splat_points_to_volumes(
|
|||||||
idx_valid = idx * valid + rand_idx * (1 - valid)
|
idx_valid = idx * valid + rand_idx * (1 - valid)
|
||||||
w_valid = w * valid.type_as(w)
|
w_valid = w * valid.type_as(w)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `int`.
|
||||||
w_valid = w_valid * mask.type_as(w)[:, :, None]
|
w_valid = w_valid * mask.type_as(w)[:, :, None]
|
||||||
|
|
||||||
# scatter add casts the votes into the weight accumulator
|
# scatter add casts the votes into the weight accumulator
|
||||||
# and the feature accumulator
|
# and the feature accumulator
|
||||||
|
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
|
||||||
|
# `Union[int, Tensor]`.
|
||||||
volume_densities.scatter_add_(1, idx_valid, w_valid)
|
volume_densities.scatter_add_(1, idx_valid, w_valid)
|
||||||
|
|
||||||
# reshape idx_valid -> (minibatch, feature_dim, n_points)
|
# reshape idx_valid -> (minibatch, feature_dim, n_points)
|
||||||
idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features)
|
idx_valid = idx_valid.view(ba, 1, n_points).expand_as(points_features)
|
||||||
|
# pyre-fixme[16]: Item `int` of `Union[int, Tensor]` has no
|
||||||
|
# attribute `view`.
|
||||||
w_valid = w_valid.view(ba, 1, n_points)
|
w_valid = w_valid.view(ba, 1, n_points)
|
||||||
|
|
||||||
# volume_features of shape (minibatch, feature_dim, n_voxels)
|
# volume_features of shape (minibatch, feature_dim, n_voxels)
|
||||||
@ -724,6 +730,7 @@ def _round_points_to_volumes(
|
|||||||
# valid - binary indicators of votes that fall into the volume
|
# valid - binary indicators of votes that fall into the volume
|
||||||
# pyre-fixme[9]: grid_sizes has type `LongTensor`; used as `Tensor`.
|
# pyre-fixme[9]: grid_sizes has type `LongTensor`; used as `Tensor`.
|
||||||
grid_sizes = grid_sizes.type_as(XYZ)
|
grid_sizes = grid_sizes.type_as(XYZ)
|
||||||
|
# pyre-fixme[16]: `int` has no attribute `long`.
|
||||||
valid = (
|
valid = (
|
||||||
(0 <= X)
|
(0 <= X)
|
||||||
* (X < grid_sizes_xyz[:, None, 0:1])
|
* (X < grid_sizes_xyz[:, None, 0:1])
|
||||||
|
@ -497,6 +497,7 @@ def clip_faces(
|
|||||||
faces_case3 = face_verts_unclipped[case3_unclipped_idx]
|
faces_case3 = face_verts_unclipped[case3_unclipped_idx]
|
||||||
|
|
||||||
# index (0, 1, or 2) of the vertex in front of the clipping plane
|
# index (0, 1, or 2) of the vertex in front of the clipping plane
|
||||||
|
# pyre-fixme[61]: `faces_clipped_verts` is undefined, or not always defined.
|
||||||
p1_face_ind = torch.where(~faces_clipped_verts[case3_unclipped_idx])[1]
|
p1_face_ind = torch.where(~faces_clipped_verts[case3_unclipped_idx])[1]
|
||||||
|
|
||||||
# Solve for the points p4, p5 that intersect the clipping plane
|
# Solve for the points p4, p5 that intersect the clipping plane
|
||||||
@ -540,6 +541,7 @@ def clip_faces(
|
|||||||
faces_case4 = face_verts_unclipped[case4_unclipped_idx]
|
faces_case4 = face_verts_unclipped[case4_unclipped_idx]
|
||||||
|
|
||||||
# index (0, 1, or 2) of the vertex behind the clipping plane
|
# index (0, 1, or 2) of the vertex behind the clipping plane
|
||||||
|
# pyre-fixme[61]: `faces_clipped_verts` is undefined, or not always defined.
|
||||||
p1_face_ind = torch.where(faces_clipped_verts[case4_unclipped_idx])[1]
|
p1_face_ind = torch.where(faces_clipped_verts[case4_unclipped_idx])[1]
|
||||||
|
|
||||||
# Solve for the points p4, p5 that intersect the clipping plane
|
# Solve for the points p4, p5 that intersect the clipping plane
|
||||||
|
@ -369,6 +369,7 @@ def plot_scene(
|
|||||||
# update camera viewpoint if provided
|
# update camera viewpoint if provided
|
||||||
if viewpoints_eye_at_up_world is not None:
|
if viewpoints_eye_at_up_world is not None:
|
||||||
# Use camera params for batch index or the first camera if only one provided.
|
# Use camera params for batch index or the first camera if only one provided.
|
||||||
|
# pyre-fixme[61]: `n_viewpoint_cameras` is undefined, or not always defined.
|
||||||
viewpoint_idx = min(n_viewpoint_cameras - 1, subplot_idx)
|
viewpoint_idx = min(n_viewpoint_cameras - 1, subplot_idx)
|
||||||
|
|
||||||
eye, at, up = (i[viewpoint_idx] for i in viewpoints_eye_at_up_world)
|
eye, at, up = (i[viewpoint_idx] for i in viewpoints_eye_at_up_world)
|
||||||
@ -627,7 +628,7 @@ def _add_struct_from_batch(
|
|||||||
|
|
||||||
|
|
||||||
def _add_mesh_trace(
|
def _add_mesh_trace(
|
||||||
fig: go.Figure, # pyre-ignore[11]
|
fig: go.Figure,
|
||||||
meshes: Meshes,
|
meshes: Meshes,
|
||||||
trace_name: str,
|
trace_name: str,
|
||||||
subplot_idx: int,
|
subplot_idx: int,
|
||||||
@ -673,6 +674,7 @@ def _add_mesh_trace(
|
|||||||
verts[~verts_used] = verts_center
|
verts[~verts_used] = verts_center
|
||||||
|
|
||||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||||
|
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Mesh3d(
|
go.Mesh3d(
|
||||||
x=verts[:, 0],
|
x=verts[:, 0],
|
||||||
@ -739,6 +741,7 @@ def _add_pointcloud_trace(
|
|||||||
|
|
||||||
row = subplot_idx // ncols + 1
|
row = subplot_idx // ncols + 1
|
||||||
col = subplot_idx % ncols + 1
|
col = subplot_idx % ncols + 1
|
||||||
|
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Scatter3d(
|
go.Scatter3d(
|
||||||
x=verts[:, 0],
|
x=verts[:, 0],
|
||||||
@ -800,6 +803,7 @@ def _add_camera_trace(
|
|||||||
x, y, z = all_cam_wires.detach().cpu().numpy().T.astype(float)
|
x, y, z = all_cam_wires.detach().cpu().numpy().T.astype(float)
|
||||||
|
|
||||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||||
|
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Scatter3d(x=x, y=y, z=z, marker={"size": 1}, name=trace_name),
|
go.Scatter3d(x=x, y=y, z=z, marker={"size": 1}, name=trace_name),
|
||||||
row=row,
|
row=row,
|
||||||
@ -894,6 +898,7 @@ def _add_ray_bundle_trace(
|
|||||||
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
|
ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
|
||||||
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
|
x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
|
||||||
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
|
||||||
|
# pyre-fixme[16]: `Figure` has no attribute `add_trace`.
|
||||||
fig.add_trace(
|
fig.add_trace(
|
||||||
go.Scatter3d(
|
go.Scatter3d(
|
||||||
x=x,
|
x=x,
|
||||||
@ -988,7 +993,7 @@ def _gen_fig_with_subplots(
|
|||||||
def _update_axes_bounds(
|
def _update_axes_bounds(
|
||||||
verts_center: torch.Tensor,
|
verts_center: torch.Tensor,
|
||||||
max_expand: float,
|
max_expand: float,
|
||||||
current_layout: go.Scene, # pyre-ignore[11]
|
current_layout: go.Scene,
|
||||||
) -> None: # pragma: no cover
|
) -> None: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Takes in the vertices' center point and max spread, and the current plotly figure
|
Takes in the vertices' center point and max spread, and the current plotly figure
|
||||||
@ -1005,6 +1010,7 @@ def _update_axes_bounds(
|
|||||||
|
|
||||||
# Ensure that within a subplot, the bounds capture all traces
|
# Ensure that within a subplot, the bounds capture all traces
|
||||||
old_xrange, old_yrange, old_zrange = (
|
old_xrange, old_yrange, old_zrange = (
|
||||||
|
# pyre-fixme[16]: `Scene` has no attribute `__getitem__`.
|
||||||
current_layout["xaxis"]["range"],
|
current_layout["xaxis"]["range"],
|
||||||
current_layout["yaxis"]["range"],
|
current_layout["yaxis"]["range"],
|
||||||
current_layout["zaxis"]["range"],
|
current_layout["zaxis"]["range"],
|
||||||
@ -1023,6 +1029,7 @@ def _update_axes_bounds(
|
|||||||
xaxis = {"range": x_range}
|
xaxis = {"range": x_range}
|
||||||
yaxis = {"range": y_range}
|
yaxis = {"range": y_range}
|
||||||
zaxis = {"range": z_range}
|
zaxis = {"range": z_range}
|
||||||
|
# pyre-fixme[16]: `Scene` has no attribute `update`.
|
||||||
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})
|
current_layout.update({"xaxis": xaxis, "yaxis": yaxis, "zaxis": zaxis})
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user