apply Black 2024 style in fbcode (4/16)

Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
This commit is contained in:
Amethyst Reese 2024-03-02 17:31:19 -08:00 committed by Facebook GitHub Bot
parent f34104cf6e
commit 3da7703c5a
31 changed files with 130 additions and 106 deletions

View File

@ -343,12 +343,14 @@ class RadianceFieldRenderer(torch.nn.Module):
# For a full render pass concatenate the output chunks, # For a full render pass concatenate the output chunks,
# and reshape to image size. # and reshape to image size.
out = { out = {
k: torch.cat( k: (
torch.cat(
[ch_o[k] for ch_o in chunk_outputs], [ch_o[k] for ch_o in chunk_outputs],
dim=1, dim=1,
).view(-1, *self._image_size, 3) ).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None if chunk_outputs[0][k] is not None
else None else None
)
for k in ("rgb_fine", "rgb_coarse", "rgb_gt") for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
} }
else: else:

View File

@ -576,11 +576,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
camera_quality_score=safe_as_tensor( camera_quality_score=safe_as_tensor(
sequence_annotation.viewpoint_quality_score, torch.float sequence_annotation.viewpoint_quality_score, torch.float
), ),
point_cloud_quality_score=safe_as_tensor( point_cloud_quality_score=(
point_cloud.quality_score, torch.float safe_as_tensor(point_cloud.quality_score, torch.float)
)
if point_cloud is not None if point_cloud is not None
else None, else None
),
) )
fg_mask_np: Optional[np.ndarray] = None fg_mask_np: Optional[np.ndarray] = None

View File

@ -124,9 +124,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
dimension of the cropping bounding box, relative to box size. dimension of the cropping bounding box, relative to box size.
""" """
frame_annotations_type: ClassVar[ frame_annotations_type: ClassVar[Type[types.FrameAnnotation]] = (
Type[types.FrameAnnotation] types.FrameAnnotation
] = types.FrameAnnotation )
path_manager: Any = None path_manager: Any = None
frame_annotations_file: str = "" frame_annotations_file: str = ""

View File

@ -88,9 +88,11 @@ def get_implicitron_sequence_pointcloud(
frame_data.camera, frame_data.camera,
frame_data.image_rgb, frame_data.image_rgb,
frame_data.depth_map, frame_data.depth_map,
(
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float() (cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if mask_points and frame_data.fg_probability is not None if mask_points and frame_data.fg_probability is not None
else None, else None
),
) )
return point_cloud, frame_data return point_cloud, frame_data

View File

@ -282,9 +282,9 @@ def eval_batch(
image_rgb_masked=image_rgb_masked, image_rgb_masked=image_rgb_masked,
depth_render=cloned_render["depth_render"], depth_render=cloned_render["depth_render"],
depth_map=frame_data.depth_map, depth_map=frame_data.depth_map,
depth_mask=frame_data.depth_mask[:1] depth_mask=(
if frame_data.depth_mask is not None frame_data.depth_mask[:1] if frame_data.depth_mask is not None else None
else None, ),
visdom_env=visualize_visdom_env, visdom_env=visualize_visdom_env,
) )

View File

@ -395,10 +395,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
n_targets = ( n_targets = (
1 1
if evaluation_mode == EvaluationMode.EVALUATION if evaluation_mode == EvaluationMode.EVALUATION
else batch_size else (
batch_size
if self.n_train_target_views <= 0 if self.n_train_target_views <= 0
else min(self.n_train_target_views, batch_size) else min(self.n_train_target_views, batch_size)
) )
)
# A helper function for selecting n_target first elements from the input # A helper function for selecting n_target first elements from the input
# where the latter can be None. # where the latter can be None.
@ -422,9 +424,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
ray_bundle: ImplicitronRayBundle = self.raysampler( ray_bundle: ImplicitronRayBundle = self.raysampler(
target_cameras, target_cameras,
evaluation_mode, evaluation_mode,
mask=mask_crop[:n_targets] mask=(
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE mask_crop[:n_targets]
else None, if mask_crop is not None
and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None
),
) )
# custom_args hold additional arguments to the implicit function. # custom_args hold additional arguments to the implicit function.

View File

@ -102,9 +102,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0: elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_( torch.nn.init.normal_(lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5)
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
)
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in: elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5) torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)

View File

@ -193,9 +193,9 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
embeds = create_embeddings_for_implicit_function( embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world, xyz_world=rays_points_world,
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz xyz_embedding_function=(
if self.input_xyz self.harmonic_embedding_xyz if self.input_xyz else None
else None, ),
global_code=global_code, global_code=global_code,
fun_viewpool=fun_viewpool, fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords, xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,

View File

@ -356,9 +356,12 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
ray_bundle: ImplicitronRayBundle = self.raysampler( ray_bundle: ImplicitronRayBundle = self.raysampler(
camera, camera,
evaluation_mode, evaluation_mode,
mask=mask_crop mask=(
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE mask_crop
else None, if mask_crop is not None
and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None
),
) )
inputs_to_be_chunked = {} inputs_to_be_chunked = {}
@ -381,11 +384,13 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
frame_timestamp=frame_timestamp, frame_timestamp=frame_timestamp,
) )
implicit_functions = [ implicit_functions = [
(
functools.partial(implicit_function, global_code=global_code) functools.partial(implicit_function, global_code=global_code)
if isinstance(implicit_function, Callable) if isinstance(implicit_function, Callable)
else functools.partial( else functools.partial(
implicit_function.forward, global_code=global_code implicit_function.forward, global_code=global_code
) )
)
for implicit_function in implicit_functions for implicit_function in implicit_functions
] ]
rendered = self._render( rendered = self._render(

View File

@ -145,10 +145,12 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
n_pts_per_ray=n_pts_per_ray_training, n_pts_per_ray=n_pts_per_ray_training,
min_depth=0.0, min_depth=0.0,
max_depth=0.0, max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask n_rays_per_image=(
self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING] if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE == RenderSamplingMode.MASK_SAMPLE
else None, else None
),
n_rays_total=self.n_rays_total_training, n_rays_total=self.n_rays_total_training,
unit_directions=True, unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training, stratified_sampling=self.stratified_point_sampling_training,
@ -160,10 +162,12 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
n_pts_per_ray=n_pts_per_ray_evaluation, n_pts_per_ray=n_pts_per_ray_evaluation,
min_depth=0.0, min_depth=0.0,
max_depth=0.0, max_depth=0.0,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask n_rays_per_image=(
self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION] if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE == RenderSamplingMode.MASK_SAMPLE
else None, else None
),
unit_directions=True, unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation, stratified_sampling=self.stratified_point_sampling_evaluation,
) )

View File

@ -415,7 +415,7 @@ class RayTracing(Configurable, nn.Module):
] ]
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[ sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
p_out_mask, p_out_mask,
: :,
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]` but # pyre-fixme[6]: For 1st param expected `Union[bool, float, int]` but
# got `Tensor`. # got `Tensor`.
][torch.arange(n_p_out), out_pts_idx] ][torch.arange(n_p_out), out_pts_idx]

View File

@ -43,9 +43,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
run_auto_creation(self) run_auto_creation(self)
self.ray_normal_coloring_network_args[ self.ray_normal_coloring_network_args["feature_vector_size"] = (
"feature_vector_size" render_features_dimensions
] = render_features_dimensions )
self._rgb_network = RayNormalColoringNetwork( self._rgb_network = RayNormalColoringNetwork(
**self.ray_normal_coloring_network_args **self.ray_normal_coloring_network_args
) )
@ -201,9 +201,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
None, :, 0, : None, :, 0, :
] ]
normals_full.view(-1, 3)[surface_mask] = normals normals_full.view(-1, 3)[surface_mask] = normals
render_full.view(-1, self.render_features_dimensions)[ render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
surface_mask self._rgb_network(
] = self._rgb_network(
features, features,
differentiable_surface_points[None], differentiable_surface_points[None],
normals, normals,
@ -211,6 +210,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
surface_mask[None, :, None], surface_mask[None, :, None],
pooling_fn=None, # TODO pooling_fn=None, # TODO
) )
)
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid( mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`. # pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
-self.soft_mask_alpha -self.soft_mask_alpha

View File

@ -241,9 +241,9 @@ class _Registry:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._mapping: Dict[ self._mapping: Dict[Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]] = (
Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]] defaultdict(dict)
] = defaultdict(dict) )
def register(self, some_class: Type[_X]) -> Type[_X]: def register(self, some_class: Type[_X]) -> Type[_X]:
""" """

View File

@ -139,9 +139,11 @@ def generate_eval_video_cameras(
fit = fit_circle_in_3d( fit = fit_circle_in_3d(
cam_centers, cam_centers,
angles=angle, angles=angle,
offset=angle.new_tensor(traj_offset_canonical) offset=(
angle.new_tensor(traj_offset_canonical)
if traj_offset_canonical is not None if traj_offset_canonical is not None
else None, else None
),
up=angle.new_tensor(up), up=angle.new_tensor(up),
) )
traj = fit.generated_points traj = fit.generated_points

View File

@ -146,9 +146,11 @@ def cat_dataclass(batch, tensor_collator: Callable):
) )
elif isinstance(elem_f, collections.abc.Mapping): elif isinstance(elem_f, collections.abc.Mapping):
collated[f.name] = { collated[f.name] = {
k: tensor_collator([getattr(e, f.name)[k] for e in batch]) k: (
tensor_collator([getattr(e, f.name)[k] for e in batch])
if elem_f[k] is not None if elem_f[k] is not None
else None else None
)
for k in elem_f for k in elem_f
} }
else: else:

View File

@ -81,7 +81,6 @@ class FishEyeCameras(CamerasBase):
device: Device = "cpu", device: Device = "cpu",
image_size: Optional[Union[List, Tuple, torch.Tensor]] = None, image_size: Optional[Union[List, Tuple, torch.Tensor]] = None,
) -> None: ) -> None:
""" """
Args: Args:

View File

@ -712,9 +712,9 @@ def convert_clipped_rasterization_to_original_faces(
) )
bary_coords_unclipped_subset = bary_coords_unclipped_subset.reshape([N * 3]) bary_coords_unclipped_subset = bary_coords_unclipped_subset.reshape([N * 3])
bary_coords_unclipped[ bary_coords_unclipped[faces_to_convert_mask_expanded] = (
faces_to_convert_mask_expanded bary_coords_unclipped_subset
] = bary_coords_unclipped_subset )
# dists for case 4 faces will be handled in the rasterizer # dists for case 4 faces will be handled in the rasterizer
# so no need to modify them here. # so no need to modify them here.

View File

@ -605,7 +605,10 @@ def rasterize_meshes_python( # noqa: C901
# If faces were clipped, map the rasterization result to be in terms of the # If faces were clipped, map the rasterization result to be in terms of the
# original unclipped faces. This may involve converting barycentric # original unclipped faces. This may involve converting barycentric
# coordinates # coordinates
(face_idxs, bary_coords,) = convert_clipped_rasterization_to_original_faces( (
face_idxs,
bary_coords,
) = convert_clipped_rasterization_to_original_faces(
face_idxs, face_idxs,
bary_coords, bary_coords,
# pyre-fixme[61]: `clipped_faces` may not be initialized here. # pyre-fixme[61]: `clipped_faces` may not be initialized here.

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# If we can access EGL, import MeshRasterizerOpenGL. # If we can access EGL, import MeshRasterizerOpenGL.
def _can_import_egl_and_pycuda(): def _can_import_egl_and_pycuda():
import os import os

View File

@ -292,9 +292,11 @@ class _OpenGLMachinery:
pix_to_face, bary_coord, zbuf = self._rasterize_mesh( pix_to_face, bary_coord, zbuf = self._rasterize_mesh(
mesh, mesh,
image_size, image_size,
projection_matrix=projection_matrix[mesh_id] projection_matrix=(
projection_matrix[mesh_id]
if projection_matrix.shape[0] > 1 if projection_matrix.shape[0] > 1
else None, else None
),
) )
pix_to_faces.append(pix_to_face) pix_to_faces.append(pix_to_face)
bary_coords.append(bary_coord) bary_coords.append(bary_coord)

View File

@ -61,9 +61,9 @@ class ExtendedSqlFrameAnnotation(SqlFrameAnnotation):
class ExtendedSqlIndexDataset(SqlIndexDataset): class ExtendedSqlIndexDataset(SqlIndexDataset):
frame_annotations_type: ClassVar[ frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = (
Type[SqlFrameAnnotation] ExtendedSqlFrameAnnotation
] = ExtendedSqlFrameAnnotation )
class CanineFrameData(FrameData): class CanineFrameData(FrameData):
@ -96,9 +96,9 @@ class CanineFrameDataBuilder(
class CanineSqlIndexDataset(SqlIndexDataset): class CanineSqlIndexDataset(SqlIndexDataset):
frame_annotations_type: ClassVar[ frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = (
Type[SqlFrameAnnotation] ExtendedSqlFrameAnnotation
] = ExtendedSqlFrameAnnotation )
frame_data_builder_class_type: str = "CanineFrameDataBuilder" frame_data_builder_class_type: str = "CanineFrameDataBuilder"

View File

@ -85,11 +85,11 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
camera_quality_score=safe_as_tensor( camera_quality_score=safe_as_tensor(
self.seq_annotation.viewpoint_quality_score, torch.float self.seq_annotation.viewpoint_quality_score, torch.float
), ),
point_cloud_quality_score=safe_as_tensor( point_cloud_quality_score=(
point_cloud.quality_score, torch.float safe_as_tensor(point_cloud.quality_score, torch.float)
)
if point_cloud is not None if point_cloud is not None
else None, else None
),
) )
def test_frame_data_builder_args(self): def test_frame_data_builder_args(self):

View File

@ -168,7 +168,10 @@ def _make_random_json_dataset_map_provider_v2_data(
mask_path = os.path.join(maskdir, f"frame{i:05d}.png") mask_path = os.path.join(maskdir, f"frame{i:05d}.png")
mask = np.zeros((H, W)) mask = np.zeros((H, W))
mask[H // 2 :, W // 2 :] = 1 mask[H // 2 :, W // 2 :] = 1
Image.fromarray((mask * 255.0).astype(np.uint8), mode="L",).convert( Image.fromarray(
(mask * 255.0).astype(np.uint8),
mode="L",
).convert(
"L" "L"
).save(mask_path) ).save(mask_path)

View File

@ -222,10 +222,7 @@ class TestRendererBase(TestCaseMixin, unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
(delta**2) / 3 (delta**2) / 3
- (4 / 15) - (4 / 15)
* ( * ((delta**4 * (12 * mu**2 - delta**2)) / (3 * mu**2 + delta**2) ** 2),
(delta**4 * (12 * mu**2 - delta**2))
/ (3 * mu**2 + delta**2) ** 2
),
t_var.numpy(), t_var.numpy(),
) )
np.testing.assert_allclose( np.testing.assert_allclose(

View File

@ -983,7 +983,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
verts_list = [] verts_list = []
faces_list = [] faces_list = []
verts_faces = [(10, 100), (20, 200)] verts_faces = [(10, 100), (20, 200)]
for (V, F) in verts_faces: for V, F in verts_faces:
verts = torch.rand((V, 3), dtype=torch.float32, device=device) verts = torch.rand((V, 3), dtype=torch.float32, device=device)
faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device) faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device)
verts_list.append(verts) verts_list.append(verts)
@ -1007,7 +1007,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
device = torch.device("cuda:0") device = torch.device("cuda:0")
verts_list = [] verts_list = []
faces_list = [] faces_list = []
for (V, F) in [(10, 100)]: for V, F in [(10, 100)]:
verts = torch.rand((V, 3), dtype=torch.float32, device=device) verts = torch.rand((V, 3), dtype=torch.float32, device=device)
faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device) faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device)
verts_list.append(verts) verts_list.append(verts)
@ -1025,7 +1025,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
verts_list = [] verts_list = []
faces_list = [] faces_list = []
verts_faces = [(10, 100), (20, 200), (30, 300)] verts_faces = [(10, 100), (20, 200), (30, 300)]
for (V, F) in verts_faces: for V, F in verts_faces:
verts = torch.rand((V, 3), dtype=torch.float32, device=device) verts = torch.rand((V, 3), dtype=torch.float32, device=device)
faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device) faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device)
verts_list.append(verts) verts_list.append(verts)
@ -1047,7 +1047,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
verts_list = [] verts_list = []
faces_list = [] faces_list = []
verts_faces = [(10, 100), (20, 200), (30, 300)] verts_faces = [(10, 100), (20, 200), (30, 300)]
for (V, F) in verts_faces: for V, F in verts_faces:
verts = torch.rand((V, 3), dtype=torch.float32, device=device) verts = torch.rand((V, 3), dtype=torch.float32, device=device)
faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device) faces = torch.randint(V, size=(F, 3), dtype=torch.int64, device=device)
verts_list.append(verts) verts_list.append(verts)

View File

@ -284,7 +284,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
frames = [] frames = []
for (image_opacity, image_opacity_mesh) in zip( for image_opacity, image_opacity_mesh in zip(
images_opacities, images_opacities_meshes images_opacities, images_opacities_meshes
): ):
image, opacity = image_opacity.split([3, 1], dim=-1) image, opacity = image_opacity.split([3, 1], dim=-1)

View File

@ -303,7 +303,6 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
self.test_simple_sphere(check_depth=True) self.test_simple_sphere(check_depth=True)
def test_simple_sphere_screen(self): def test_simple_sphere_screen(self):
""" """
Test output when rendering with PerspectiveCameras & OrthographicCameras Test output when rendering with PerspectiveCameras & OrthographicCameras
in NDC vs screen space. in NDC vs screen space.
@ -1221,7 +1220,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
"flat": HardFlatShader, "flat": HardFlatShader,
"splatter": SplatterPhongShader, "splatter": SplatterPhongShader,
} }
for (name, shader_init) in shaders.items(): for name, shader_init in shaders.items():
if rasterizer_type == MeshRasterizerOpenGL and name != "splatter": if rasterizer_type == MeshRasterizerOpenGL and name != "splatter":
continue continue
if rasterizer_type == MeshRasterizer and name == "splatter": if rasterizer_type == MeshRasterizer and name == "splatter":

View File

@ -620,7 +620,7 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
plane into a quadrilateral, there shouldn't be duplicates indices of plane into a quadrilateral, there shouldn't be duplicates indices of
the face in the pix_to_face output of rasterization. the face in the pix_to_face output of rasterization.
""" """
for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]: for device, bin_size in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]:
verts = torch.tensor( verts = torch.tensor(
[[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]], [[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]],
dtype=torch.float32, dtype=torch.float32,
@ -673,7 +673,7 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
device = "cuda:0" device = "cuda:0"
mesh1 = torus(20.0, 85.0, 32, 16, device=device) mesh1 = torus(20.0, 85.0, 32, 16, device=device)
mesh2 = torus(2.0, 3.0, 32, 16, device=device) mesh2 = torus(2.0, 3.0, 32, 16, device=device)
for (mesh, z_clip) in [(mesh1, None), (mesh2, 5.0)]: for mesh, z_clip in [(mesh1, None), (mesh2, 5.0)]:
tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded())) tex = TexturesVertex(verts_features=torch.rand_like(mesh.verts_padded()))
mesh.textures = tex mesh.textures = tex
raster_settings = RasterizationSettings( raster_settings = RasterizationSettings(

View File

@ -384,7 +384,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
(AlphaCompositor, alpha_composite), (AlphaCompositor, alpha_composite),
] ]
for (compositor_class, composite_func) in compositor_funcs: for compositor_class, composite_func in compositor_funcs:
compositor = compositor_class(background_color) compositor = compositor_class(background_color)
@ -435,7 +435,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
(AlphaCompositor, alpha_composite), (AlphaCompositor, alpha_composite),
] ]
for (compositor_class, composite_func) in compositor_funcs: for compositor_class, composite_func in compositor_funcs:
compositor = compositor_class(background_color) compositor = compositor_class(background_color)

View File

@ -392,7 +392,7 @@ class TestRenderVolumes(TestCaseMixin, unittest.TestCase):
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
frames = [] frames = []
for (image, image_pts) in zip(images, images_pts): for image, image_pts in zip(images, images_pts):
diff_image = ( diff_image = (
((image - image_pts) * 0.5 + 0.5) ((image - image_pts) * 0.5 + 0.5)
.mean(dim=2, keepdim=True) .mean(dim=2, keepdim=True)

View File

@ -100,7 +100,7 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"): def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
H, W = [14, 28], [14, 28] H, W = [14, 28], [14, 28]
feats = [] feats = []
for (h, w) in zip(H, W): for h, w in zip(H, W):
feats.append(torch.rand((batch_size, num_channels, h, w), device=device)) feats.append(torch.rand((batch_size, num_channels, h, w), device=device))
return feats return feats