diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 8c7b30c3..7971a630 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from dataclasses import dataclass from typing import NamedTuple, Optional import torch import torch.nn as nn @@ -20,14 +18,31 @@ class Fragments(NamedTuple): # Class to store the mesh rasterization params with defaults -@dataclass class RasterizationSettings: - image_size: int = 256 - blur_radius: float = 0.0 - faces_per_pixel: int = 1 - bin_size: Optional[int] = None - max_faces_per_bin: Optional[int] = None - perspective_correct: bool = False + __slots__ = [ + "image_size", + "blur_radius", + "faces_per_pixel", + "bin_size", + "max_faces_per_bin", + "perspective_correct", + ] + + def __init__( + self, + image_size: int = 256, + blur_radius: float = 0.0, + faces_per_pixel: int = 1, + bin_size: Optional[int] = None, + max_faces_per_bin: Optional[int] = None, + perspective_correct: bool = False, + ): + self.image_size = image_size + self.blur_radius = blur_radius + self.faces_per_pixel = faces_per_pixel + self.bin_size = bin_size + self.max_faces_per_bin = max_faces_per_bin + self.perspective_correct = perspective_correct class MeshRasterizer(nn.Module):