Summary: lint issues (mostly flake) in implicitron

Reviewed By: patricklabatut

Differential Revision: D37920948

fbshipit-source-id: 8cb3c2a2838d111c80a211c98a404c210d4649ed
This commit is contained in:
Jeremy Reizenstein 2022-07-21 13:33:49 -07:00 committed by Facebook GitHub Bot
parent 8597d4c5c1
commit b2dc520210
9 changed files with 18 additions and 26 deletions

View File

@ -833,7 +833,7 @@ def _load_1bit_png_mask(file: str) -> np.ndarray:
return mask
def _load_depth_mask(path) -> np.ndarray:
def _load_depth_mask(path: str) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path)

View File

@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import field
from typing import List, Optional
from typing import Optional, Tuple
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
@ -206,7 +205,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 1.0
n_hidden_neurons_xyz: int = 256
n_layers_xyz: int = 8
append_xyz: List[int] = field(default_factory=lambda: [5])
append_xyz: Tuple[int, ...] = (5,)
def _construct_xyz_encoder(self, input_dim: int):
return MLPWithInputSkips(
@ -224,7 +223,7 @@ class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 2.0
n_hidden_neurons_xyz: int = 80
n_layers_xyz: int = 2
append_xyz: List[int] = field(default_factory=lambda: [1])
append_xyz: Tuple[int, ...] = (1,)
def _construct_xyz_encoder(self, input_dim: int):
return TransformerWithInputSkips(
@ -286,7 +285,7 @@ class MLPWithInputSkips(torch.nn.Module):
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: List[int] = [5],
input_skips: Tuple[int, ...] = (5,),
skip_affine_trans: bool = False,
no_last_relu=False,
):
@ -362,7 +361,7 @@ class TransformerWithInputSkips(torch.nn.Module):
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 64,
input_skips: List[int] = [5],
input_skips: Tuple[int, ...] = (5,),
dim_down_factor: float = 1,
):
"""

View File

@ -7,11 +7,10 @@
from typing import List
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, RendererOutput
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_point_refiner import RayPointRefiner
from .raymarcher import RaymarcherBase
@ -107,7 +106,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
def forward(
self,
ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper] = [],
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Optional, Tuple
import torch

View File

@ -59,7 +59,7 @@ def cleanup_eval_depth(
good_df_thr = std * sigma
good_depth = (df <= good_df_thr).float() * pcl_mask
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
good_depth_raster = torch.zeros_like(depth).view(ba, -1)

View File

@ -200,9 +200,6 @@ def _visdom_plot_scene(
viz = Visdom()
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
import pdb
pdb.set_trace()
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):

View File

@ -202,7 +202,7 @@ def neg_iou_loss(
return 1.0 - iou(predict, target, mask=mask)
def safe_sqrt(A: torch.Tensor, eps: float = float(1e-4)) -> torch.Tensor:
def safe_sqrt(A: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
"""
performs safe differentiable sqrt
"""

View File

@ -20,12 +20,10 @@ logger = logging.getLogger(__name__)
def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats
try:
stats = Stats.load(flstats)
except:
logger.info("Cant load stats! %s" % flstats)
stats = None
return stats
if not os.path.isfile(flstats):
return None
return Stats.load(flstats)
def get_model_path(fl) -> str:
@ -40,7 +38,7 @@ def get_optimizer_path(fl) -> str:
return flopt
def get_stats_path(fl, eval_results: bool = False):
def get_stats_path(fl, eval_results: bool = False) -> str:
fl = os.path.splitext(fl)[0]
if eval_results:
for postfix in ("_2", ""):

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List
from typing import Any, Dict, Tuple
import torch
from visdom import Visdom
@ -60,14 +60,14 @@ def visualize_basics(
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",
visualize_preds_keys: List[str] = [
visualize_preds_keys: Tuple[str, ...] = (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
],
),
store_history: bool = False,
) -> None:
"""