Summary: Use logging instead of printing in the internals of implicitron.

Reviewed By: davnov134

Differential Revision: D35247581

fbshipit-source-id: be5ddad5efe1409adbae0575d35ade6112b3be63
This commit is contained in:
Jeremy Reizenstein 2022-04-04 06:53:16 -07:00 committed by Facebook GitHub Bot
parent 6473aa316c
commit 199309fcf7
10 changed files with 65 additions and 35 deletions

View File

@ -8,6 +8,7 @@ import functools
import gzip import gzip
import hashlib import hashlib
import json import json
import logging
import os import os
import random import random
import warnings import warnings
@ -43,6 +44,9 @@ from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_ba
from . import types from . import types
logger = logging.getLogger(__name__)
@dataclass @dataclass
class FrameData(Mapping[str, Any]): class FrameData(Mapping[str, Any]):
""" """
@ -398,7 +402,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._sort_frames() self._sort_frames()
self._load_subset_lists() self._load_subset_lists()
self._filter_db() # also computes sequence indices self._filter_db() # also computes sequence indices
print(str(self)) logger.info(str(self))
def seq_frame_index_to_dataset_index( def seq_frame_index_to_dataset_index(
self, self,
@ -674,7 +678,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
) )
def _load_frames(self) -> None: def _load_frames(self) -> None:
print(f"Loading Co3D frames from {self.frame_annotations_file}.") logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
local_file = self._local_path(self.frame_annotations_file) local_file = self._local_path(self.frame_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile: with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass( frame_annots_list = types.load_dataclass(
@ -687,7 +691,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
] ]
def _load_sequences(self) -> None: def _load_sequences(self) -> None:
print(f"Loading Co3D sequences from {self.sequence_annotations_file}.") logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
local_file = self._local_path(self.sequence_annotations_file) local_file = self._local_path(self.sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile: with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
@ -696,7 +700,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None: def _load_subset_lists(self) -> None:
print(f"Loading Co3D subset lists from {self.subset_lists_file}.") logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
if not self.subset_lists_file: if not self.subset_lists_file:
return return
@ -731,7 +735,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
def _filter_db(self) -> None: def _filter_db(self) -> None:
if self.remove_empty_masks: if self.remove_empty_masks:
print("Removing images with empty masks.") logger.info("Removing images with empty masks.")
old_len = len(self.frame_annots) old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
@ -749,7 +753,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
for frame in self.frame_annots for frame in self.frame_annots
if positive_mass(frame["frame_annotation"]) if positive_mass(frame["frame_annotation"])
] ]
print("... filtered %d -> %d" % (old_len, len(self.frame_annots))) logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
# this has to be called after joining with categories!! # this has to be called after joining with categories!!
subsets = self.subsets subsets = self.subsets
@ -759,7 +763,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
"Subset filter is on but subset_lists_file was not given" "Subset filter is on but subset_lists_file was not given"
) )
print(f"Limitting Co3D dataset to the '{subsets}' subsets.") logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
# truncate the list of subsets to the valid one # truncate the list of subsets to the valid one
self.frame_annots = [ self.frame_annots = [
@ -771,7 +775,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._invalidate_indexes(filter_seq_annots=True) self._invalidate_indexes(filter_seq_annots=True)
if len(self.limit_category_to) > 0: if len(self.limit_category_to) > 0:
print(f"Limitting dataset to categories: {self.limit_category_to}") logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
self.seq_annots = { self.seq_annots = {
name: entry name: entry
for name, entry in self.seq_annots.items() for name, entry in self.seq_annots.items()
@ -784,13 +788,13 @@ class ImplicitronDataset(ImplicitronDatasetBase):
attr = f"{prefix}_sequence" attr = f"{prefix}_sequence"
arr = getattr(self, attr) arr = getattr(self, attr)
if len(arr) > 0: if len(arr) > 0:
print(f"{attr}: {str(arr)}") logger.info(f"{attr}: {str(arr)}")
self.seq_annots = { self.seq_annots = {
name: entry name: entry
for name, entry in self.seq_annots.items() for name, entry in self.seq_annots.items()
if (name in arr) == (prefix == "pick") if (name in arr) == (prefix == "pick")
} }
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
if self.limit_sequences_to > 0: if self.limit_sequences_to > 0:
self.seq_annots = dict( self.seq_annots = dict(
@ -807,7 +811,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
self._invalidate_indexes() self._invalidate_indexes()
if self.n_frames_per_sequence > 0: if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.") logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = [] keep_idx = []
for seq, seq_indices in self._seq_to_idx.items(): for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible # infer the seed from the sequence name, this is reproducible
@ -818,13 +822,15 @@ class ImplicitronDataset(ImplicitronDatasetBase):
) )
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))) logger.info(
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
)
self.frame_annots = [self.frame_annots[i] for i in keep_idx] self.frame_annots = [self.frame_annots[i] for i in keep_idx]
self._invalidate_indexes(filter_seq_annots=False) self._invalidate_indexes(filter_seq_annots=False)
# sequences are not decimated, so self.seq_annots is valid # sequences are not decimated, so self.seq_annots is valid
if self.limit_to > 0 and self.limit_to < len(self.frame_annots): if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
print( logger.info(
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
) )
self.frame_annots = self.frame_annots[: self.limit_to] self.frame_annots = self.frame_annots[: self.limit_to]

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
import math import math
import warnings import warnings
from dataclasses import field from dataclasses import field
@ -52,6 +53,7 @@ from .view_pooling.view_sampling import ViewSampler
STD_LOG_VARS = ["objective", "epoch", "sec/it"] STD_LOG_VARS = ["objective", "epoch", "sec/it"]
logger = logging.getLogger(__name__)
# pyre-ignore: 13 # pyre-ignore: 13
@ -274,7 +276,7 @@ class GenericModel(Configurable, torch.nn.Module):
self._implicit_functions = self._construct_implicit_functions() self._implicit_functions = self._construct_implicit_functions()
self.print_loss_weights() self.log_loss_weights()
def forward( def forward(
self, self,
@ -507,7 +509,7 @@ class GenericModel(Configurable, torch.nn.Module):
prefix: prepended to the names of images prefix: prepended to the names of images
""" """
if not viz.check_connection(): if not viz.check_connection():
print("no visdom server! -> skipping batch vis") logger.info("no visdom server! -> skipping batch vis")
return return
idx_image = 0 idx_image = 0
@ -662,14 +664,16 @@ class GenericModel(Configurable, torch.nn.Module):
] ]
return torch.nn.ModuleList(implicit_functions_list) return torch.nn.ModuleList(implicit_functions_list)
def print_loss_weights(self) -> None: def log_loss_weights(self) -> None:
""" """
Print a table of the loss weights. Print a table of the loss weights.
""" """
print("-------\nloss_weights:") loss_weights_message = (
for k, w in self.loss_weights.items(): "-------\nloss_weights:\n"
print(f"{k:40s}: {w:1.2e}") + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
print("-------") + "-------"
)
logger.info(loss_weights_message)
def _preprocess_input( def _preprocess_input(
self, self,

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from dataclasses import field from dataclasses import field
from typing import List, Optional from typing import List, Optional
@ -18,6 +19,9 @@ from .base import ImplicitFunctionBase
from .utils import create_embeddings_for_implicit_function from .utils import create_embeddings_for_implicit_function
logger = logging.getLogger(__name__)
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
n_harmonic_functions_xyz: int = 10 n_harmonic_functions_xyz: int = 10
n_harmonic_functions_dir: int = 4 n_harmonic_functions_dir: int = 4
@ -384,7 +388,7 @@ class TransformerWithInputSkips(torch.nn.Module):
for layeri in range(n_layers): for layeri in range(n_layers):
dimin = int(round(hidden_dim / (dim_down_factor ** layeri))) dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1)))) dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
print(f"Tr: {dimin} -> {dimout}") logger.info(f"Tr: {dimin} -> {dimout}")
for _i, l in enumerate((layers_pool, layers_ray)): for _i, l in enumerate((layers_pool, layers_ray)):
l.append( l.append(
TransformerEncoderLayer( TransformerEncoderLayer(

View File

@ -183,7 +183,6 @@ def _rgb_metrics(
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03) rgb_loss = utils.huber(rgb_squared, scaling=0.03)
crop_mass = masks_crop.sum().clamp(1.0) crop_mass = masks_crop.sum().clamp(1.0)
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
preds = { preds = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
@ -13,6 +14,9 @@ from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
logger = logging.getLogger(__name__)
@registry.register @registry.register
class LSTMRenderer(BaseRenderer, torch.nn.Module): class LSTMRenderer(BaseRenderer, torch.nn.Module):
""" """
@ -28,7 +32,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
hidden_size: The dimensionality of the LSTM's hidden state. hidden_size: The dimensionality of the LSTM's hidden state.
n_feature_channels: The number of feature channels returned by the n_feature_channels: The number of feature channels returned by the
implicit_function evaluated at each raymarching step. implicit_function evaluated at each raymarching step.
verbose: If `True`, prints raymarching debug info. verbose: If `True`, logs raymarching debug info.
References: References:
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G.. [1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
@ -110,8 +114,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
raymarch_features=None, raymarch_features=None,
) )
if self.verbose: if self.verbose:
# print some stats msg = (
print(
f"{t}: mu={float(signed_distance.mean()):1.2e};" f"{t}: mu={float(signed_distance.mean()):1.2e};"
+ f" std={float(signed_distance.std()):1.2e};" + f" std={float(signed_distance.std()):1.2e};"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str, # pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
@ -123,6 +126,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
# param but got `Tensor`. # param but got `Tensor`.
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};" + f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
) )
logger.info(msg)
if t == self.num_raymarch_steps: if t == self.num_raymarch_steps:
break break

View File

@ -3,6 +3,7 @@
# https://github.com/lioryariv/idr/ # https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv # Copyright (c) 2020 Lior Yariv
import logging
from typing import List, Tuple from typing import List, Tuple
import torch import torch
@ -10,6 +11,9 @@ from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn from torch import nn
logger = logging.getLogger(__name__)
class RayNormalColoringNetwork(torch.nn.Module): class RayNormalColoringNetwork(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -36,7 +40,7 @@ class RayNormalColoringNetwork(torch.nn.Module):
dims_full[0] += self.embedview_fn.get_output_dim() - 3 dims_full[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0: if pooled_feature_dim > 0:
print("Pooled features in rendering network.") logger.info("Pooled features in rendering network.")
dims_full[0] += pooled_feature_dim dims_full[0] += pooled_feature_dim
self.num_layers = len(dims_full) self.num_layers = len(dims_full)

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import glob import glob
import logging
import os import os
import shutil import shutil
import tempfile import tempfile
@ -12,13 +13,16 @@ import tempfile
import torch import torch
logger = logging.getLogger(__name__)
def load_stats(flstats): def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.implicitron.tools.stats import Stats
try: try:
stats = Stats.load(flstats) stats = Stats.load(flstats)
except: except:
print("Cant load stats! %s" % flstats) logger.info("Cant load stats! %s" % flstats)
stats = None stats = None
return stats return stats
@ -59,7 +63,7 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
the moves. It is however quite improbable that a crash would occur right at the moves. It is however quite improbable that a crash would occur right at
this time. this time.
""" """
print(f"saving model files safely to {fl}") logger.info(f"saving model files safely to {fl}")
# first store everything to a tmpdir # first store everything to a tmpdir
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1]) tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
@ -76,21 +80,20 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls): for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
if tgt_fl is None: if tgt_fl is None:
continue continue
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
shutil.move(tmpfl, tgt_fl) shutil.move(tmpfl, tgt_fl)
def save_model(model, stats, fl, optimizer=None, cfg=None): def save_model(model, stats, fl, optimizer=None, cfg=None):
flstats = get_stats_path(fl) flstats = get_stats_path(fl)
flmodel = get_model_path(fl) flmodel = get_model_path(fl)
print("saving model to %s" % flmodel) logger.info("saving model to %s" % flmodel)
torch.save(model.state_dict(), flmodel) torch.save(model.state_dict(), flmodel)
flopt = None flopt = None
if optimizer is not None: if optimizer is not None:
flopt = get_optimizer_path(fl) flopt = get_optimizer_path(fl)
print("saving optimizer to %s" % flopt) logger.info("saving optimizer to %s" % flopt)
torch.save(optimizer.state_dict(), flopt) torch.save(optimizer.state_dict(), flopt)
print("saving model stats to %s" % flstats) logger.info("saving model stats to %s" % flstats)
stats.save(flstats) stats.save(flstats)
return flstats, flmodel, flopt return flstats, flmodel, flopt
@ -159,5 +162,5 @@ def purge_epoch(exp_dir, epoch) -> None:
get_stats_path(model_path), get_stats_path(model_path),
]: ]:
if os.path.isfile(file_path): if os.path.isfile(file_path):
print("deleting %s" % file_path) logger.info("deleting %s" % file_path)
os.remove(file_path) os.remove(file_path)

View File

@ -4,12 +4,16 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from visdom import Visdom from visdom import Visdom
logger = logging.getLogger(__name__)
def get_visdom_env(cfg): def get_visdom_env(cfg):
""" """
Parse out visdom environment name from the input config. Parse out visdom environment name from the input config.
@ -80,7 +84,7 @@ def visualize_basics(
imout = {} imout = {}
for k in visualize_preds_keys: for k in visualize_preds_keys:
if k not in preds or preds[k] is None: if k not in preds or preds[k] is None:
print(f"cant show {k}") logger.info(f"cant show {k}")
continue continue
v = preds[k].cpu().detach().clone() v = preds[k].cpu().detach().clone()
if k.startswith("depth"): if k.startswith("depth"):
@ -154,7 +158,7 @@ def make_depth_image(
for d, m in zip(depths, masks): for d, m in zip(depths, masks):
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5) ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
if ok.sum() <= 1: if ok.sum() <= 1:
print("empty depth!") logger.info("empty depth!")
normfacs.append(torch.zeros(2).type_as(depths)) normfacs.append(torch.zeros(2).type_as(depths))
continue continue
dok = d.view(-1)[ok].view(-1) dok = d.view(-1)[ok].view(-1)

View File

@ -10,6 +10,7 @@ import sys
import unittest import unittest
import unittest.mock import unittest.mock
if os.environ.get("FB_TEST", False): if os.environ.get("FB_TEST", False):
from common_testing import get_pytorch3d_dir from common_testing import get_pytorch3d_dir
else: else:

View File

@ -24,6 +24,7 @@ from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
if os.environ.get("FB_TEST", False): if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg from .common_resources import get_skateboard_data, provide_lpips_vgg
else: else: