mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
logging
Summary: Use logging instead of printing in the internals of implicitron. Reviewed By: davnov134 Differential Revision: D35247581 fbshipit-source-id: be5ddad5efe1409adbae0575d35ade6112b3be63
This commit is contained in:
parent
6473aa316c
commit
199309fcf7
@ -8,6 +8,7 @@ import functools
|
||||
import gzip
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
@ -43,6 +44,9 @@ from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_ba
|
||||
from . import types
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
@ -398,7 +402,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._sort_frames()
|
||||
self._load_subset_lists()
|
||||
self._filter_db() # also computes sequence indices
|
||||
print(str(self))
|
||||
logger.info(str(self))
|
||||
|
||||
def seq_frame_index_to_dataset_index(
|
||||
self,
|
||||
@ -674,7 +678,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
)
|
||||
|
||||
def _load_frames(self) -> None:
|
||||
print(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
||||
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
||||
local_file = self._local_path(self.frame_annotations_file)
|
||||
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
||||
frame_annots_list = types.load_dataclass(
|
||||
@ -687,7 +691,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
]
|
||||
|
||||
def _load_sequences(self) -> None:
|
||||
print(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
|
||||
logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
|
||||
local_file = self._local_path(self.sequence_annotations_file)
|
||||
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
||||
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
||||
@ -696,7 +700,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
||||
|
||||
def _load_subset_lists(self) -> None:
|
||||
print(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
||||
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
||||
if not self.subset_lists_file:
|
||||
return
|
||||
|
||||
@ -731,7 +735,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
|
||||
def _filter_db(self) -> None:
|
||||
if self.remove_empty_masks:
|
||||
print("Removing images with empty masks.")
|
||||
logger.info("Removing images with empty masks.")
|
||||
old_len = len(self.frame_annots)
|
||||
|
||||
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
||||
@ -749,7 +753,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
for frame in self.frame_annots
|
||||
if positive_mass(frame["frame_annotation"])
|
||||
]
|
||||
print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
||||
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
||||
|
||||
# this has to be called after joining with categories!!
|
||||
subsets = self.subsets
|
||||
@ -759,7 +763,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
"Subset filter is on but subset_lists_file was not given"
|
||||
)
|
||||
|
||||
print(f"Limitting Co3D dataset to the '{subsets}' subsets.")
|
||||
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
|
||||
|
||||
# truncate the list of subsets to the valid one
|
||||
self.frame_annots = [
|
||||
@ -771,7 +775,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._invalidate_indexes(filter_seq_annots=True)
|
||||
|
||||
if len(self.limit_category_to) > 0:
|
||||
print(f"Limitting dataset to categories: {self.limit_category_to}")
|
||||
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
||||
self.seq_annots = {
|
||||
name: entry
|
||||
for name, entry in self.seq_annots.items()
|
||||
@ -784,13 +788,13 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
attr = f"{prefix}_sequence"
|
||||
arr = getattr(self, attr)
|
||||
if len(arr) > 0:
|
||||
print(f"{attr}: {str(arr)}")
|
||||
logger.info(f"{attr}: {str(arr)}")
|
||||
self.seq_annots = {
|
||||
name: entry
|
||||
for name, entry in self.seq_annots.items()
|
||||
if (name in arr) == (prefix == "pick")
|
||||
}
|
||||
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
||||
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
||||
|
||||
if self.limit_sequences_to > 0:
|
||||
self.seq_annots = dict(
|
||||
@ -807,7 +811,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._invalidate_indexes()
|
||||
|
||||
if self.n_frames_per_sequence > 0:
|
||||
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||
keep_idx = []
|
||||
for seq, seq_indices in self._seq_to_idx.items():
|
||||
# infer the seed from the sequence name, this is reproducible
|
||||
@ -818,13 +822,15 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
)
|
||||
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
|
||||
|
||||
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
|
||||
logger.info(
|
||||
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
|
||||
)
|
||||
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
|
||||
self._invalidate_indexes(filter_seq_annots=False)
|
||||
# sequences are not decimated, so self.seq_annots is valid
|
||||
|
||||
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
|
||||
print(
|
||||
logger.info(
|
||||
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
|
||||
)
|
||||
self.frame_annots = self.frame_annots[: self.limit_to]
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
@ -52,6 +53,7 @@ from .view_pooling.view_sampling import ViewSampler
|
||||
|
||||
|
||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pyre-ignore: 13
|
||||
@ -274,7 +276,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
|
||||
self._implicit_functions = self._construct_implicit_functions()
|
||||
|
||||
self.print_loss_weights()
|
||||
self.log_loss_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -507,7 +509,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
prefix: prepended to the names of images
|
||||
"""
|
||||
if not viz.check_connection():
|
||||
print("no visdom server! -> skipping batch vis")
|
||||
logger.info("no visdom server! -> skipping batch vis")
|
||||
return
|
||||
|
||||
idx_image = 0
|
||||
@ -662,14 +664,16 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
]
|
||||
return torch.nn.ModuleList(implicit_functions_list)
|
||||
|
||||
def print_loss_weights(self) -> None:
|
||||
def log_loss_weights(self) -> None:
|
||||
"""
|
||||
Print a table of the loss weights.
|
||||
"""
|
||||
print("-------\nloss_weights:")
|
||||
for k, w in self.loss_weights.items():
|
||||
print(f"{k:40s}: {w:1.2e}")
|
||||
print("-------")
|
||||
loss_weights_message = (
|
||||
"-------\nloss_weights:\n"
|
||||
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
|
||||
+ "-------"
|
||||
)
|
||||
logger.info(loss_weights_message)
|
||||
|
||||
def _preprocess_input(
|
||||
self,
|
||||
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import field
|
||||
from typing import List, Optional
|
||||
|
||||
@ -18,6 +19,9 @@ from .base import ImplicitFunctionBase
|
||||
from .utils import create_embeddings_for_implicit_function
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
n_harmonic_functions_xyz: int = 10
|
||||
n_harmonic_functions_dir: int = 4
|
||||
@ -384,7 +388,7 @@ class TransformerWithInputSkips(torch.nn.Module):
|
||||
for layeri in range(n_layers):
|
||||
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
|
||||
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
||||
print(f"Tr: {dimin} -> {dimout}")
|
||||
logger.info(f"Tr: {dimin} -> {dimout}")
|
||||
for _i, l in enumerate((layers_pool, layers_ray)):
|
||||
l.append(
|
||||
TransformerEncoderLayer(
|
||||
|
@ -183,7 +183,6 @@ def _rgb_metrics(
|
||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||
crop_mass = masks_crop.sum().clamp(1.0)
|
||||
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
|
||||
preds = {
|
||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -13,6 +14,9 @@ from pytorch3d.renderer import RayBundle
|
||||
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
"""
|
||||
@ -28,7 +32,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
hidden_size: The dimensionality of the LSTM's hidden state.
|
||||
n_feature_channels: The number of feature channels returned by the
|
||||
implicit_function evaluated at each raymarching step.
|
||||
verbose: If `True`, prints raymarching debug info.
|
||||
verbose: If `True`, logs raymarching debug info.
|
||||
|
||||
References:
|
||||
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
|
||||
@ -110,8 +114,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
raymarch_features=None,
|
||||
)
|
||||
if self.verbose:
|
||||
# print some stats
|
||||
print(
|
||||
msg = (
|
||||
f"{t}: mu={float(signed_distance.mean()):1.2e};"
|
||||
+ f" std={float(signed_distance.std()):1.2e};"
|
||||
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
||||
@ -123,6 +126,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
# param but got `Tensor`.
|
||||
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
|
||||
)
|
||||
logger.info(msg)
|
||||
if t == self.num_raymarch_steps:
|
||||
break
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
# https://github.com/lioryariv/idr/
|
||||
# Copyright (c) 2020 Lior Yariv
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
@ -10,6 +11,9 @@ from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayNormalColoringNetwork(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -36,7 +40,7 @@ class RayNormalColoringNetwork(torch.nn.Module):
|
||||
dims_full[0] += self.embedview_fn.get_output_dim() - 3
|
||||
|
||||
if pooled_feature_dim > 0:
|
||||
print("Pooled features in rendering network.")
|
||||
logger.info("Pooled features in rendering network.")
|
||||
dims_full[0] += pooled_feature_dim
|
||||
|
||||
self.num_layers = len(dims_full)
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@ -12,13 +13,16 @@ import tempfile
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_stats(flstats):
|
||||
from pytorch3d.implicitron.tools.stats import Stats
|
||||
|
||||
try:
|
||||
stats = Stats.load(flstats)
|
||||
except:
|
||||
print("Cant load stats! %s" % flstats)
|
||||
logger.info("Cant load stats! %s" % flstats)
|
||||
stats = None
|
||||
return stats
|
||||
|
||||
@ -59,7 +63,7 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
|
||||
the moves. It is however quite improbable that a crash would occur right at
|
||||
this time.
|
||||
"""
|
||||
print(f"saving model files safely to {fl}")
|
||||
logger.info(f"saving model files safely to {fl}")
|
||||
# first store everything to a tmpdir
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
|
||||
@ -76,21 +80,20 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
|
||||
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
|
||||
if tgt_fl is None:
|
||||
continue
|
||||
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
|
||||
shutil.move(tmpfl, tgt_fl)
|
||||
|
||||
|
||||
def save_model(model, stats, fl, optimizer=None, cfg=None):
|
||||
flstats = get_stats_path(fl)
|
||||
flmodel = get_model_path(fl)
|
||||
print("saving model to %s" % flmodel)
|
||||
logger.info("saving model to %s" % flmodel)
|
||||
torch.save(model.state_dict(), flmodel)
|
||||
flopt = None
|
||||
if optimizer is not None:
|
||||
flopt = get_optimizer_path(fl)
|
||||
print("saving optimizer to %s" % flopt)
|
||||
logger.info("saving optimizer to %s" % flopt)
|
||||
torch.save(optimizer.state_dict(), flopt)
|
||||
print("saving model stats to %s" % flstats)
|
||||
logger.info("saving model stats to %s" % flstats)
|
||||
stats.save(flstats)
|
||||
|
||||
return flstats, flmodel, flopt
|
||||
@ -159,5 +162,5 @@ def purge_epoch(exp_dir, epoch) -> None:
|
||||
get_stats_path(model_path),
|
||||
]:
|
||||
if os.path.isfile(file_path):
|
||||
print("deleting %s" % file_path)
|
||||
logger.info("deleting %s" % file_path)
|
||||
os.remove(file_path)
|
||||
|
@ -4,12 +4,16 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from visdom import Visdom
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_visdom_env(cfg):
|
||||
"""
|
||||
Parse out visdom environment name from the input config.
|
||||
@ -80,7 +84,7 @@ def visualize_basics(
|
||||
imout = {}
|
||||
for k in visualize_preds_keys:
|
||||
if k not in preds or preds[k] is None:
|
||||
print(f"cant show {k}")
|
||||
logger.info(f"cant show {k}")
|
||||
continue
|
||||
v = preds[k].cpu().detach().clone()
|
||||
if k.startswith("depth"):
|
||||
@ -154,7 +158,7 @@ def make_depth_image(
|
||||
for d, m in zip(depths, masks):
|
||||
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
|
||||
if ok.sum() <= 1:
|
||||
print("empty depth!")
|
||||
logger.info("empty depth!")
|
||||
normfacs.append(torch.zeros(2).type_as(depths))
|
||||
continue
|
||||
dok = d.view(-1)[ok].view(-1)
|
||||
|
@ -10,6 +10,7 @@ import sys
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import get_pytorch3d_dir
|
||||
else:
|
||||
|
@ -24,6 +24,7 @@ from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from .common_resources import get_skateboard_data, provide_lpips_vgg
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user