Summary: Use logging instead of printing in the internals of implicitron.

Reviewed By: davnov134

Differential Revision: D35247581

fbshipit-source-id: be5ddad5efe1409adbae0575d35ade6112b3be63
This commit is contained in:
Jeremy Reizenstein 2022-04-04 06:53:16 -07:00 committed by Facebook GitHub Bot
parent 6473aa316c
commit 199309fcf7
10 changed files with 65 additions and 35 deletions

View File

@ -8,6 +8,7 @@ import functools
import gzip
import hashlib
import json
import logging
import os
import random
import warnings
@ -43,6 +44,9 @@ from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_ba
from . import types
logger = logging.getLogger(__name__)
@dataclass
class FrameData(Mapping[str, Any]):
"""
@ -398,7 +402,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._sort_frames()
self._load_subset_lists()
self._filter_db() # also computes sequence indices
print(str(self))
logger.info(str(self))
def seq_frame_index_to_dataset_index(
self,
@ -674,7 +678,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
)
def _load_frames(self) -> None:
print(f"Loading Co3D frames from {self.frame_annotations_file}.")
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
local_file = self._local_path(self.frame_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass(
@ -687,7 +691,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
]
def _load_sequences(self) -> None:
print(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
local_file = self._local_path(self.sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
@ -696,7 +700,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None:
print(f"Loading Co3D subset lists from {self.subset_lists_file}.")
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
if not self.subset_lists_file:
return
@ -731,7 +735,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
def _filter_db(self) -> None:
if self.remove_empty_masks:
print("Removing images with empty masks.")
logger.info("Removing images with empty masks.")
old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
@ -749,7 +753,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
for frame in self.frame_annots
if positive_mass(frame["frame_annotation"])
]
print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
# this has to be called after joining with categories!!
subsets = self.subsets
@ -759,7 +763,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
"Subset filter is on but subset_lists_file was not given"
)
print(f"Limitting Co3D dataset to the '{subsets}' subsets.")
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
# truncate the list of subsets to the valid one
self.frame_annots = [
@ -771,7 +775,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._invalidate_indexes(filter_seq_annots=True)
if len(self.limit_category_to) > 0:
print(f"Limitting dataset to categories: {self.limit_category_to}")
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
@ -784,13 +788,13 @@ class ImplicitronDataset(ImplicitronDatasetBase):
attr = f"{prefix}_sequence"
arr = getattr(self, attr)
if len(arr) > 0:
print(f"{attr}: {str(arr)}")
logger.info(f"{attr}: {str(arr)}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if (name in arr) == (prefix == "pick")
}
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
if self.limit_sequences_to > 0:
self.seq_annots = dict(
@ -807,7 +811,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._invalidate_indexes()
if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
@ -818,13 +822,15 @@ class ImplicitronDataset(ImplicitronDatasetBase):
)
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
logger.info(
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
)
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
self._invalidate_indexes(filter_seq_annots=False)
# sequences are not decimated, so self.seq_annots is valid
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
print(
logger.info(
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
)
self.frame_annots = self.frame_annots[: self.limit_to]

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import logging
import math
import warnings
from dataclasses import field
@ -52,6 +53,7 @@ from .view_pooling.view_sampling import ViewSampler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
logger = logging.getLogger(__name__)
# pyre-ignore: 13
@ -274,7 +276,7 @@ class GenericModel(Configurable, torch.nn.Module):
self._implicit_functions = self._construct_implicit_functions()
self.print_loss_weights()
self.log_loss_weights()
def forward(
self,
@ -507,7 +509,7 @@ class GenericModel(Configurable, torch.nn.Module):
prefix: prepended to the names of images
"""
if not viz.check_connection():
print("no visdom server! -> skipping batch vis")
logger.info("no visdom server! -> skipping batch vis")
return
idx_image = 0
@ -662,14 +664,16 @@ class GenericModel(Configurable, torch.nn.Module):
]
return torch.nn.ModuleList(implicit_functions_list)
def print_loss_weights(self) -> None:
def log_loss_weights(self) -> None:
"""
Print a table of the loss weights.
"""
print("-------\nloss_weights:")
for k, w in self.loss_weights.items():
print(f"{k:40s}: {w:1.2e}")
print("-------")
loss_weights_message = (
"-------\nloss_weights:\n"
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
+ "-------"
)
logger.info(loss_weights_message)
def _preprocess_input(
self,

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import field
from typing import List, Optional
@ -18,6 +19,9 @@ from .base import ImplicitFunctionBase
from .utils import create_embeddings_for_implicit_function
logger = logging.getLogger(__name__)
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
n_harmonic_functions_xyz: int = 10
n_harmonic_functions_dir: int = 4
@ -384,7 +388,7 @@ class TransformerWithInputSkips(torch.nn.Module):
for layeri in range(n_layers):
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
print(f"Tr: {dimin} -> {dimout}")
logger.info(f"Tr: {dimin} -> {dimout}")
for _i, l in enumerate((layers_pool, layers_ray)):
l.append(
TransformerEncoderLayer(

View File

@ -183,7 +183,6 @@ def _rgb_metrics(
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
crop_mass = masks_crop.sum().clamp(1.0)
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
preds = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional, Tuple
import torch
@ -13,6 +14,9 @@ from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
logger = logging.getLogger(__name__)
@registry.register
class LSTMRenderer(BaseRenderer, torch.nn.Module):
"""
@ -28,7 +32,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
hidden_size: The dimensionality of the LSTM's hidden state.
n_feature_channels: The number of feature channels returned by the
implicit_function evaluated at each raymarching step.
verbose: If `True`, prints raymarching debug info.
verbose: If `True`, logs raymarching debug info.
References:
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
@ -110,8 +114,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
raymarch_features=None,
)
if self.verbose:
# print some stats
print(
msg = (
f"{t}: mu={float(signed_distance.mean()):1.2e};"
+ f" std={float(signed_distance.std()):1.2e};"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
@ -123,6 +126,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
# param but got `Tensor`.
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
)
logger.info(msg)
if t == self.num_raymarch_steps:
break

View File

@ -3,6 +3,7 @@
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
import logging
from typing import List, Tuple
import torch
@ -10,6 +11,9 @@ from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
logger = logging.getLogger(__name__)
class RayNormalColoringNetwork(torch.nn.Module):
def __init__(
self,
@ -36,7 +40,7 @@ class RayNormalColoringNetwork(torch.nn.Module):
dims_full[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0:
print("Pooled features in rendering network.")
logger.info("Pooled features in rendering network.")
dims_full[0] += pooled_feature_dim
self.num_layers = len(dims_full)

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import glob
import logging
import os
import shutil
import tempfile
@ -12,13 +13,16 @@ import tempfile
import torch
logger = logging.getLogger(__name__)
def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats
try:
stats = Stats.load(flstats)
except:
print("Cant load stats! %s" % flstats)
logger.info("Cant load stats! %s" % flstats)
stats = None
return stats
@ -59,7 +63,7 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
the moves. It is however quite improbable that a crash would occur right at
this time.
"""
print(f"saving model files safely to {fl}")
logger.info(f"saving model files safely to {fl}")
# first store everything to a tmpdir
with tempfile.TemporaryDirectory() as tmpdir:
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
@ -76,21 +80,20 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
if tgt_fl is None:
continue
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
shutil.move(tmpfl, tgt_fl)
def save_model(model, stats, fl, optimizer=None, cfg=None):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
print("saving model to %s" % flmodel)
logger.info("saving model to %s" % flmodel)
torch.save(model.state_dict(), flmodel)
flopt = None
if optimizer is not None:
flopt = get_optimizer_path(fl)
print("saving optimizer to %s" % flopt)
logger.info("saving optimizer to %s" % flopt)
torch.save(optimizer.state_dict(), flopt)
print("saving model stats to %s" % flstats)
logger.info("saving model stats to %s" % flstats)
stats.save(flstats)
return flstats, flmodel, flopt
@ -159,5 +162,5 @@ def purge_epoch(exp_dir, epoch) -> None:
get_stats_path(model_path),
]:
if os.path.isfile(file_path):
print("deleting %s" % file_path)
logger.info("deleting %s" % file_path)
os.remove(file_path)

View File

@ -4,12 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List
import torch
from visdom import Visdom
logger = logging.getLogger(__name__)
def get_visdom_env(cfg):
"""
Parse out visdom environment name from the input config.
@ -80,7 +84,7 @@ def visualize_basics(
imout = {}
for k in visualize_preds_keys:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
logger.info(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
@ -154,7 +158,7 @@ def make_depth_image(
for d, m in zip(depths, masks):
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
if ok.sum() <= 1:
print("empty depth!")
logger.info("empty depth!")
normfacs.append(torch.zeros(2).type_as(depths))
continue
dok = d.view(-1)[ok].view(-1)

View File

@ -10,6 +10,7 @@ import sys
import unittest
import unittest.mock
if os.environ.get("FB_TEST", False):
from common_testing import get_pytorch3d_dir
else:

View File

@ -24,6 +24,7 @@ from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg
else: