mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-06-17 12:38:53 +08:00
Enable Pyrefly in fbcode/vision/fair
Summary: Automated migration to enable Pyrefly type checking for `fbcode/vision/fair`. - Added `python.set_pyrefly(True)` to PACKAGE file - Suppressed pre-existing type errors Pyrefly is Meta's next-generation Python type checker, replacing Pyre. If you encounter issues, you can revert the PACKAGE change by removing the `python.set_pyrefly(True)` line. #pyreupgrade Differential Revision: D107142434 fbshipit-source-id: 25929bb3d5a310d00dab11a46c5395df94357feb
This commit is contained in:
committed by
meta-codesync[bot]
parent
b73d735ecf
commit
05025bf005
@@ -193,6 +193,7 @@ class Experiment(Configurable):
|
|||||||
last_epoch=start_epoch,
|
last_epoch=start_epoch,
|
||||||
model=model,
|
model=model,
|
||||||
resume=self.model_factory.resume,
|
resume=self.model_factory.resume,
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
resume_epoch=self.model_factory.resume_epoch,
|
resume_epoch=self.model_factory.resume_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -212,6 +213,7 @@ class Experiment(Configurable):
|
|||||||
|
|
||||||
# Enter the main training loop.
|
# Enter the main training loop.
|
||||||
self.training_loop.run(
|
self.training_loop.run(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
train_loader=train_loader,
|
train_loader=train_loader,
|
||||||
val_loader=val_loader,
|
val_loader=val_loader,
|
||||||
test_loader=test_loader,
|
test_loader=test_loader,
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
|
lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
elif self.lr_policy.casefold() == "LinearExponential".casefold():
|
elif self.lr_policy.casefold() == "LinearExponential".casefold():
|
||||||
@@ -191,7 +192,11 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
|
|
||||||
# pyre-fixme[28]: Unexpected keyword argument `verbose`.
|
# pyre-fixme[28]: Unexpected keyword argument `verbose`.
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||||
optimizer, _get_lr, verbose=False
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
|
optimizer,
|
||||||
|
_get_lr,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("no such lr policy %s" % self.lr_policy)
|
raise ValueError("no such lr policy %s" % self.lr_policy)
|
||||||
|
|||||||
@@ -199,6 +199,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
and self.test_interval > 0
|
and self.test_interval > 0
|
||||||
and epoch % self.test_interval == 0
|
and epoch % self.test_interval == 0
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.evaluator.run(
|
self.evaluator.run(
|
||||||
device=device,
|
device=device,
|
||||||
dataloader=test_loader,
|
dataloader=test_loader,
|
||||||
@@ -215,6 +216,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
|
|
||||||
if self.test_when_finished:
|
if self.test_when_finished:
|
||||||
if test_loader is not None:
|
if test_loader is not None:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.evaluator.run(
|
self.evaluator.run(
|
||||||
device=device,
|
device=device,
|
||||||
dump_to_json=True,
|
dump_to_json=True,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def seed_all_random_engines(seed: int) -> None:
|
def seed_all_random_engines(seed: int) -> None:
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ def visualize_reconstruction(
|
|||||||
|
|
||||||
# Set the random seeds
|
# Set the random seeds
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
# Get the config from the experiment_directory,
|
# Get the config from the experiment_directory,
|
||||||
@@ -135,6 +136,7 @@ def visualize_reconstruction(
|
|||||||
"device": device,
|
"device": device,
|
||||||
**render_flyaround_kwargs,
|
**render_flyaround_kwargs,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
render_flyaround(**render_kwargs)
|
render_flyaround(**render_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class _SymEig3x3(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
p = torch.sqrt(p2 / 6.0)
|
p = torch.sqrt(p2 / 6.0)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
||||||
|
|
||||||
r = torch.det(B) / 2.0
|
r = torch.det(B) / 2.0
|
||||||
@@ -174,8 +175,10 @@ class _SymEig3x3(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Find the eigenvector corresponding to alpha0, its eigenvalue is distinct
|
# Find the eigenvector corresponding to alpha0, its eigenvalue is distinct
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
ev0 = self._get_ev0(inputs - alpha0[..., None, None] * self._identity)
|
ev0 = self._get_ev0(inputs - alpha0[..., None, None] * self._identity)
|
||||||
u, v = self._get_uv(ev0)
|
u, v = self._get_uv(ev0)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
ev1 = self._get_ev1(inputs - alpha1[..., None, None] * self._identity, u, v)
|
ev1 = self._get_ev1(inputs - alpha1[..., None, None] * self._identity, u, v)
|
||||||
# Third eigenvector is computed as the cross-product of the other two
|
# Third eigenvector is computed as the cross-product of the other two
|
||||||
ev2 = torch.cross(ev0, ev1, dim=-1)
|
ev2 = torch.cross(ev0, ev1, dim=-1)
|
||||||
@@ -250,6 +253,7 @@ class _SymEig3x3(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
min_idx = w.abs().argmin(dim=-1)
|
min_idx = w.abs().argmin(dim=-1)
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
rotation_2d = self._rotations_3d[min_idx].to(w)
|
rotation_2d = self._rotations_3d[min_idx].to(w)
|
||||||
|
|
||||||
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ class R2N2(ShapeNetBase): # pragma: no cover
|
|||||||
) % (shapenet_dir, ", ".join(synset_not_present))
|
) % (shapenet_dir, ", ".join(synset_not_present))
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
|
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
|
||||||
"""
|
"""
|
||||||
Read a model by the given index.
|
Read a model by the given index.
|
||||||
@@ -370,6 +371,7 @@ class R2N2(ShapeNetBase): # pragma: no cover
|
|||||||
T = RT[3, :3]
|
T = RT[3, :3]
|
||||||
return R, T
|
return R, T
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
model_ids: Optional[List[str]] = None,
|
model_ids: Optional[List[str]] = None,
|
||||||
|
|||||||
@@ -62,8 +62,10 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover
|
|||||||
# all models have the same number of views V, stack the batches of
|
# all models have the same number of views V, stack the batches of
|
||||||
# views of each model into a new batch of shape (N, V, H, W, 3).
|
# views of each model into a new batch of shape (N, V, H, W, 3).
|
||||||
# Otherwise leave it as a list.
|
# Otherwise leave it as a list.
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
if "images" in collated_dict:
|
if "images" in collated_dict:
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["images"] = torch.stack(collated_dict["images"])
|
collated_dict["images"] = torch.stack(collated_dict["images"])
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(
|
print(
|
||||||
@@ -75,10 +77,14 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover
|
|||||||
# matrices and that all models have the same number of views V, stack each
|
# matrices and that all models have the same number of views V, stack each
|
||||||
# type of matrices into a new batch of shape (N, V, ...).
|
# type of matrices into a new batch of shape (N, V, ...).
|
||||||
# Otherwise leave them as lists.
|
# Otherwise leave them as lists.
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
if all(x in collated_dict for x in ["R", "T", "K"]):
|
if all(x in collated_dict for x in ["R", "T", "K"]):
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3)
|
collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3)
|
collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4)
|
collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(
|
print(
|
||||||
@@ -89,8 +95,10 @@ def collate_batched_R2N2(batch: List[Dict]): # pragma: no cover
|
|||||||
# If collate_batched_meshes receives voxels and all models have the same
|
# If collate_batched_meshes receives voxels and all models have the same
|
||||||
# number of views V, stack the batches of voxels into a new batch of shape
|
# number of views V, stack the batches of voxels into a new batch of shape
|
||||||
# (N, V, S, S, S), where S is the voxel size.
|
# (N, V, S, S, S), where S is the voxel size.
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
if "voxels" in collated_dict:
|
if "voxels" in collated_dict:
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["voxels"] = torch.stack(collated_dict["voxels"])
|
collated_dict["voxels"] = torch.stack(collated_dict["voxels"])
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(
|
print(
|
||||||
@@ -458,6 +466,7 @@ class BlenderCamera(CamerasBase): # pragma: no cover
|
|||||||
|
|
||||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||||
transform = Transform3d(device=self.device)
|
transform = Transform3d(device=self.device)
|
||||||
|
# pyrefly: ignore [not-callable]
|
||||||
transform._matrix = self.K.transpose(1, 2).contiguous()
|
transform._matrix = self.K.transpose(1, 2).contiguous()
|
||||||
return transform
|
return transform
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
|||||||
"""
|
"""
|
||||||
return len(self.model_ids)
|
return len(self.model_ids)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override-param-name]
|
||||||
def __getitem__(self, idx) -> Dict:
|
def __getitem__(self, idx) -> Dict:
|
||||||
"""
|
"""
|
||||||
Read a model by the given index. Need to be implemented for every child class
|
Read a model by the given index. Need to be implemented for every child class
|
||||||
@@ -147,12 +148,17 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
|||||||
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
||||||
# Use the getitem method which loads mesh + texture
|
# Use the getitem method which loads mesh + texture
|
||||||
models = [self[idx] for idx in idxs]
|
models = [self[idx] for idx in idxs]
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
meshes = collate_batched_meshes(models)["mesh"]
|
meshes = collate_batched_meshes(models)["mesh"]
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
if meshes.textures is None:
|
if meshes.textures is None:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
meshes.textures = TexturesVertex(
|
meshes.textures = TexturesVertex(
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
meshes = meshes.to(device)
|
meshes = meshes.to(device)
|
||||||
cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device)
|
cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device)
|
||||||
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
|
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
|
||||||
|
|||||||
@@ -34,12 +34,14 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
|||||||
for k in batch[0].keys():
|
for k in batch[0].keys():
|
||||||
collated_dict[k] = [d[k] for d in batch]
|
collated_dict[k] = [d[k] for d in batch]
|
||||||
|
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["mesh"] = None
|
collated_dict["mesh"] = None
|
||||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||||
textures = None
|
textures = None
|
||||||
if "textures" in collated_dict:
|
if "textures" in collated_dict:
|
||||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||||
|
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
collated_dict["mesh"] = Meshes(
|
collated_dict["mesh"] = Meshes(
|
||||||
verts=collated_dict["verts"],
|
verts=collated_dict["verts"],
|
||||||
faces=collated_dict["faces"],
|
faces=collated_dict["faces"],
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
|||||||
# crashes without overriding __getitem__
|
# crashes without overriding __getitem__
|
||||||
sequence_category = self[first_frame_idx].sequence_category
|
sequence_category = self[first_frame_idx].sequence_category
|
||||||
c2seq[sequence_category].append(sequence_name)
|
c2seq[sequence_category].append(sequence_name)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return dict(c2seq)
|
return dict(c2seq)
|
||||||
|
|
||||||
def sequence_frames_in_order(
|
def sequence_frames_in_order(
|
||||||
|
|||||||
@@ -297,7 +297,11 @@ class FrameData(Mapping[str, Any]):
|
|||||||
depth_map = self.depth_map
|
depth_map = self.depth_map
|
||||||
if depth_map is not None:
|
if depth_map is not None:
|
||||||
clamp_bbox_xyxy_depth = rescale_bbox(
|
clamp_bbox_xyxy_depth = rescale_bbox(
|
||||||
clamp_bbox_xyxy, tuple(depth_map.shape[-2:]), effective_image_size_hw
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
clamp_bbox_xyxy,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
tuple(depth_map.shape[-2:]),
|
||||||
|
effective_image_size_hw,
|
||||||
).long()
|
).long()
|
||||||
self.depth_map = crop_around_box(
|
self.depth_map = crop_around_box(
|
||||||
depth_map,
|
depth_map,
|
||||||
@@ -308,7 +312,11 @@ class FrameData(Mapping[str, Any]):
|
|||||||
depth_mask = self.depth_mask
|
depth_mask = self.depth_mask
|
||||||
if depth_mask is not None:
|
if depth_mask is not None:
|
||||||
clamp_bbox_xyxy_depth = rescale_bbox(
|
clamp_bbox_xyxy_depth = rescale_bbox(
|
||||||
clamp_bbox_xyxy, tuple(depth_mask.shape[-2:]), effective_image_size_hw
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
clamp_bbox_xyxy,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
tuple(depth_mask.shape[-2:]),
|
||||||
|
effective_image_size_hw,
|
||||||
).long()
|
).long()
|
||||||
self.depth_mask = crop_around_box(
|
self.depth_mask = crop_around_box(
|
||||||
depth_mask,
|
depth_mask,
|
||||||
@@ -453,6 +461,7 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# To be initialised to FrameDataSubtype
|
# To be initialised to FrameDataSubtype
|
||||||
|
# pyrefly: ignore [invalid-annotation]
|
||||||
frame_data_type: ClassVar[Type[FrameDataSubtype]]
|
frame_data_type: ClassVar[Type[FrameDataSubtype]]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -425,6 +425,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
raise ValueError("subsets not loaded")
|
raise ValueError("subsets not loaded")
|
||||||
if is_known_frame_scalar(frame_type):
|
if is_known_frame_scalar(frame_type):
|
||||||
cameras.append(self[frame_idx].camera)
|
cameras.append(self[frame_idx].camera)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return join_cameras_as_batch(cameras)
|
return join_cameras_as_batch(cameras)
|
||||||
|
|
||||||
def __getitem__(self, index) -> FrameData:
|
def __getitem__(self, index) -> FrameData:
|
||||||
|
|||||||
@@ -311,6 +311,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
|||||||
subset_mapping["test"],
|
subset_mapping["test"],
|
||||||
) = self._extend_test_data_with_known_views(
|
) = self._extend_test_data_with_known_views(
|
||||||
subset_mapping,
|
subset_mapping,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
eval_batch_index,
|
eval_batch_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -322,6 +323,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
|||||||
try:
|
try:
|
||||||
test_dataset.eval_batches = (
|
test_dataset.eval_batches = (
|
||||||
test_dataset.seq_frame_index_to_dataset_index(
|
test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
eval_batch_index,
|
eval_batch_index,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -335,6 +337,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
|||||||
)
|
)
|
||||||
test_dataset.eval_batches = (
|
test_dataset.eval_batches = (
|
||||||
test_dataset.seq_frame_index_to_dataset_index(
|
test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
eval_batch_index,
|
eval_batch_index,
|
||||||
allow_missing_indices=True,
|
allow_missing_indices=True,
|
||||||
remove_missing_indices=True,
|
remove_missing_indices=True,
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
|
|||||||
impl = LargeBinary
|
impl = LargeBinary
|
||||||
_format = format_symbol * math.prod(shape)
|
_format = format_symbol * math.prod(shape)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override-param-name]
|
||||||
def process_bind_param(self, value, _):
|
def process_bind_param(self, value, _):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
@@ -99,6 +100,7 @@ def TupleTypeFactory(dtype=float, shape: Tuple[int, ...] = (2,)):
|
|||||||
|
|
||||||
return struct.pack(TupleType._format, *value)
|
return struct.pack(TupleType._format, *value)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override-param-name]
|
||||||
def process_result_value(self, value, _):
|
def process_result_value(self, value, _):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -121,11 +121,17 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase):
|
|||||||
self.poses = poses.cpu()
|
self.poses = poses.cpu()
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
|
self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
object_name="cow",
|
object_name="cow",
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
images=list(images.permute(0, 3, 1, 2).cpu()),
|
images=list(images.permute(0, 3, 1, 2).cpu()),
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
fg_probabilities=list(masks[:, None].cpu()),
|
fg_probabilities=list(masks[:, None].cpu()),
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
poses=[self.poses[i] for i in range(len(poses))],
|
poses=[self.poses[i] for i in range(len(poses))],
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
frame_types=[DATASET_TYPE_KNOWN] * len(poses),
|
frame_types=[DATASET_TYPE_KNOWN] * len(poses),
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
eval_batches=None,
|
eval_batches=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -131,10 +131,12 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
None
|
None
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
if self.fg_probabilities is None
|
if self.fg_probabilities is None
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
else self.fg_probabilities[split]
|
else self.fg_probabilities[split]
|
||||||
)
|
)
|
||||||
eval_batches = [[i] for i in range(len(split))]
|
eval_batches = [[i] for i in range(len(split))]
|
||||||
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
train_split = self.i_split[0]
|
train_split = self.i_split[0]
|
||||||
if set_eval_batches:
|
if set_eval_batches:
|
||||||
generator = np.random.default_rng(seed=0)
|
generator = np.random.default_rng(seed=0)
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._index)
|
return len(self._index)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override-param-name]
|
||||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||||
"""
|
"""
|
||||||
Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair
|
Fetches FrameData by either iloc in the index or by (sequence, frame_no) pair
|
||||||
@@ -424,6 +425,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
|
|
||||||
# override
|
# override
|
||||||
@property
|
@property
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def frame_data_type(self) -> Type[FrameData]:
|
def frame_data_type(self) -> Type[FrameData]:
|
||||||
return self.frame_data_builder.frame_data_type
|
return self.frame_data_builder.frame_data_type
|
||||||
|
|
||||||
@@ -630,7 +632,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
# dev load: 94 s / 23 s (3.1M / 500K)
|
# dev load: 94 s / 23 s (3.1M / 500K)
|
||||||
pick_frames_criteria.append(
|
pick_frames_criteria.append(
|
||||||
sa.or_(
|
sa.or_(
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.frame_annotations_type._mask_mass.is_(None),
|
self.frame_annotations_type._mask_mass.is_(None),
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.frame_annotations_type._mask_mass != 0,
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -697,6 +701,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
where_conditions.append(
|
where_conditions.append(
|
||||||
sa.or_(
|
sa.or_(
|
||||||
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.frame_annotations_type._mask_mass != 0,
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
|||||||
else:
|
else:
|
||||||
res = _dataclass_from_dict(asdict, cls)
|
res = _dataclass_from_dict(asdict, cls)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ def crop_around_box(
|
|||||||
# bbox is xyxy, where the upper bound is corrected with +1
|
# bbox is xyxy, where the upper bound is corrected with +1
|
||||||
bbox = clamp_box_to_image_bounds_and_round(
|
bbox = clamp_box_to_image_bounds_and_round(
|
||||||
bbox,
|
bbox,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
image_size_hw=tuple(tensor.shape[-2:]),
|
image_size_hw=tuple(tensor.shape[-2:]),
|
||||||
)
|
)
|
||||||
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
||||||
|
|||||||
@@ -312,6 +312,7 @@ def eval_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if visualize:
|
if visualize:
|
||||||
|
# pyrefly: ignore [unbound-name]
|
||||||
visualizer.show_rgb(
|
visualizer.show_rgb(
|
||||||
results[metric_name].item(), metric_name, loss_mask_now
|
results[metric_name].item(), metric_name, loss_mask_now
|
||||||
)
|
)
|
||||||
@@ -330,6 +331,7 @@ def eval_batch(
|
|||||||
results["depth_abs" + name_postfix] = abs_.mean()
|
results["depth_abs" + name_postfix] = abs_.mean()
|
||||||
|
|
||||||
if visualize:
|
if visualize:
|
||||||
|
# pyrefly: ignore [unbound-name]
|
||||||
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
||||||
if break_after_visualising:
|
if break_after_visualising:
|
||||||
breakpoint() # noqa: B601
|
breakpoint() # noqa: B601
|
||||||
|
|||||||
@@ -472,6 +472,7 @@ class GenericModel(ImplicitronModelBase):
|
|||||||
sequence_name=safe_slice_targets(sequence_name),
|
sequence_name=safe_slice_targets(sequence_name),
|
||||||
frame_timestamp=safe_slice_targets(frame_timestamp),
|
frame_timestamp=safe_slice_targets(frame_timestamp),
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
custom_args["global_code"] = global_code
|
custom_args["global_code"] = global_code
|
||||||
|
|
||||||
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||||
|
|||||||
@@ -907,6 +907,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Torch Module to hold parameters since they can only be registered
|
# Torch Module to hold parameters since they can only be registered
|
||||||
# at object level.
|
# at object level.
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.params = _RegistratedBufferDict(vars(params))
|
self.params = _RegistratedBufferDict(vars(params))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -476,6 +476,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
call_epochs = list(
|
call_epochs = list(
|
||||||
set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
|
set(self.scaffold_calculating_epochs) | set(self.volume_cropping_epochs)
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return call_epochs, callback
|
return call_epochs, callback
|
||||||
|
|
||||||
def _crop(self, epoch: int) -> bool:
|
def _crop(self, epoch: int) -> bool:
|
||||||
@@ -581,6 +582,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
cls = registry.get(DecoderFunctionBase, type_)
|
cls = registry.get(DecoderFunctionBase, type_)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||||
if need_input_dim:
|
if need_input_dim:
|
||||||
self.decoder_density = cls(input_dim=input_dim, **args)
|
self.decoder_density = cls(input_dim=input_dim, **args)
|
||||||
@@ -621,6 +623,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
input_dim = input_dim0 + input_dim1
|
input_dim = input_dim0 + input_dim1
|
||||||
|
|
||||||
cls = registry.get(DecoderFunctionBase, type_)
|
cls = registry.get(DecoderFunctionBase, type_)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||||
if need_input_dim:
|
if need_input_dim:
|
||||||
self.decoder_color = cls(input_dim=input_dim, **args)
|
self.decoder_color = cls(input_dim=input_dim, **args)
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ class ImplicitronRayBundle:
|
|||||||
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
||||||
# pyre-ignore
|
# pyre-ignore
|
||||||
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
|
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._lengths
|
return self._lengths
|
||||||
|
|
||||||
@lengths.setter
|
@lengths.setter
|
||||||
@@ -166,6 +167,7 @@ class ImplicitronRayBundle:
|
|||||||
)
|
)
|
||||||
num_inputs = camera_counts.sum().item()
|
num_inputs = camera_counts.sum().item()
|
||||||
max_size = torch.max(camera_counts).item()
|
max_size = torch.max(camera_counts).item()
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
xys = packed_to_padded(self.xys, first_idxs, max_size)
|
xys = packed_to_padded(self.xys, first_idxs, max_size)
|
||||||
# pyre-ignore [7] pytorch typeshed inaccuracy
|
# pyre-ignore [7] pytorch typeshed inaccuracy
|
||||||
return xys, first_idxs, num_inputs
|
return xys, first_idxs, num_inputs
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
|||||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||||
|
|
||||||
alpha = opacities if self.blend_output else 1
|
alpha = opacities if self.blend_output else 1
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
|
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
|
||||||
raise ValueError("Wrong number of background color channels.")
|
raise ValueError("Wrong number of background color channels.")
|
||||||
# pyre-fixme[58]: `*` is not supported for operand types `int` and
|
# pyre-fixme[58]: `*` is not supported for operand types `int` and
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
# Sample points for the eikonal loss
|
# Sample points for the eikonal loss
|
||||||
eik_bounding_box: float = self.object_bounding_sphere
|
eik_bounding_box: float = self.object_bounding_sphere
|
||||||
n_eik_points = batch_size * num_pixels // 2
|
n_eik_points = batch_size * num_pixels // 2
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
eikonal_points = torch.empty(
|
eikonal_points = torch.empty(
|
||||||
n_eik_points,
|
n_eik_points,
|
||||||
3,
|
3,
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ def weighted_sum_losses(
|
|||||||
return None
|
return None
|
||||||
loss = sum(losses_weighted)
|
loss = sum(losses_weighted)
|
||||||
assert torch.is_tensor(loss)
|
assert torch.is_tensor(loss)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -231,7 +231,9 @@ class Configurable:
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore [invalid-type-var]
|
||||||
_X = TypeVar("X", bound=ReplaceableBase)
|
_X = TypeVar("X", bound=ReplaceableBase)
|
||||||
|
# pyrefly: ignore [invalid-type-var]
|
||||||
_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
|
_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
|
||||||
|
|
||||||
|
|
||||||
@@ -890,10 +892,13 @@ def expand_args_fields(
|
|||||||
continue
|
continue
|
||||||
expand_args_fields(base, _do_not_process=_do_not_process)
|
expand_args_fields(base, _do_not_process=_do_not_process)
|
||||||
if "_creation_functions" in base.__dict__:
|
if "_creation_functions" in base.__dict__:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
creation_functions.extend(base._creation_functions)
|
creation_functions.extend(base._creation_functions)
|
||||||
if "_known_implementations" in base.__dict__:
|
if "_known_implementations" in base.__dict__:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
known_implementations.update(base._known_implementations)
|
known_implementations.update(base._known_implementations)
|
||||||
if "_processed_members" in base.__dict__:
|
if "_processed_members" in base.__dict__:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
processed_members.update(base._processed_members)
|
processed_members.update(base._processed_members)
|
||||||
|
|
||||||
to_process: List[Tuple[str, Type, _ProcessType]] = []
|
to_process: List[Tuple[str, Type, _ProcessType]] = []
|
||||||
|
|||||||
@@ -62,9 +62,14 @@ def rasterize_sparse_ray_bundle(
|
|||||||
|
|
||||||
max_size = torch.max(camera_counts).item()
|
max_size = torch.max(camera_counts).item()
|
||||||
features_depth_ras = packed_to_padded(
|
features_depth_ras = packed_to_padded(
|
||||||
features_depth_ras[:, 0], first_idxs, max_size
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
features_depth_ras[:, 0],
|
||||||
|
first_idxs,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
max_size,
|
||||||
)
|
)
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padded_mask = packed_to_padded(masks.flatten(1, -1), first_idxs, max_size)
|
padded_mask = packed_to_padded(masks.flatten(1, -1), first_idxs, max_size)
|
||||||
masks_ras = padded_mask * masks_ras
|
masks_ras = padded_mask * masks_ras
|
||||||
|
|
||||||
|
|||||||
@@ -124,10 +124,12 @@ class VideoWriter:
|
|||||||
if im is not None:
|
if im is not None:
|
||||||
if resize is not None:
|
if resize is not None:
|
||||||
if isinstance(resize, float):
|
if isinstance(resize, float):
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
resize = [int(resize * s) for s in im.size]
|
resize = [int(resize * s) for s in im.size]
|
||||||
else:
|
else:
|
||||||
resize = im.size
|
resize = im.size
|
||||||
# make sure size is divisible by 2
|
# make sure size is divisible by 2
|
||||||
|
# pyrefly: ignore [bad-assignment, bad-index, unsupported-operation]
|
||||||
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
||||||
|
|
||||||
im = im.resize(resize, Image.ANTIALIAS)
|
im = im.resize(resize, Image.ANTIALIAS)
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ class _TargetType(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class OurEncoder(json.JSONEncoder):
|
class OurEncoder(json.JSONEncoder):
|
||||||
|
# pyrefly: ignore [bad-override-param-name]
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, np.int64):
|
if isinstance(obj, np.int64):
|
||||||
return str(obj)
|
return str(obj)
|
||||||
@@ -242,6 +243,7 @@ class _GLTFLoader:
|
|||||||
by _get_texture_map_image which caches it.
|
by _get_texture_map_image which caches it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
image_json = self._json_data["images"][image_index]
|
image_json = self._json_data["images"][image_index]
|
||||||
buffer_view = self._buffer_views[image_json["bufferView"]]
|
buffer_view = self._buffer_views[image_json["bufferView"]]
|
||||||
if "byteStride" in buffer_view:
|
if "byteStride" in buffer_view:
|
||||||
@@ -407,10 +409,12 @@ class _GLTFLoader:
|
|||||||
verts_uvs[:, 1] = 1 - verts_uvs[:, -1]
|
verts_uvs[:, 1] = 1 - verts_uvs[:, -1]
|
||||||
faces_uvs = indices
|
faces_uvs = indices
|
||||||
material_index = primitive.get("material", 0)
|
material_index = primitive.get("material", 0)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
material = self._json_data["materials"][material_index]
|
material = self._json_data["materials"][material_index]
|
||||||
material_roughness = material["pbrMetallicRoughness"]
|
material_roughness = material["pbrMetallicRoughness"]
|
||||||
if "baseColorTexture" in material_roughness:
|
if "baseColorTexture" in material_roughness:
|
||||||
texture_index = material_roughness["baseColorTexture"]["index"]
|
texture_index = material_roughness["baseColorTexture"]["index"]
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
texture_json = self._json_data["textures"][texture_index]
|
texture_json = self._json_data["textures"][texture_index]
|
||||||
# Todo - include baseColorFactor when also given
|
# Todo - include baseColorFactor when also given
|
||||||
# Todo - look at the sampler
|
# Todo - look at the sampler
|
||||||
@@ -555,6 +559,7 @@ class _GLTFWriter:
|
|||||||
# pyre-fixme[6]: Incompatible parameter type
|
# pyre-fixme[6]: Incompatible parameter type
|
||||||
self._json_data["scene"] = scene_index
|
self._json_data["scene"] = scene_index
|
||||||
self._json_data["scenes"].append({"nodes": [scene_index]})
|
self._json_data["scenes"].append({"nodes": [scene_index]})
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
self._json_data["asset"] = {"version": "2.0"}
|
self._json_data["asset"] = {"version": "2.0"}
|
||||||
node = {"name": "Node", "mesh": 0}
|
node = {"name": "Node", "mesh": 0}
|
||||||
self._json_data["nodes"].append(node)
|
self._json_data["nodes"].append(node)
|
||||||
@@ -621,6 +626,7 @@ class _GLTFWriter:
|
|||||||
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||||
elif key == "texcoords":
|
elif key == "texcoords":
|
||||||
component_type = _ComponentType.FLOAT
|
component_type = _ComponentType.FLOAT
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
data = self.mesh.textures.verts_uvs_list()[0].cpu().numpy()
|
data = self.mesh.textures.verts_uvs_list()[0].cpu().numpy()
|
||||||
data[:, 1] = 1 - data[:, -1] # flip y tex-coordinate
|
data[:, 1] = 1 - data[:, -1] # flip y tex-coordinate
|
||||||
element_type = "VEC2"
|
element_type = "VEC2"
|
||||||
@@ -630,6 +636,7 @@ class _GLTFWriter:
|
|||||||
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||||
elif key == "texvertices":
|
elif key == "texvertices":
|
||||||
component_type = _ComponentType.FLOAT
|
component_type = _ComponentType.FLOAT
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
data = self.mesh.textures.verts_features_list()[0].cpu().numpy()
|
data = self.mesh.textures.verts_features_list()[0].cpu().numpy()
|
||||||
element_type = "VEC3"
|
element_type = "VEC3"
|
||||||
buffer_view = 2
|
buffer_view = 2
|
||||||
@@ -700,11 +707,14 @@ class _GLTFWriter:
|
|||||||
target = _TargetType.ELEMENT_ARRAY_BUFFER
|
target = _TargetType.ELEMENT_ARRAY_BUFFER
|
||||||
|
|
||||||
bufferview["target"] = target
|
bufferview["target"] = target
|
||||||
|
# pyrefly: ignore [bad-typed-dict-key]
|
||||||
bufferview["byteOffset"] = kwargs.get("offset")
|
bufferview["byteOffset"] = kwargs.get("offset")
|
||||||
|
# pyrefly: ignore [bad-typed-dict-key]
|
||||||
bufferview["byteLength"] = kwargs.get("byte_length")
|
bufferview["byteLength"] = kwargs.get("byte_length")
|
||||||
self._json_data["bufferViews"].append(bufferview)
|
self._json_data["bufferViews"].append(bufferview)
|
||||||
|
|
||||||
def _write_image_buffer(self, **kwargs) -> Tuple[int, bytes]:
|
def _write_image_buffer(self, **kwargs) -> Tuple[int, bytes]:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
image_np = self.mesh.textures.maps_list()[0].cpu().numpy()
|
image_np = self.mesh.textures.maps_list()[0].cpu().numpy()
|
||||||
image_array = (image_np * 255.0).astype(np.uint8)
|
image_array = (image_np * 255.0).astype(np.uint8)
|
||||||
im = Image.fromarray(image_array)
|
im = Image.fromarray(image_array)
|
||||||
@@ -716,6 +726,7 @@ class _GLTFWriter:
|
|||||||
bufferview_image = {
|
bufferview_image = {
|
||||||
"buffer": 0,
|
"buffer": 0,
|
||||||
}
|
}
|
||||||
|
# pyrefly: ignore [bad-typed-dict-key]
|
||||||
bufferview_image["byteOffset"] = kwargs.get("offset")
|
bufferview_image["byteOffset"] = kwargs.get("offset")
|
||||||
bufferview_image["byteLength"] = image_data_byte_length
|
bufferview_image["byteLength"] = image_data_byte_length
|
||||||
self._json_data["bufferViews"].append(bufferview_image)
|
self._json_data["bufferViews"].append(bufferview_image)
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ def load_obj(
|
|||||||
with _open_file(f, path_manager, "r") as f:
|
with _open_file(f, path_manager, "r") as f:
|
||||||
return _load_obj(
|
return _load_obj(
|
||||||
f,
|
f,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
load_textures=load_textures,
|
load_textures=load_textures,
|
||||||
create_texture_atlas=create_texture_atlas,
|
create_texture_atlas=create_texture_atlas,
|
||||||
@@ -641,6 +642,7 @@ def _load_obj(
|
|||||||
material_names.append(next(iter(material_colors.keys())))
|
material_names.append(next(iter(material_colors.keys())))
|
||||||
# replace all -1 by 0 material idx
|
# replace all -1 by 0 material idx
|
||||||
if torch.is_tensor(faces_materials_idx):
|
if torch.is_tensor(faces_materials_idx):
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
faces_materials_idx.clamp_(min=0)
|
faces_materials_idx.clamp_(min=0)
|
||||||
|
|
||||||
if create_texture_atlas:
|
if create_texture_atlas:
|
||||||
@@ -649,14 +651,18 @@ def _load_obj(
|
|||||||
|
|
||||||
# Create an array of strings of material names for each face.
|
# Create an array of strings of material names for each face.
|
||||||
# If faces_materials_idx == -1 then that face doesn't have a material.
|
# If faces_materials_idx == -1 then that face doesn't have a material.
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
idx = faces_materials_idx.cpu().numpy()
|
idx = faces_materials_idx.cpu().numpy()
|
||||||
face_material_names = np.array([""] + material_names)[idx + 1] # (F,)
|
face_material_names = np.array([""] + material_names)[idx + 1] # (F,)
|
||||||
|
|
||||||
# Construct the atlas.
|
# Construct the atlas.
|
||||||
texture_atlas = make_mesh_texture_atlas(
|
texture_atlas = make_mesh_texture_atlas(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
material_colors,
|
material_colors,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
texture_images,
|
texture_images,
|
||||||
face_material_names,
|
face_material_names,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
faces_textures_idx,
|
faces_textures_idx,
|
||||||
verts_uvs,
|
verts_uvs,
|
||||||
texture_atlas_size,
|
texture_atlas_size,
|
||||||
|
|||||||
@@ -878,10 +878,14 @@ def _get_verts_column_indices(
|
|||||||
):
|
):
|
||||||
color_scale = 1.0 / 255
|
color_scale = 1.0 / 255
|
||||||
return _VertsColumnIndices(
|
return _VertsColumnIndices(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
point_idxs=point_idxs,
|
point_idxs=point_idxs,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
color_idxs=None if None in color_idxs else color_idxs,
|
color_idxs=None if None in color_idxs else color_idxs,
|
||||||
color_scale=color_scale,
|
color_scale=color_scale,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
normal_idxs=None if None in normal_idxs else normal_idxs,
|
normal_idxs=None if None in normal_idxs else normal_idxs,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
|
texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ class _PointFaceDistance(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_dists):
|
def backward(ctx, grad_dists):
|
||||||
grad_dists = grad_dists.contiguous()
|
grad_dists = grad_dists.contiguous()
|
||||||
points, tris, idxs = ctx.saved_tensors
|
points, tris, idxs = ctx.saved_tensors
|
||||||
@@ -143,6 +144,7 @@ class _FacePointDistance(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_dists):
|
def backward(ctx, grad_dists):
|
||||||
grad_dists = grad_dists.contiguous()
|
grad_dists = grad_dists.contiguous()
|
||||||
points, tris, idxs = ctx.saved_tensors
|
points, tris, idxs = ctx.saved_tensors
|
||||||
@@ -194,6 +196,7 @@ class _PointEdgeDistance(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_dists):
|
def backward(ctx, grad_dists):
|
||||||
grad_dists = grad_dists.contiguous()
|
grad_dists = grad_dists.contiguous()
|
||||||
points, segms, idxs = ctx.saved_tensors
|
points, segms, idxs = ctx.saved_tensors
|
||||||
@@ -244,6 +247,7 @@ class _EdgePointDistance(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_dists):
|
def backward(ctx, grad_dists):
|
||||||
grad_dists = grad_dists.contiguous()
|
grad_dists = grad_dists.contiguous()
|
||||||
points, segms, idxs = ctx.saved_tensors
|
points, segms, idxs = ctx.saved_tensors
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class _ball_query(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_dists, grad_idx):
|
def backward(ctx, grad_dists, grad_idx):
|
||||||
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||||
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ class GatherScatter(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
edges = ctx.saved_tensors[0]
|
edges = ctx.saved_tensors[0]
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class _InterpFaceAttrs(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_pix_attrs):
|
def backward(ctx, grad_pix_attrs):
|
||||||
args = ctx.saved_tensors
|
args = ctx.saved_tensors
|
||||||
args = args + (grad_pix_attrs,)
|
args = args + (grad_pix_attrs,)
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ class _box3d_overlap(Function):
|
|||||||
return vol, iou
|
return vol, iou
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_vol, grad_iou):
|
def backward(ctx, grad_vol, grad_iou):
|
||||||
raise ValueError("box3d_overlap backward is not supported")
|
raise ValueError("box3d_overlap backward is not supported")
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ class _knn_points(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_dists, grad_idx):
|
def backward(ctx, grad_dists, grad_idx):
|
||||||
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||||
norm = ctx.norm
|
norm = ctx.norm
|
||||||
|
|||||||
@@ -247,6 +247,7 @@ class _marching_cubes(Function):
|
|||||||
return verts, faces, ids
|
return verts, faces, ids
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_verts, grad_faces):
|
def backward(ctx, grad_verts, grad_faces):
|
||||||
raise ValueError("marching_cubes backward is not supported")
|
raise ValueError("marching_cubes backward is not supported")
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class _MeshFaceAreasNormals(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_areas, grad_normals):
|
def backward(ctx, grad_areas, grad_normals):
|
||||||
grad_areas = grad_areas.contiguous()
|
grad_areas = grad_areas.contiguous()
|
||||||
grad_normals = grad_normals.contiguous()
|
grad_normals = grad_normals.contiguous()
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class _PackedToPadded(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
first_idxs = ctx.saved_tensors[0]
|
first_idxs = ctx.saved_tensors[0]
|
||||||
@@ -143,6 +144,7 @@ class _PaddedToPacked(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
grad_output = grad_output.contiguous()
|
grad_output = grad_output.contiguous()
|
||||||
first_idxs = ctx.saved_tensors[0]
|
first_idxs = ctx.saved_tensors[0]
|
||||||
|
|||||||
@@ -103,14 +103,17 @@ def iterative_closest_point(
|
|||||||
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
|
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
|
||||||
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
b, size_X, dim = Xt.shape
|
b, size_X, dim = Xt.shape
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
|
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Point sets X and Y have to have the same "
|
"Point sets X and Y have to have the same "
|
||||||
+ "number of batches and data dimensions."
|
+ "number of batches and data dimensions."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
|
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
|
||||||
num_points_Y != num_points_X
|
num_points_Y != num_points_X
|
||||||
).any():
|
).any():
|
||||||
@@ -121,6 +124,7 @@ def iterative_closest_point(
|
|||||||
< num_points_X[:, None]
|
< num_points_X[:, None]
|
||||||
).type_as(Xt)
|
).type_as(Xt)
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
mask_X = Xt.new_ones(b, size_X)
|
mask_X = Xt.new_ones(b, size_X)
|
||||||
|
|
||||||
# clone the initial point cloud
|
# clone the initial point cloud
|
||||||
@@ -145,11 +149,15 @@ def iterative_closest_point(
|
|||||||
"of scalars of shape (minibatch,)."
|
"of scalars of shape (minibatch,)."
|
||||||
) from None
|
) from None
|
||||||
# apply the init transform to the input point cloud
|
# apply the init transform to the input point cloud
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
Xt = _apply_similarity_transform(Xt, R, T, s)
|
Xt = _apply_similarity_transform(Xt, R, T, s)
|
||||||
else:
|
else:
|
||||||
# initialize the transformation with identity
|
# initialize the transformation with identity
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
|
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
T = Xt.new_zeros((b, dim))
|
T = Xt.new_zeros((b, dim))
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
s = Xt.new_ones(b)
|
s = Xt.new_ones(b)
|
||||||
|
|
||||||
prev_rmse = None
|
prev_rmse = None
|
||||||
@@ -163,7 +171,14 @@ def iterative_closest_point(
|
|||||||
# the main loop over ICP iterations
|
# the main loop over ICP iterations
|
||||||
for iteration in range(max_iterations):
|
for iteration in range(max_iterations):
|
||||||
Xt_nn_points = knn_points(
|
Xt_nn_points = knn_points(
|
||||||
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
Xt,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
Yt,
|
||||||
|
lengths1=num_points_X,
|
||||||
|
lengths2=num_points_Y,
|
||||||
|
K=1,
|
||||||
|
return_nn=True,
|
||||||
).knn[:, :, 0, :]
|
).knn[:, :, 0, :]
|
||||||
|
|
||||||
# get the alignment of the nearest neighbors from Yt with Xt_init
|
# get the alignment of the nearest neighbors from Yt with Xt_init
|
||||||
@@ -216,6 +231,7 @@ def iterative_closest_point(
|
|||||||
if oputil.is_pointclouds(X):
|
if oputil.is_pointclouds(X):
|
||||||
Xt = X.update_padded(Xt) # type: ignore
|
Xt = X.update_padded(Xt) # type: ignore
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
|
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
|
||||||
|
|
||||||
|
|
||||||
@@ -276,6 +292,7 @@ def corresponding_points_alignment(
|
|||||||
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
|
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
|
||||||
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
|
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Point sets X and Y have to have the same \
|
"Point sets X and Y have to have the same \
|
||||||
@@ -291,25 +308,33 @@ def corresponding_points_alignment(
|
|||||||
weights = [w[..., None] for w in weights]
|
weights = [w[..., None] for w in weights]
|
||||||
weights = strutil.list_to_padded(weights)[..., 0]
|
weights = strutil.list_to_padded(weights)[..., 0]
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
if Xt.shape[:2] != weights.shape:
|
if Xt.shape[:2] != weights.shape:
|
||||||
raise ValueError("weights should have the same first two dimensions as X.")
|
raise ValueError("weights should have the same first two dimensions as X.")
|
||||||
|
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
b, n, dim = Xt.shape
|
b, n, dim = Xt.shape
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-index, missing-attribute]
|
||||||
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
||||||
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
||||||
mask = (
|
mask = (
|
||||||
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
|
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
|
||||||
< num_points[:, None]
|
< num_points[:, None]
|
||||||
).type_as(Xt)
|
).type_as(Xt)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
weights = mask if weights is None else mask * weights.type_as(Xt)
|
weights = mask if weights is None else mask * weights.type_as(Xt)
|
||||||
|
|
||||||
# compute the centroids of the point sets
|
# compute the centroids of the point sets
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
|
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
|
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
|
||||||
|
|
||||||
# mean-center the point sets
|
# mean-center the point sets
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
Xc = Xt - Xmu
|
Xc = Xt - Xmu
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
Yc = Yt - Ymu
|
Yc = Yt - Ymu
|
||||||
|
|
||||||
total_weight = torch.clamp(num_points, 1)
|
total_weight = torch.clamp(num_points, 1)
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
|
|
||||||
points_padded, num_points = convert_pointclouds_to_tensor(pointclouds)
|
points_padded, num_points = convert_pointclouds_to_tensor(pointclouds)
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
ba, N, dim = points_padded.shape
|
ba, N, dim = points_padded.shape
|
||||||
if dim != 3:
|
if dim != 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -133,6 +134,7 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
|
|
||||||
# undo global mean for stability
|
# undo global mean for stability
|
||||||
# TODO: replace with tutil.wmean once landed
|
# TODO: replace with tutil.wmean once landed
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
pcl_mean = points_padded.sum(1) / num_points[:, None]
|
pcl_mean = points_padded.sum(1) / num_points[:, None]
|
||||||
points_centered = points_padded - pcl_mean[:, None, :]
|
points_centered = points_padded - pcl_mean[:, None, :]
|
||||||
|
|
||||||
@@ -154,17 +156,26 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
if disambiguate_directions:
|
if disambiguate_directions:
|
||||||
# disambiguate normal
|
# disambiguate normal
|
||||||
n = _disambiguate_vector_directions(
|
n = _disambiguate_vector_directions(
|
||||||
points_centered, knns, local_coord_frames[:, :, :, 0]
|
# pyrefly: ignore [unsupported-operation]
|
||||||
|
points_centered,
|
||||||
|
knns,
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
|
local_coord_frames[:, :, :, 0],
|
||||||
)
|
)
|
||||||
# disambiguate the main curvature
|
# disambiguate the main curvature
|
||||||
z = _disambiguate_vector_directions(
|
z = _disambiguate_vector_directions(
|
||||||
points_centered, knns, local_coord_frames[:, :, :, 2]
|
# pyrefly: ignore [unsupported-operation]
|
||||||
|
points_centered,
|
||||||
|
knns,
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
|
local_coord_frames[:, :, :, 2],
|
||||||
)
|
)
|
||||||
# the secondary curvature is just a cross between n and z
|
# the secondary curvature is just a cross between n and z
|
||||||
y = torch.cross(n, z, dim=2)
|
y = torch.cross(n, z, dim=2)
|
||||||
# cat to form the set of principal directions
|
# cat to form the set of principal directions
|
||||||
local_coord_frames = torch.stack((n, y, z), dim=3)
|
local_coord_frames = torch.stack((n, y, z), dim=3)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return curvatures, local_coord_frames
|
return curvatures, local_coord_frames
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ class _points_to_volumes_function(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_volume_densities, grad_volume_features):
|
def backward(ctx, grad_volume_densities, grad_volume_features):
|
||||||
splat = ctx.splat
|
splat = ctx.splat
|
||||||
N, C = grad_volume_features.shape[:2]
|
N, C = grad_volume_features.shape[:2]
|
||||||
@@ -377,6 +378,7 @@ def add_points_features_to_volume_densities_features(
|
|||||||
if grid_sizes is None:
|
if grid_sizes is None:
|
||||||
# grid sizes shape (minibatch, 3)
|
# grid sizes shape (minibatch, 3)
|
||||||
grid_sizes = (
|
grid_sizes = (
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
torch.LongTensor(list(volume_densities.shape[2:]))
|
torch.LongTensor(list(volume_densities.shape[2:]))
|
||||||
.to(volume_densities.device)
|
.to(volume_densities.device)
|
||||||
.expand(volume_densities.shape[0], 3)
|
.expand(volume_densities.shape[0], 3)
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ def sample_farthest_points_naive(
|
|||||||
|
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
# Initialize an array for the sampled indices, shape: (max_K,)
|
# Initialize an array for the sampled indices, shape: (max_K,)
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
sample_idx_batch = torch.full(
|
sample_idx_batch = torch.full(
|
||||||
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
|
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
|
||||||
# typing.Tuple[int, ...]]` but got `Tuple[Tensor]`.
|
# typing.Tuple[int, ...]]` but got `Tuple[Tensor]`.
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
|
|||||||
elif torch.is_tensor(pcl):
|
elif torch.is_tensor(pcl):
|
||||||
X = pcl
|
X = pcl
|
||||||
num_points = X.shape[1] * torch.ones( # type: ignore
|
num_points = X.shape[1] * torch.ones( # type: ignore
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
X.shape[0],
|
X.shape[0],
|
||||||
device=X.device,
|
device=X.device,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class _SigmoidAlphaBlend(torch.autograd.Function):
|
|||||||
return alphas
|
return alphas
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_alphas):
|
def backward(ctx, grad_alphas):
|
||||||
dists, pix_to_face, alphas = ctx.saved_tensors
|
dists, pix_to_face, alphas = ctx.saved_tensors
|
||||||
sigma = ctx.sigma
|
sigma = ctx.sigma
|
||||||
@@ -212,8 +213,10 @@ def softmax_rgb_blend(
|
|||||||
|
|
||||||
# Reshape to be compatible with (N, H, W, K) values in fragments
|
# Reshape to be compatible with (N, H, W, K) values in fragments
|
||||||
if torch.is_tensor(zfar):
|
if torch.is_tensor(zfar):
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
zfar = zfar[:, None, None, None]
|
zfar = zfar[:, None, None, None]
|
||||||
if torch.is_tensor(znear):
|
if torch.is_tensor(znear):
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
znear = znear[:, None, None, None]
|
znear = znear[:, None, None, None]
|
||||||
|
|
||||||
# pyre-fixme[6]: Expected `float` but got `Union[float, Tensor]`
|
# pyre-fixme[6]: Expected `float` but got `Union[float, Tensor]`
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ def _opencv_from_cameras_projection(
|
|||||||
scale = scale.expand(-1, 2)
|
scale = scale.expand(-1, 2)
|
||||||
c0 = image_size_wh / 2.0
|
c0 = image_size_wh / 2.0
|
||||||
|
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
principal_point = -p0_pytorch3d * scale + c0
|
principal_point = -p0_pytorch3d * scale + c0
|
||||||
focal_length = focal_pytorch3d * scale
|
focal_length = focal_pytorch3d * scale
|
||||||
|
|
||||||
|
|||||||
@@ -202,8 +202,10 @@ def join_cameras_as_batch(cameras_list: Sequence[CamerasBase]) -> CamerasBase:
|
|||||||
# In the init, all inputs will be converted to
|
# In the init, all inputs will be converted to
|
||||||
# batched tensors before set as attributes
|
# batched tensors before set as attributes
|
||||||
# Join as a tensor along the batch dimension
|
# Join as a tensor along the batch dimension
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
kwargs[field] = torch.cat(attrs_list, dim=0)
|
kwargs[field] = torch.cat(attrs_list, dim=0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Field {field} type is not supported for batching")
|
raise ValueError(f"Field {field} type is not supported for batching")
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return c0.__class__(**kwargs)
|
return c0.__class__(**kwargs)
|
||||||
|
|||||||
@@ -362,6 +362,7 @@ class CamerasBase(TensorProperties):
|
|||||||
self, with_xyflip=with_xyflip, image_size=image_size
|
self, with_xyflip=with_xyflip, image_size=image_size
|
||||||
).transform_points(points_ndc, eps=eps)
|
).transform_points(points_ndc, eps=eps)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""
|
"""
|
||||||
Returns a copy of `self`.
|
Returns a copy of `self`.
|
||||||
@@ -390,6 +391,7 @@ class CamerasBase(TensorProperties):
|
|||||||
"""
|
"""
|
||||||
return getattr(self, "image_size", None)
|
return getattr(self, "image_size", None)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
|
self, index: Union[int, List[int], torch.BoolTensor, torch.LongTensor]
|
||||||
) -> "CamerasBase":
|
) -> "CamerasBase":
|
||||||
@@ -455,11 +457,14 @@ class CamerasBase(TensorProperties):
|
|||||||
elif isinstance(val, torch.Tensor):
|
elif isinstance(val, torch.Tensor):
|
||||||
# In the init, all inputs will be converted to
|
# In the init, all inputs will be converted to
|
||||||
# tensors before setting as attributes
|
# tensors before setting as attributes
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
kwargs[field] = val[index]
|
kwargs[field] = val[index]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Field {field} type is not supported for indexing")
|
raise ValueError(f"Field {field} type is not supported for indexing")
|
||||||
|
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
kwargs["device"] = self.device
|
kwargs["device"] = self.device
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return self.__class__(**kwargs)
|
return self.__class__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -1741,7 +1746,14 @@ def look_at_view_transform(
|
|||||||
dist, elev, azim, at, up = broadcasted_args
|
dist, elev, azim, at, up = broadcasted_args
|
||||||
C = (
|
C = (
|
||||||
camera_position_from_spherical_angles(
|
camera_position_from_spherical_angles(
|
||||||
dist, elev, azim, degrees=degrees, device=device
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
dist,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
elev,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
azim,
|
||||||
|
degrees=degrees,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
+ at
|
+ at
|
||||||
)
|
)
|
||||||
@@ -1787,6 +1799,7 @@ def get_ndc_to_screen_transform(
|
|||||||
K = torch.zeros((cameras._N, 4, 4), device=cameras.device, dtype=torch.float32)
|
K = torch.zeros((cameras._N, 4, 4), device=cameras.device, dtype=torch.float32)
|
||||||
if not torch.is_tensor(image_size):
|
if not torch.is_tensor(image_size):
|
||||||
image_size = torch.tensor(image_size, device=cameras.device)
|
image_size = torch.tensor(image_size, device=cameras.device)
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
image_size = image_size.view(-1, 2) # of shape (1 or B)x2
|
image_size = image_size.view(-1, 2) # of shape (1 or B)x2
|
||||||
height, width = image_size.unbind(1)
|
height, width = image_size.unbind(1)
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ class _CompositeAlphaPoints(torch.autograd.Function):
|
|||||||
return pt_cld
|
return pt_cld
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
grad_features = None
|
grad_features = None
|
||||||
grad_alphas = None
|
grad_alphas = None
|
||||||
@@ -130,6 +131,7 @@ class _CompositeNormWeightedSumPoints(torch.autograd.Function):
|
|||||||
return pt_cld
|
return pt_cld
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
grad_features = None
|
grad_features = None
|
||||||
grad_alphas = None
|
grad_alphas = None
|
||||||
@@ -208,6 +210,7 @@ class _CompositeWeightedSumPoints(torch.autograd.Function):
|
|||||||
return pt_cld
|
return pt_cld
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
grad_features = None
|
grad_features = None
|
||||||
grad_alphas = None
|
grad_alphas = None
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ class FishEyeCameras(CamerasBase):
|
|||||||
else:
|
else:
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.device = device
|
self.device = device
|
||||||
self.focal = focal_length.to(self.device)
|
self.focal = focal_length.to(self.device)
|
||||||
self.principal_point = principal_point.to(self.device)
|
self.principal_point = principal_point.to(self.device)
|
||||||
|
|||||||
@@ -136,13 +136,16 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
||||||
"""
|
"""
|
||||||
# [..., dim, n_harmonic_functions]
|
# [..., dim, n_harmonic_functions]
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
embed = x[..., None] * self._frequencies
|
embed = x[..., None] * self._frequencies
|
||||||
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
|
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
||||||
# Use the trig identity cos(x) = sin(x + pi/2)
|
# Use the trig identity cos(x) = sin(x + pi/2)
|
||||||
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
|
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
|
||||||
embed = embed.sin()
|
embed = embed.sin()
|
||||||
if diag_cov is not None:
|
if diag_cov is not None:
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
||||||
exp_var = torch.exp(-0.5 * x_var)
|
exp_var = torch.exp(-0.5 * x_var)
|
||||||
# [..., 2, dim, n_harmonic_functions]
|
# [..., 2, dim, n_harmonic_functions]
|
||||||
@@ -180,5 +183,9 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
so the input might be xyz.
|
so the input might be xyz.
|
||||||
"""
|
"""
|
||||||
return self.get_output_dim_static(
|
return self.get_output_dim_static(
|
||||||
input_dims, len(self._frequencies), self.append_input
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
input_dims,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
len(self._frequencies),
|
||||||
|
self.append_input,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -226,6 +226,7 @@ def _check_raymarcher_inputs(
|
|||||||
if not z_can_be_none and rays_z.shape != rays_shape:
|
if not z_can_be_none and rays_z.shape != rays_shape:
|
||||||
raise ValueError("rays_z have to be of the same shape as rays_densities.")
|
raise ValueError("rays_z have to be of the same shape as rays_densities.")
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
if not features_can_be_none and rays_features.shape[:-1] != rays_shape:
|
if not features_can_be_none and rays_features.shape[:-1] != rays_shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The first to previous to last dimensions of rays_features"
|
"The first to previous to last dimensions of rays_features"
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
"`n_rays_total` and `n_rays_per_image` cannot both be defined."
|
"`n_rays_total` and `n_rays_per_image` cannot both be defined."
|
||||||
)
|
)
|
||||||
if n_rays_total:
|
if n_rays_total:
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
(
|
(
|
||||||
cameras,
|
cameras,
|
||||||
mask,
|
mask,
|
||||||
@@ -221,6 +222,7 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
if mask is not None and n_rays_per_image is None:
|
if mask is not None and n_rays_per_image is None:
|
||||||
# if num rays not given, sample according to the smallest mask
|
# if num rays not given, sample according to the smallest mask
|
||||||
n_rays_per_image = (
|
n_rays_per_image = (
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
|
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -453,6 +455,7 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
# of shape (batch_size, n_rays_per_image, 2)
|
# of shape (batch_size, n_rays_per_image, 2)
|
||||||
rays_xy = torch.cat(
|
rays_xy = torch.cat(
|
||||||
[
|
[
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
torch.rand(
|
torch.rand(
|
||||||
size=(batch_size, n_rays_per_image, 1),
|
size=(batch_size, n_rays_per_image, 1),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
|
|||||||
@@ -190,10 +190,12 @@ class DirectionalLights(TensorProperties):
|
|||||||
direction=direction,
|
direction=direction,
|
||||||
)
|
)
|
||||||
_validate_light_properties(self)
|
_validate_light_properties(self)
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
if self.direction.shape[-1] != 3:
|
if self.direction.shape[-1] != 3:
|
||||||
msg = "Expected direction to have shape (N, 3); got %r"
|
msg = "Expected direction to have shape (N, 3); got %r"
|
||||||
raise ValueError(msg % repr(self.direction.shape))
|
raise ValueError(msg % repr(self.direction.shape))
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def clone(self):
|
def clone(self):
|
||||||
other = self.__class__(device=self.device)
|
other = self.__class__(device=self.device)
|
||||||
return super().clone(other)
|
return super().clone(other)
|
||||||
@@ -251,10 +253,12 @@ class PointLights(TensorProperties):
|
|||||||
location=location,
|
location=location,
|
||||||
)
|
)
|
||||||
_validate_light_properties(self)
|
_validate_light_properties(self)
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
if self.location.shape[-1] != 3:
|
if self.location.shape[-1] != 3:
|
||||||
msg = "Expected location to have shape (N, 3); got %r"
|
msg = "Expected location to have shape (N, 3); got %r"
|
||||||
raise ValueError(msg % repr(self.location.shape))
|
raise ValueError(msg % repr(self.location.shape))
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def clone(self):
|
def clone(self):
|
||||||
other = self.__class__(device=self.device)
|
other = self.__class__(device=self.device)
|
||||||
return super().clone(other)
|
return super().clone(other)
|
||||||
@@ -319,6 +323,7 @@ class AmbientLights(TensorProperties):
|
|||||||
ambient_color = ((1.0, 1.0, 1.0),)
|
ambient_color = ((1.0, 1.0, 1.0),)
|
||||||
super().__init__(ambient_color=ambient_color, device=device)
|
super().__init__(ambient_color=ambient_color, device=device)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def clone(self):
|
def clone(self):
|
||||||
other = self.__class__(device=self.device)
|
other = self.__class__(device=self.device)
|
||||||
return super().clone(other)
|
return super().clone(other)
|
||||||
@@ -330,7 +335,9 @@ class AmbientLights(TensorProperties):
|
|||||||
return self._zeros_channels(points)
|
return self._zeros_channels(points)
|
||||||
|
|
||||||
def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
|
def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
ch = self.ambient_color.shape[-1]
|
ch = self.ambient_color.shape[-1]
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return torch.zeros(*points.shape[:-1], ch, device=points.device)
|
return torch.zeros(*points.shape[:-1], ch, device=points.device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class Materials(TensorProperties):
|
|||||||
specular_color=specular_color,
|
specular_color=specular_color,
|
||||||
shininess=shininess,
|
shininess=shininess,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-index]
|
||||||
C = self.ambient_color.shape[-1]
|
C = self.ambient_color.shape[-1]
|
||||||
for n in ["ambient_color", "diffuse_color", "specular_color"]:
|
for n in ["ambient_color", "diffuse_color", "specular_color"]:
|
||||||
t = getattr(self, n)
|
t = getattr(self, n)
|
||||||
@@ -62,6 +63,7 @@ class Materials(TensorProperties):
|
|||||||
msg = "shininess should have shape (N); got %r"
|
msg = "shininess should have shape (N); got %r"
|
||||||
raise ValueError(msg % repr(self.shininess.shape))
|
raise ValueError(msg % repr(self.shininess.shape))
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def clone(self):
|
def clone(self):
|
||||||
other = Materials(device=self.device)
|
other = Materials(device=self.device)
|
||||||
return super().clone(other)
|
return super().clone(other)
|
||||||
|
|||||||
@@ -496,7 +496,12 @@ def clip_faces(
|
|||||||
|
|
||||||
# Solve for the points p4, p5 that intersect the clipping plane
|
# Solve for the points p4, p5 that intersect the clipping plane
|
||||||
p, p_barycentric = _find_verts_intersecting_clipping_plane(
|
p, p_barycentric = _find_verts_intersecting_clipping_plane(
|
||||||
faces_case3, p1_face_ind, z_clip_value, perspective_correct
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
faces_case3,
|
||||||
|
p1_face_ind,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
z_clip_value,
|
||||||
|
perspective_correct,
|
||||||
)
|
)
|
||||||
|
|
||||||
p1, _, _, p4, p5 = p
|
p1, _, _, p4, p5 = p
|
||||||
@@ -540,7 +545,12 @@ def clip_faces(
|
|||||||
|
|
||||||
# Solve for the points p4, p5 that intersect the clipping plane
|
# Solve for the points p4, p5 that intersect the clipping plane
|
||||||
p, p_barycentric = _find_verts_intersecting_clipping_plane(
|
p, p_barycentric = _find_verts_intersecting_clipping_plane(
|
||||||
faces_case4, p1_face_ind, z_clip_value, perspective_correct
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
faces_case4,
|
||||||
|
p1_face_ind,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
z_clip_value,
|
||||||
|
perspective_correct,
|
||||||
)
|
)
|
||||||
_, p2, p3, p4, p5 = p
|
_, p2, p3, p4, p5 = p
|
||||||
_, p2_barycentric, p3_barycentric, p4_barycentric, p5_barycentric = p_barycentric
|
_, p2_barycentric, p3_barycentric, p4_barycentric, p5_barycentric = p_barycentric
|
||||||
@@ -682,6 +692,7 @@ def convert_clipped_rasterization_to_original_faces(
|
|||||||
# rasterized pixel.
|
# rasterized pixel.
|
||||||
pix_to_conversion_idx = torch.where(
|
pix_to_conversion_idx = torch.where(
|
||||||
pix_to_face_clipped != -1,
|
pix_to_face_clipped != -1,
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
faces_clipped_to_conversion_idx[pix_to_face_clipped],
|
faces_clipped_to_conversion_idx[pix_to_face_clipped],
|
||||||
empty,
|
empty,
|
||||||
)
|
)
|
||||||
@@ -709,6 +720,7 @@ def convert_clipped_rasterization_to_original_faces(
|
|||||||
bary_coords_clipped_subset
|
bary_coords_clipped_subset
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
bary_coords_unclipped_subset = bary_coords_unclipped_subset.reshape([N * 3])
|
bary_coords_unclipped_subset = bary_coords_unclipped_subset.reshape([N * 3])
|
||||||
bary_coords_unclipped[faces_to_convert_mask_expanded] = (
|
bary_coords_unclipped[faces_to_convert_mask_expanded] = (
|
||||||
bary_coords_unclipped_subset
|
bary_coords_unclipped_subset
|
||||||
|
|||||||
@@ -316,6 +316,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
|||||||
return pix_to_face, zbuf, barycentric_coords, dists
|
return pix_to_face, zbuf, barycentric_coords, dists
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists):
|
def backward(ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists):
|
||||||
grad_face_verts = None
|
grad_face_verts = None
|
||||||
grad_mesh_to_face_first_idx = None
|
grad_mesh_to_face_first_idx = None
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class MeshRasterizer(nn.Module):
|
|||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
self.raster_settings = raster_settings
|
self.raster_settings = raster_settings
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||||
if self.cameras is not None:
|
if self.cameras is not None:
|
||||||
@@ -238,10 +239,12 @@ class MeshRasterizer(nn.Module):
|
|||||||
if raster_settings.perspective_correct is not None:
|
if raster_settings.perspective_correct is not None:
|
||||||
perspective_correct = raster_settings.perspective_correct
|
perspective_correct = raster_settings.perspective_correct
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
perspective_correct = cameras.is_perspective()
|
perspective_correct = cameras.is_perspective()
|
||||||
if raster_settings.z_clip_value is not None:
|
if raster_settings.z_clip_value is not None:
|
||||||
z_clip = raster_settings.z_clip_value
|
z_clip = raster_settings.z_clip_value
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
znear = cameras.get_znear()
|
znear = cameras.get_znear()
|
||||||
if isinstance(znear, torch.Tensor):
|
if isinstance(znear, torch.Tensor):
|
||||||
znear = znear.min().item()
|
znear = znear.min().item()
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class MeshRenderer(nn.Module):
|
|||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.shader = shader
|
self.shader = shader
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||||
self.rasterizer.to(device)
|
self.rasterizer.to(device)
|
||||||
@@ -85,6 +86,7 @@ class MeshRendererWithFragments(nn.Module):
|
|||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.shader = shader
|
self.shader = shader
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||||
self.rasterizer.to(device)
|
self.rasterizer.to(device)
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class ShaderBase(nn.Module):
|
|||||||
cameras = self.cameras
|
cameras = self.cameras
|
||||||
if cameras is not None:
|
if cameras is not None:
|
||||||
self.cameras = cameras.to(device)
|
self.cameras = cameras.to(device)
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ class TexturesBase:
|
|||||||
f"Property {p} has unsupported type {type(t)}."
|
f"Property {p} has unsupported type {type(t)}."
|
||||||
"Only tensors and lists are supported."
|
"Only tensors and lists are supported."
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return new_props
|
return new_props
|
||||||
|
|
||||||
def _getitem(self, index: Union[int, slice], props: List[str]):
|
def _getitem(self, index: Union[int, slice], props: List[str]):
|
||||||
@@ -275,6 +276,7 @@ class TexturesBase:
|
|||||||
t = t() # class method
|
t = t() # class method
|
||||||
new_props[p] = t[index] if t is not None else None
|
new_props[p] = t[index] if t is not None else None
|
||||||
elif isinstance(index, list):
|
elif isinstance(index, list):
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
index = torch.tensor(index)
|
index = torch.tensor(index)
|
||||||
if isinstance(index, torch.Tensor):
|
if isinstance(index, torch.Tensor):
|
||||||
if index.dtype == torch.bool:
|
if index.dtype == torch.bool:
|
||||||
@@ -451,6 +453,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
msg = "Expected atlas to be of shape (N, F, R, R, C); got %r"
|
msg = "Expected atlas to be of shape (N, F, R, R, C); got %r"
|
||||||
raise ValueError(msg % repr(atlas.ndim))
|
raise ValueError(msg % repr(atlas.ndim))
|
||||||
self._atlas_padded = atlas
|
self._atlas_padded = atlas
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._atlas_list = None
|
self._atlas_list = None
|
||||||
self.device = atlas.device
|
self.device = atlas.device
|
||||||
|
|
||||||
@@ -474,6 +477,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
if self._atlas_list is not None:
|
if self._atlas_list is not None:
|
||||||
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
||||||
num_faces = (
|
num_faces = (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self._num_faces_per_mesh.clone()
|
self._num_faces_per_mesh.clone()
|
||||||
if torch.is_tensor(self._num_faces_per_mesh)
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
else self._num_faces_per_mesh
|
else self._num_faces_per_mesh
|
||||||
@@ -487,6 +491,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
if self._atlas_list is not None:
|
if self._atlas_list is not None:
|
||||||
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
||||||
num_faces = (
|
num_faces = (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self._num_faces_per_mesh.detach()
|
self._num_faces_per_mesh.detach()
|
||||||
if torch.is_tensor(self._num_faces_per_mesh)
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
else self._num_faces_per_mesh
|
else self._num_faces_per_mesh
|
||||||
@@ -504,9 +509,11 @@ class TexturesAtlas(TexturesBase):
|
|||||||
new_tex = self.__class__(atlas=atlas)
|
new_tex = self.__class__(atlas=atlas)
|
||||||
elif torch.is_tensor(atlas):
|
elif torch.is_tensor(atlas):
|
||||||
# single element
|
# single element
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
new_tex = self.__class__(atlas=[atlas])
|
new_tex = self.__class__(atlas=[atlas])
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not all values are provided in the correct format")
|
raise ValueError("Not all values are provided in the correct format")
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
@@ -528,6 +535,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
self._atlas_padded = [
|
self._atlas_padded = [
|
||||||
torch.empty((0, 0, 0, 3), dtype=torch.float32, device=self.device)
|
torch.empty((0, 0, 0, 3), dtype=torch.float32, device=self.device)
|
||||||
] * self._N
|
] * self._N
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._atlas_list = _padded_to_list_wrapper(
|
self._atlas_list = _padded_to_list_wrapper(
|
||||||
self._atlas_padded, split_size=self._num_faces_per_mesh
|
self._atlas_padded, split_size=self._num_faces_per_mesh
|
||||||
)
|
)
|
||||||
@@ -544,6 +552,7 @@ class TexturesAtlas(TexturesBase):
|
|||||||
def extend(self, N: int) -> "TexturesAtlas":
|
def extend(self, N: int) -> "TexturesAtlas":
|
||||||
new_props = self._extend(N, ["atlas_padded", "_num_faces_per_mesh"])
|
new_props = self._extend(N, ["atlas_padded", "_num_faces_per_mesh"])
|
||||||
new_tex = self.__class__(atlas=new_props["atlas_padded"])
|
new_tex = self.__class__(atlas=new_props["atlas_padded"])
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
@@ -790,6 +799,7 @@ class TexturesUV(TexturesBase):
|
|||||||
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
||||||
raise ValueError(msg % repr(faces_uvs.shape))
|
raise ValueError(msg % repr(faces_uvs.shape))
|
||||||
self._faces_uvs_padded = faces_uvs
|
self._faces_uvs_padded = faces_uvs
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._faces_uvs_list = None
|
self._faces_uvs_list = None
|
||||||
self.device = faces_uvs.device
|
self.device = faces_uvs.device
|
||||||
|
|
||||||
@@ -826,6 +836,7 @@ class TexturesUV(TexturesBase):
|
|||||||
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
||||||
raise ValueError(msg % repr(verts_uvs.shape))
|
raise ValueError(msg % repr(verts_uvs.shape))
|
||||||
self._verts_uvs_padded = verts_uvs
|
self._verts_uvs_padded = verts_uvs
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._verts_uvs_list = None
|
self._verts_uvs_list = None
|
||||||
|
|
||||||
if verts_uvs.device != self.device:
|
if verts_uvs.device != self.device:
|
||||||
@@ -838,6 +849,7 @@ class TexturesUV(TexturesBase):
|
|||||||
if isinstance(maps, (list, tuple)):
|
if isinstance(maps, (list, tuple)):
|
||||||
self._maps_list = maps
|
self._maps_list = maps
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._maps_list = None
|
self._maps_list = None
|
||||||
self._maps_padded = self._format_maps_padded(maps)
|
self._maps_padded = self._format_maps_padded(maps)
|
||||||
|
|
||||||
@@ -966,6 +978,7 @@ class TexturesUV(TexturesBase):
|
|||||||
if self._maps_ids_list is not None:
|
if self._maps_ids_list is not None:
|
||||||
tex._maps_ids_list = [f.clone() for f in self._maps_ids_list]
|
tex._maps_ids_list = [f.clone() for f in self._maps_ids_list]
|
||||||
num_faces = (
|
num_faces = (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self._num_faces_per_mesh.clone()
|
self._num_faces_per_mesh.clone()
|
||||||
if torch.is_tensor(self._num_faces_per_mesh)
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
else self._num_faces_per_mesh
|
else self._num_faces_per_mesh
|
||||||
@@ -997,6 +1010,7 @@ class TexturesUV(TexturesBase):
|
|||||||
if self._maps_ids_list is not None:
|
if self._maps_ids_list is not None:
|
||||||
tex._maps_ids_list = [mi.detach() for mi in self._maps_ids_list]
|
tex._maps_ids_list = [mi.detach() for mi in self._maps_ids_list]
|
||||||
num_faces = (
|
num_faces = (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self._num_faces_per_mesh.detach()
|
self._num_faces_per_mesh.detach()
|
||||||
if torch.is_tensor(self._num_faces_per_mesh)
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
else self._num_faces_per_mesh
|
else self._num_faces_per_mesh
|
||||||
@@ -1026,8 +1040,11 @@ class TexturesUV(TexturesBase):
|
|||||||
"Maps ids are not in the correct format expected list or tuple"
|
"Maps ids are not in the correct format expected list or tuple"
|
||||||
)
|
)
|
||||||
new_tex = self.__class__(
|
new_tex = self.__class__(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
faces_uvs=faces_uvs,
|
faces_uvs=faces_uvs,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
verts_uvs=verts_uvs,
|
verts_uvs=verts_uvs,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
maps=maps,
|
maps=maps,
|
||||||
maps_ids=maps_ids,
|
maps_ids=maps_ids,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
@@ -1040,8 +1057,11 @@ class TexturesUV(TexturesBase):
|
|||||||
"Maps ids are not in the correct format expected tensor"
|
"Maps ids are not in the correct format expected tensor"
|
||||||
)
|
)
|
||||||
new_tex = self.__class__(
|
new_tex = self.__class__(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
faces_uvs=[faces_uvs],
|
faces_uvs=[faces_uvs],
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
verts_uvs=[verts_uvs],
|
verts_uvs=[verts_uvs],
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
maps=[maps],
|
maps=[maps],
|
||||||
maps_ids=[maps_ids] if maps_ids is not None else None,
|
maps_ids=[maps_ids] if maps_ids is not None else None,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
@@ -1050,6 +1070,7 @@ class TexturesUV(TexturesBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not all values are provided in the correct format")
|
raise ValueError("Not all values are provided in the correct format")
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
@@ -1072,9 +1093,11 @@ class TexturesUV(TexturesBase):
|
|||||||
torch.empty((0, 3), dtype=torch.float32, device=self.device)
|
torch.empty((0, 3), dtype=torch.float32, device=self.device)
|
||||||
] * self._N
|
] * self._N
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._faces_uvs_list = padded_to_list(
|
self._faces_uvs_list = padded_to_list(
|
||||||
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._faces_uvs_list
|
return self._faces_uvs_list
|
||||||
|
|
||||||
def verts_uvs_padded(self) -> torch.Tensor:
|
def verts_uvs_padded(self) -> torch.Tensor:
|
||||||
@@ -1099,7 +1122,9 @@ class TexturesUV(TexturesBase):
|
|||||||
# The number of vertices in the mesh and in verts_uvs can differ
|
# The number of vertices in the mesh and in verts_uvs can differ
|
||||||
# e.g. if a vertex is shared between 3 faces, it can
|
# e.g. if a vertex is shared between 3 faces, it can
|
||||||
# have up to 3 different uv coordinates.
|
# have up to 3 different uv coordinates.
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
|
self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._verts_uvs_list
|
return self._verts_uvs_list
|
||||||
|
|
||||||
def maps_ids_padded(self) -> Optional[torch.Tensor]:
|
def maps_ids_padded(self) -> Optional[torch.Tensor]:
|
||||||
@@ -1107,8 +1132,10 @@ class TexturesUV(TexturesBase):
|
|||||||
|
|
||||||
def maps_ids_list(self) -> Optional[List[torch.Tensor]]:
|
def maps_ids_list(self) -> Optional[List[torch.Tensor]]:
|
||||||
if self._maps_ids_list is not None:
|
if self._maps_ids_list is not None:
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._maps_ids_list
|
return self._maps_ids_list
|
||||||
elif self._maps_ids_padded is not None:
|
elif self._maps_ids_padded is not None:
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._maps_ids_padded.unbind(0)
|
return self._maps_ids_padded.unbind(0)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -1143,6 +1170,7 @@ class TexturesUV(TexturesBase):
|
|||||||
sampling_mode=self.sampling_mode,
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
@@ -1716,6 +1744,7 @@ class TexturesVertex(TexturesBase):
|
|||||||
msg = "Expected verts_features to be of shape (N, V, C); got %r"
|
msg = "Expected verts_features to be of shape (N, V, C); got %r"
|
||||||
raise ValueError(msg % repr(verts_features.shape))
|
raise ValueError(msg % repr(verts_features.shape))
|
||||||
self._verts_features_padded = verts_features
|
self._verts_features_padded = verts_features
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._verts_features_list = None
|
self._verts_features_list = None
|
||||||
self.device = verts_features.device
|
self.device = verts_features.device
|
||||||
|
|
||||||
@@ -1763,9 +1792,11 @@ class TexturesVertex(TexturesBase):
|
|||||||
)
|
)
|
||||||
new_tex = self.__class__(verts_features=verts_features)
|
new_tex = self.__class__(verts_features=verts_features)
|
||||||
elif torch.is_tensor(verts_features):
|
elif torch.is_tensor(verts_features):
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
new_tex = self.__class__(verts_features=[verts_features])
|
new_tex = self.__class__(verts_features=[verts_features])
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not all values are provided in the correct format")
|
raise ValueError("Not all values are provided in the correct format")
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
|
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
|
||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
@@ -1788,9 +1819,11 @@ class TexturesVertex(TexturesBase):
|
|||||||
torch.empty((0, 3), dtype=torch.float32, device=self.device)
|
torch.empty((0, 3), dtype=torch.float32, device=self.device)
|
||||||
] * self._N
|
] * self._N
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._verts_features_list = padded_to_list(
|
self._verts_features_list = padded_to_list(
|
||||||
self._verts_features_padded, split_size=self._num_verts_per_mesh
|
self._verts_features_padded, split_size=self._num_verts_per_mesh
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._verts_features_list
|
return self._verts_features_list
|
||||||
|
|
||||||
def verts_features_packed(self) -> torch.Tensor:
|
def verts_features_packed(self) -> torch.Tensor:
|
||||||
@@ -1802,6 +1835,7 @@ class TexturesVertex(TexturesBase):
|
|||||||
def extend(self, N: int) -> "TexturesVertex":
|
def extend(self, N: int) -> "TexturesVertex":
|
||||||
new_props = self._extend(N, ["verts_features_padded", "_num_verts_per_mesh"])
|
new_props = self._extend(N, ["verts_features_padded", "_num_verts_per_mesh"])
|
||||||
new_tex = self.__class__(verts_features=new_props["verts_features_padded"])
|
new_tex = self.__class__(verts_features=new_props["verts_features_padded"])
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
|
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
|
||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ def _get_cuda_device(requested_device_id: int):
|
|||||||
|
|
||||||
# Iterate over all the EGL devices, and check if their CUDA ID matches the request.
|
# Iterate over all the EGL devices, and check if their CUDA ID matches the request.
|
||||||
for device in devices:
|
for device in devices:
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
available_device_id = egl.EGLAttrib(ctypes.c_int(-1))
|
available_device_id = egl.EGLAttrib(ctypes.c_int(-1))
|
||||||
# pyre-ignore Undefined attribute [16]
|
# pyre-ignore Undefined attribute [16]
|
||||||
egl.eglQueryDeviceAttribEXT(device, EGL_CUDA_DEVICE_NV, available_device_id)
|
egl.eglQueryDeviceAttribEXT(device, EGL_CUDA_DEVICE_NV, available_device_id)
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ class MeshRasterizerOpenGL(nn.Module):
|
|||||||
dists=None,
|
dists=None,
|
||||||
).detach()
|
).detach()
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||||
if self.cameras is not None:
|
if self.cameras is not None:
|
||||||
@@ -276,6 +277,7 @@ class _OpenGLMachinery:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.initialize_device_data(meshes_gl_ndc.device)
|
self.initialize_device_data(meshes_gl_ndc.device)
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
with self.egl_context.active_and_locked():
|
with self.egl_context.active_and_locked():
|
||||||
# Perspective projection happens in OpenGL. Move the matrix over if there's only
|
# Perspective projection happens in OpenGL. Move the matrix over if there's only
|
||||||
# a single camera shared by all the meshes.
|
# a single camera shared by all the meshes.
|
||||||
@@ -370,11 +372,15 @@ class _OpenGLMachinery:
|
|||||||
"""
|
"""
|
||||||
# Finish all current operations.
|
# Finish all current operations.
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.cuda_context.synchronize()
|
self.cuda_context.synchronize()
|
||||||
|
|
||||||
# Free pycuda resources.
|
# Free pycuda resources.
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.cuda_context.push()
|
self.cuda_context.push()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.cuda_buffer.unregister()
|
self.cuda_buffer.unregister()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.cuda_context.pop()
|
self.cuda_context.pop()
|
||||||
|
|
||||||
# Free GL resources.
|
# Free GL resources.
|
||||||
@@ -391,6 +397,7 @@ class _OpenGLMachinery:
|
|||||||
del self.mesh_buffer_object
|
del self.mesh_buffer_object
|
||||||
|
|
||||||
gl.glDeleteProgram(self.program)
|
gl.glDeleteProgram(self.program)
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.egl_context.release()
|
self.egl_context.release()
|
||||||
|
|
||||||
def _projection_matrix_to_opengl(self, projection_matrix: torch.Tensor) -> None:
|
def _projection_matrix_to_opengl(self, projection_matrix: torch.Tensor) -> None:
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ class _Render(torch.autograd.Function):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_im, *args):
|
def backward(ctx, grad_im, *args):
|
||||||
global GAMMA_WARNING_EMITTED
|
global GAMMA_WARNING_EMITTED
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -121,7 +121,9 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
"gamma is a required keyword argument for the PulsarPointsRenderer!"
|
"gamma is a required keyword argument for the PulsarPointsRenderer!"
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
len(point_clouds) != len(self.rasterizer.cameras)
|
len(point_clouds) != len(self.rasterizer.cameras)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
and len(self.rasterizer.cameras) != 1
|
and len(self.rasterizer.cameras) != 1
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -132,6 +134,7 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
)
|
)
|
||||||
% (
|
% (
|
||||||
len(point_clouds),
|
len(point_clouds),
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
len(self.rasterizer.cameras),
|
len(self.rasterizer.cameras),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -141,6 +144,7 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
|
self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
|
||||||
)
|
)
|
||||||
if orthogonal_projection != self.renderer._renderer.orthogonal:
|
if orthogonal_projection != self.renderer._renderer.orthogonal:
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The camera type can not be changed after renderer initialization! "
|
"The camera type can not be changed after renderer initialization! "
|
||||||
"Current camera orthogonal: %r. Original orthogonal: %r."
|
"Current camera orthogonal: %r. Original orthogonal: %r."
|
||||||
@@ -219,6 +223,7 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Currently, this means it must be an 'OrthographicCameras' object.
|
# Currently, this means it must be an 'OrthographicCameras' object.
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
|
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
|
||||||
cloud_idx
|
cloud_idx
|
||||||
]
|
]
|
||||||
@@ -249,11 +254,13 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
znear = kwargs["znear"][cloud_idx]
|
znear = kwargs["znear"][cloud_idx]
|
||||||
zfar = kwargs["zfar"][cloud_idx]
|
zfar = kwargs["zfar"][cloud_idx]
|
||||||
principal_point_x = (
|
principal_point_x = (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
|
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
|
||||||
* 0.5
|
* 0.5
|
||||||
* self.renderer._renderer.width
|
* self.renderer._renderer.width
|
||||||
)
|
)
|
||||||
principal_point_y = (
|
principal_point_y = (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
|
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
|
||||||
* 0.5
|
* 0.5
|
||||||
* self.renderer._renderer.height
|
* self.renderer._renderer.height
|
||||||
@@ -261,20 +268,26 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
if not isinstance(cameras, PerspectiveCameras):
|
if not isinstance(cameras, PerspectiveCameras):
|
||||||
# Create a virtual focal length that is closer than znear.
|
# Create a virtual focal length that is closer than znear.
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
|
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
|
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
|
||||||
focal_length = znear - 1e-6
|
focal_length = znear - 1e-6
|
||||||
# Create a sensor size that matches the expected fov assuming this f.
|
# Create a sensor size that matches the expected fov assuming this f.
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
afov = kwargs.get("fov", cameras.fov)[cloud_idx]
|
afov = kwargs.get("fov", cameras.fov)[cloud_idx]
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
if kwargs.get("degrees", cameras.degrees):
|
if kwargs.get("degrees", cameras.degrees):
|
||||||
afov *= math.pi / 180.0
|
afov *= math.pi / 180.0
|
||||||
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
|
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
|
||||||
if not (
|
if not (
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
|
kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
|
||||||
- self.renderer._renderer.width / self.renderer._renderer.height
|
- self.renderer._renderer.width / self.renderer._renderer.height
|
||||||
< 1e-6
|
< 1e-6
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
"The aspect ratio ("
|
"The aspect ratio ("
|
||||||
f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) "
|
f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) "
|
||||||
"must agree with the resolution width / height ("
|
"must agree with the resolution width / height ("
|
||||||
@@ -361,7 +374,9 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Shorthand:
|
# Shorthand:
|
||||||
cameras = self.rasterizer.cameras
|
cameras = self.rasterizer.cameras
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
R = kwargs.get("R", cameras.R)[cloud_idx]
|
R = kwargs.get("R", cameras.R)[cloud_idx]
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
T = kwargs.get("T", cameras.T)[cloud_idx]
|
T = kwargs.get("T", cameras.T)[cloud_idx]
|
||||||
tmp_cams = PerspectiveCameras(
|
tmp_cams = PerspectiveCameras(
|
||||||
R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
|
R=R.unsqueeze(0), T=T.unsqueeze(0), device=R.device
|
||||||
@@ -388,6 +403,7 @@ class PulsarPointsRenderer(nn.Module):
|
|||||||
# or itself a tensor.
|
# or itself a tensor.
|
||||||
raster_rad = self.rasterizer.raster_settings.radius
|
raster_rad = self.rasterizer.raster_settings.radius
|
||||||
if kwargs.get("radius_world", False):
|
if kwargs.get("radius_world", False):
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return raster_rad
|
return raster_rad
|
||||||
if (
|
if (
|
||||||
isinstance(raster_rad, torch.Tensor)
|
isinstance(raster_rad, torch.Tensor)
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ class _RasterizePoints(torch.autograd.Function):
|
|||||||
return idx, zbuf, dists
|
return idx, zbuf, dists
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_idx, grad_zbuf, grad_dists):
|
def backward(ctx, grad_idx, grad_zbuf, grad_dists):
|
||||||
grad_points = None
|
grad_points = None
|
||||||
grad_cloud_to_packed_first_idx = None
|
grad_cloud_to_packed_first_idx = None
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class PointsRasterizer(nn.Module):
|
|||||||
point_clouds = point_clouds.update_padded(pts_ndc)
|
point_clouds = point_clouds.update_padded(pts_ndc)
|
||||||
return point_clouds
|
return point_clouds
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||||
if self.cameras is not None:
|
if self.cameras is not None:
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ class PointsRenderer(nn.Module):
|
|||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.compositor = compositor
|
self.compositor = compositor
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Manually move to device rasterizer as the cameras
|
# Manually move to device rasterizer as the cameras
|
||||||
# within the class are not of type nn.Module
|
# within the class are not of type nn.Module
|
||||||
|
|||||||
@@ -464,6 +464,7 @@ class SplatterBlender(torch.nn.Module):
|
|||||||
input_shape, device
|
input_shape, device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
self.offsets = self.offsets.to(device)
|
self.offsets = self.offsets.to(device)
|
||||||
self.crop_ids_h = self.crop_ids_h.to(device)
|
self.crop_ids_h = self.crop_ids_h.to(device)
|
||||||
|
|||||||
@@ -67,12 +67,15 @@ class TensorAccessor(nn.Module):
|
|||||||
if (
|
if (
|
||||||
v.dim() == 0
|
v.dim() == 0
|
||||||
and isinstance(self.index, slice)
|
and isinstance(self.index, slice)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
and len(value) != len(self.index)
|
and len(value) != len(self.index)
|
||||||
):
|
):
|
||||||
msg = "Expected value to have len %r; got %r"
|
msg = "Expected value to have len %r; got %r"
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
raise ValueError(msg % (len(self.index), len(value)))
|
raise ValueError(msg % (len(self.index), len(value)))
|
||||||
self.class_object.__dict__[name][self.index] = value
|
self.class_object.__dict__[name][self.index] = value
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-override]
|
||||||
def __getattr__(self, name: str):
|
def __getattr__(self, name: str):
|
||||||
"""
|
"""
|
||||||
Return the value of the attribute given by "name" on self.class_object
|
Return the value of the attribute given by "name" on self.class_object
|
||||||
@@ -85,6 +88,7 @@ class TensorAccessor(nn.Module):
|
|||||||
return self.class_object.__dict__[name][self.index]
|
return self.class_object.__dict__[name][self.index]
|
||||||
else:
|
else:
|
||||||
msg = "Attribute %s not found on %r"
|
msg = "Attribute %s not found on %r"
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
return AttributeError(msg % (name, self.class_object.__name__))
|
return AttributeError(msg % (name, self.class_object.__name__))
|
||||||
|
|
||||||
|
|
||||||
@@ -456,4 +460,5 @@ def parse_image_size(
|
|||||||
raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size)
|
raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size)
|
||||||
if not all(isinstance(i, int) for i in image_size):
|
if not all(isinstance(i, int) for i in image_size):
|
||||||
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
|
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return tuple(image_size)
|
return tuple(image_size)
|
||||||
|
|||||||
@@ -432,13 +432,17 @@ class Meshes:
|
|||||||
|
|
||||||
# Set the num verts/faces on the textures if present.
|
# Set the num verts/faces on the textures if present.
|
||||||
if textures is not None:
|
if textures is not None:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
|
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
|
||||||
if not shape_ok:
|
if not shape_ok:
|
||||||
msg = "Textures do not match the dimensions of Meshes."
|
msg = "Textures do not match the dimensions of Meshes."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
|
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
|
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.textures.valid = self.valid
|
self.textures.valid = self.valid
|
||||||
|
|
||||||
if verts_normals is not None:
|
if verts_normals is not None:
|
||||||
@@ -449,6 +453,7 @@ class Meshes:
|
|||||||
if len(verts_normals) != self._N:
|
if len(verts_normals) != self._N:
|
||||||
raise ValueError("Invalid verts_normals input")
|
raise ValueError("Invalid verts_normals input")
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
for item, n_verts in zip(verts_normals, self._num_verts_per_mesh):
|
for item, n_verts in zip(verts_normals, self._num_verts_per_mesh):
|
||||||
if (
|
if (
|
||||||
not isinstance(item, torch.Tensor)
|
not isinstance(item, torch.Tensor)
|
||||||
@@ -466,7 +471,10 @@ class Meshes:
|
|||||||
):
|
):
|
||||||
raise ValueError("Vertex normals tensor has incorrect dimensions.")
|
raise ValueError("Vertex normals tensor has incorrect dimensions.")
|
||||||
self._verts_normals_packed = struct_utils.padded_to_packed(
|
self._verts_normals_packed = struct_utils.padded_to_packed(
|
||||||
verts_normals, split_size=self._num_verts_per_mesh.tolist()
|
# pyrefly: ignore [missing-attribute]
|
||||||
|
verts_normals,
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
|
split_size=self._num_verts_per_mesh.tolist(),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("verts_normals must be a list or tensor")
|
raise ValueError("verts_normals must be a list or tensor")
|
||||||
@@ -497,8 +505,11 @@ class Meshes:
|
|||||||
# NOTE consider converting index to cpu for efficiency
|
# NOTE consider converting index to cpu for efficiency
|
||||||
if index.dtype == torch.bool:
|
if index.dtype == torch.bool:
|
||||||
# advanced indexing on a single dimension
|
# advanced indexing on a single dimension
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
index = index.nonzero()
|
index = index.nonzero()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
index = index.squeeze(1) if index.numel() > 0 else index
|
index = index.squeeze(1) if index.numel() > 0 else index
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
index = index.tolist()
|
index = index.tolist()
|
||||||
verts = [self.verts_list()[i] for i in index]
|
verts = [self.verts_list()[i] for i in index]
|
||||||
faces = [self.faces_list()[i] for i in index]
|
faces = [self.faces_list()[i] for i in index]
|
||||||
@@ -521,6 +532,7 @@ class Meshes:
|
|||||||
Returns:
|
Returns:
|
||||||
bool indicating whether there is any data.
|
bool indicating whether there is any data.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
return self._N == 0 or self.valid.eq(False).all()
|
return self._N == 0 or self.valid.eq(False).all()
|
||||||
|
|
||||||
def verts_list(self):
|
def verts_list(self):
|
||||||
@@ -1058,6 +1070,7 @@ class Meshes:
|
|||||||
|
|
||||||
# All edges including duplicates.
|
# All edges including duplicates.
|
||||||
edges = torch.cat([e12, e20, e01], dim=0) # (sum(F_n)*3, 2)
|
edges = torch.cat([e12, e20, e01], dim=0) # (sum(F_n)*3, 2)
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
edge_to_mesh = torch.cat(
|
edge_to_mesh = torch.cat(
|
||||||
[
|
[
|
||||||
self._faces_packed_to_mesh_idx,
|
self._faces_packed_to_mesh_idx,
|
||||||
@@ -1082,6 +1095,7 @@ class Meshes:
|
|||||||
# unique_edges[inverse_idxs] == edges
|
# unique_edges[inverse_idxs] == edges
|
||||||
# i.e. inverse_idxs[i] == j means that edges[i] == unique_edges[j]
|
# i.e. inverse_idxs[i] == j means that edges[i] == unique_edges[j]
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
V = self._verts_packed.shape[0]
|
V = self._verts_packed.shape[0]
|
||||||
edges_hash = V * edges[:, 0] + edges[:, 1]
|
edges_hash = V * edges[:, 0] + edges[:, 1]
|
||||||
u, inverse_idxs = torch.unique(edges_hash, return_inverse=True)
|
u, inverse_idxs = torch.unique(edges_hash, return_inverse=True)
|
||||||
@@ -1699,6 +1713,7 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True) ->
|
|||||||
if not tex_types_same:
|
if not tex_types_same:
|
||||||
raise ValueError("All meshes in the batch must have the same type of texture.")
|
raise ValueError("All meshes in the batch must have the same type of texture.")
|
||||||
|
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
tex = first.join_batch(all_textures[1:])
|
tex = first.join_batch(all_textures[1:])
|
||||||
return Meshes(verts=verts, faces=faces, textures=tex)
|
return Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
|||||||
@@ -319,6 +319,7 @@ class Pointclouds:
|
|||||||
|
|
||||||
if len(aux_input) != self._N:
|
if len(aux_input) != self._N:
|
||||||
raise ValueError("Points and auxiliary input must be the same length.")
|
raise ValueError("Points and auxiliary input must be the same length.")
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
for p, d in zip(self._num_points_per_cloud, aux_input):
|
for p, d in zip(self._num_points_per_cloud, aux_input):
|
||||||
valid_but_empty = p == 0 and d is not None and d.ndim == 2
|
valid_but_empty = p == 0 and d is not None and d.ndim == 2
|
||||||
if p > 0 or valid_but_empty:
|
if p > 0 or valid_but_empty:
|
||||||
@@ -350,6 +351,7 @@ class Pointclouds:
|
|||||||
if good_empty is None:
|
if good_empty is None:
|
||||||
good_empty = torch.zeros((0, aux_input_C), device=self.device)
|
good_empty = torch.zeros((0, aux_input_C), device=self.device)
|
||||||
aux_input_out = []
|
aux_input_out = []
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
for p, d in zip(self._num_points_per_cloud, aux_input):
|
for p, d in zip(self._num_points_per_cloud, aux_input):
|
||||||
valid_but_empty = p == 0 and d is not None and d.ndim == 2
|
valid_but_empty = p == 0 and d is not None and d.ndim == 2
|
||||||
if p > 0 or valid_but_empty:
|
if p > 0 or valid_but_empty:
|
||||||
@@ -403,8 +405,11 @@ class Pointclouds:
|
|||||||
# NOTE consider converting index to cpu for efficiency
|
# NOTE consider converting index to cpu for efficiency
|
||||||
if index.dtype == torch.bool:
|
if index.dtype == torch.bool:
|
||||||
# advanced indexing on a single dimension
|
# advanced indexing on a single dimension
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
index = index.nonzero()
|
index = index.nonzero()
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
index = index.squeeze(1) if index.numel() > 0 else index
|
index = index.squeeze(1) if index.numel() > 0 else index
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
index = index.tolist()
|
index = index.tolist()
|
||||||
points = [self.points_list()[i] for i in index]
|
points = [self.points_list()[i] for i in index]
|
||||||
if normals_list is not None:
|
if normals_list is not None:
|
||||||
@@ -423,6 +428,7 @@ class Pointclouds:
|
|||||||
Returns:
|
Returns:
|
||||||
bool indicating whether there is any data.
|
bool indicating whether there is any data.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
return self._N == 0 or self.valid.eq(False).all()
|
return self._N == 0 or self.valid.eq(False).all()
|
||||||
|
|
||||||
def points_list(self) -> List[torch.Tensor]:
|
def points_list(self) -> List[torch.Tensor]:
|
||||||
@@ -486,6 +492,7 @@ class Pointclouds:
|
|||||||
tensor of points of shape (sum(P_n), 3).
|
tensor of points of shape (sum(P_n), 3).
|
||||||
"""
|
"""
|
||||||
self._compute_packed()
|
self._compute_packed()
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._points_packed
|
return self._points_packed
|
||||||
|
|
||||||
def normals_packed(self) -> Optional[torch.Tensor]:
|
def normals_packed(self) -> Optional[torch.Tensor]:
|
||||||
@@ -541,6 +548,7 @@ class Pointclouds:
|
|||||||
Returns:
|
Returns:
|
||||||
1D tensor of sizes.
|
1D tensor of sizes.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._num_points_per_cloud
|
return self._num_points_per_cloud
|
||||||
|
|
||||||
def points_padded(self) -> torch.Tensor:
|
def points_padded(self) -> torch.Tensor:
|
||||||
@@ -551,6 +559,7 @@ class Pointclouds:
|
|||||||
tensor of points of shape (N, max(P_n), 3).
|
tensor of points of shape (N, max(P_n), 3).
|
||||||
"""
|
"""
|
||||||
self._compute_padded()
|
self._compute_padded()
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._points_padded
|
return self._points_padded
|
||||||
|
|
||||||
def normals_padded(self) -> Optional[torch.Tensor]:
|
def normals_padded(self) -> Optional[torch.Tensor]:
|
||||||
@@ -636,6 +645,7 @@ class Pointclouds:
|
|||||||
if features_list is not None:
|
if features_list is not None:
|
||||||
self._features_padded = struct_utils.list_to_padded(
|
self._features_padded = struct_utils.list_to_padded(
|
||||||
features_list,
|
features_list,
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
(self._P, self._C),
|
(self._P, self._C),
|
||||||
pad_value=0.0,
|
pad_value=0.0,
|
||||||
equisized=self.equisized,
|
equisized=self.equisized,
|
||||||
@@ -686,6 +696,7 @@ class Pointclouds:
|
|||||||
|
|
||||||
points_list_to_packed = struct_utils.list_to_packed(points_list)
|
points_list_to_packed = struct_utils.list_to_packed(points_list)
|
||||||
self._points_packed = points_list_to_packed[0]
|
self._points_packed = points_list_to_packed[0]
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if not torch.allclose(self._num_points_per_cloud, points_list_to_packed[1]):
|
if not torch.allclose(self._num_points_per_cloud, points_list_to_packed[1]):
|
||||||
raise ValueError("Inconsistent list to packed conversion")
|
raise ValueError("Inconsistent list to packed conversion")
|
||||||
self._cloud_to_packed_first_idx = points_list_to_packed[2]
|
self._cloud_to_packed_first_idx = points_list_to_packed[2]
|
||||||
@@ -1066,6 +1077,7 @@ class Pointclouds:
|
|||||||
self.normals_list()
|
self.normals_list()
|
||||||
if self._points_packed is not None:
|
if self._points_packed is not None:
|
||||||
# update self._normals_packed
|
# update self._normals_packed
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
self._normals_packed = torch.cat(self._normals_list, dim=0)
|
self._normals_packed = torch.cat(self._normals_list, dim=0)
|
||||||
|
|
||||||
return normals_est
|
return normals_est
|
||||||
|
|||||||
@@ -1010,6 +1010,7 @@ class VolumeLocator:
|
|||||||
Defaults to all items (`:`).
|
Defaults to all items (`:`).
|
||||||
"""
|
"""
|
||||||
device = device if device is not None else self.device
|
device = device if device is not None else self.device
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
other._grid_sizes = self._grid_sizes[index].to(device)
|
other._grid_sizes = self._grid_sizes[index].to(device)
|
||||||
other._local_to_world_transform = self.get_local_to_world_coords_transform()[
|
other._local_to_world_transform = self.get_local_to_world_coords_transform()[
|
||||||
# pyre-fixme[6]: For 1st param expected `Union[List[int], int, slice,
|
# pyre-fixme[6]: For 1st param expected `Union[List[int], int, slice,
|
||||||
@@ -1114,6 +1115,7 @@ class VolumeLocator:
|
|||||||
return other
|
return other
|
||||||
|
|
||||||
other.device = device_
|
other.device = device_
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
other._grid_sizes = self._grid_sizes.to(device_)
|
other._grid_sizes = self._grid_sizes.to(device_)
|
||||||
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
|
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
|
||||||
device
|
device
|
||||||
|
|||||||
@@ -628,6 +628,7 @@ def _add_struct_from_batch(
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
struct = batched_struct[struct_idx]
|
struct = batched_struct[struct_idx]
|
||||||
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
|
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
scene_dictionary[subplot_title][trace_name] = struct
|
scene_dictionary[subplot_title][trace_name] = struct
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user