Summary: lint issues (mostly flake) in implicitron

Reviewed By: patricklabatut

Differential Revision: D37920948

fbshipit-source-id: 8cb3c2a2838d111c80a211c98a404c210d4649ed
This commit is contained in:
Jeremy Reizenstein 2022-07-21 13:33:49 -07:00 committed by Facebook GitHub Bot
parent 8597d4c5c1
commit b2dc520210
9 changed files with 18 additions and 26 deletions

View File

@ -833,7 +833,7 @@ def _load_1bit_png_mask(file: str) -> np.ndarray:
return mask return mask
def _load_depth_mask(path) -> np.ndarray: def _load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"): if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path) raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path) m = _load_1bit_png_mask(path)

View File

@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from dataclasses import field from typing import Optional, Tuple
from typing import List, Optional
import torch import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat from pytorch3d.common.linear_with_repeat import LinearWithRepeat
@ -206,7 +205,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 1.0 transformer_dim_down_factor: float = 1.0
n_hidden_neurons_xyz: int = 256 n_hidden_neurons_xyz: int = 256
n_layers_xyz: int = 8 n_layers_xyz: int = 8
append_xyz: List[int] = field(default_factory=lambda: [5]) append_xyz: Tuple[int, ...] = (5,)
def _construct_xyz_encoder(self, input_dim: int): def _construct_xyz_encoder(self, input_dim: int):
return MLPWithInputSkips( return MLPWithInputSkips(
@ -224,7 +223,7 @@ class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 2.0 transformer_dim_down_factor: float = 2.0
n_hidden_neurons_xyz: int = 80 n_hidden_neurons_xyz: int = 80
n_layers_xyz: int = 2 n_layers_xyz: int = 2
append_xyz: List[int] = field(default_factory=lambda: [1]) append_xyz: Tuple[int, ...] = (1,)
def _construct_xyz_encoder(self, input_dim: int): def _construct_xyz_encoder(self, input_dim: int):
return TransformerWithInputSkips( return TransformerWithInputSkips(
@ -286,7 +285,7 @@ class MLPWithInputSkips(torch.nn.Module):
output_dim: int = 256, output_dim: int = 256,
skip_dim: int = 39, skip_dim: int = 39,
hidden_dim: int = 256, hidden_dim: int = 256,
input_skips: List[int] = [5], input_skips: Tuple[int, ...] = (5,),
skip_affine_trans: bool = False, skip_affine_trans: bool = False,
no_last_relu=False, no_last_relu=False,
): ):
@ -362,7 +361,7 @@ class TransformerWithInputSkips(torch.nn.Module):
output_dim: int = 256, output_dim: int = 256,
skip_dim: int = 39, skip_dim: int = 39,
hidden_dim: int = 64, hidden_dim: int = 64,
input_skips: List[int] = [5], input_skips: Tuple[int, ...] = (5,),
dim_down_factor: float = 1, dim_down_factor: float = 1,
): ):
""" """

View File

@ -7,11 +7,10 @@
from typing import List from typing import List
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import registry, run_auto_creation from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from pytorch3d.renderer import RayBundle from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, RendererOutput from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_point_refiner import RayPointRefiner from .ray_point_refiner import RayPointRefiner
from .raymarcher import RaymarcherBase from .raymarcher import RaymarcherBase
@ -107,7 +106,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper] = [], implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs, **kwargs,
) -> RendererOutput: ) -> RendererOutput:

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch

View File

@ -59,7 +59,7 @@ def cleanup_eval_depth(
good_df_thr = std * sigma good_df_thr = std * sigma
good_depth = (df <= good_df_thr).float() * pcl_mask good_depth = (df <= good_df_thr).float() * pcl_mask
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1) # perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points') # print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
good_depth_raster = torch.zeros_like(depth).view(ba, -1) good_depth_raster = torch.zeros_like(depth).view(ba, -1)

View File

@ -200,9 +200,6 @@ def _visdom_plot_scene(
viz = Visdom() viz = Visdom()
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs") viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
import pdb
pdb.set_trace()
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5): def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):

View File

@ -202,7 +202,7 @@ def neg_iou_loss(
return 1.0 - iou(predict, target, mask=mask) return 1.0 - iou(predict, target, mask=mask)
def safe_sqrt(A: torch.Tensor, eps: float = float(1e-4)) -> torch.Tensor: def safe_sqrt(A: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
""" """
performs safe differentiable sqrt performs safe differentiable sqrt
""" """

View File

@ -20,12 +20,10 @@ logger = logging.getLogger(__name__)
def load_stats(flstats): def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.implicitron.tools.stats import Stats
try: if not os.path.isfile(flstats):
stats = Stats.load(flstats) return None
except:
logger.info("Cant load stats! %s" % flstats) return Stats.load(flstats)
stats = None
return stats
def get_model_path(fl) -> str: def get_model_path(fl) -> str:
@ -40,7 +38,7 @@ def get_optimizer_path(fl) -> str:
return flopt return flopt
def get_stats_path(fl, eval_results: bool = False): def get_stats_path(fl, eval_results: bool = False) -> str:
fl = os.path.splitext(fl)[0] fl = os.path.splitext(fl)[0]
if eval_results: if eval_results:
for postfix in ("_2", ""): for postfix in ("_2", ""):

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from typing import Any, Dict, List from typing import Any, Dict, Tuple
import torch import torch
from visdom import Visdom from visdom import Visdom
@ -60,14 +60,14 @@ def visualize_basics(
preds: Dict[str, Any], preds: Dict[str, Any],
visdom_env_imgs: str, visdom_env_imgs: str,
title: str = "", title: str = "",
visualize_preds_keys: List[str] = [ visualize_preds_keys: Tuple[str, ...] = (
"image_rgb", "image_rgb",
"images_render", "images_render",
"fg_probability", "fg_probability",
"masks_render", "masks_render",
"depths_render", "depths_render",
"depth_map", "depth_map",
], ),
store_history: bool = False, store_history: bool = False,
) -> None: ) -> None:
""" """