mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
logging
Summary: Use logging instead of printing in the internals of implicitron. Reviewed By: davnov134 Differential Revision: D35247581 fbshipit-source-id: be5ddad5efe1409adbae0575d35ade6112b3be63
This commit is contained in:
parent
6473aa316c
commit
199309fcf7
@ -8,6 +8,7 @@ import functools
|
|||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
@ -43,6 +44,9 @@ from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_ba
|
|||||||
from . import types
|
from . import types
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrameData(Mapping[str, Any]):
|
class FrameData(Mapping[str, Any]):
|
||||||
"""
|
"""
|
||||||
@ -398,7 +402,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
self._sort_frames()
|
self._sort_frames()
|
||||||
self._load_subset_lists()
|
self._load_subset_lists()
|
||||||
self._filter_db() # also computes sequence indices
|
self._filter_db() # also computes sequence indices
|
||||||
print(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
def seq_frame_index_to_dataset_index(
|
def seq_frame_index_to_dataset_index(
|
||||||
self,
|
self,
|
||||||
@ -674,7 +678,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _load_frames(self) -> None:
|
def _load_frames(self) -> None:
|
||||||
print(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
|
||||||
local_file = self._local_path(self.frame_annotations_file)
|
local_file = self._local_path(self.frame_annotations_file)
|
||||||
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
||||||
frame_annots_list = types.load_dataclass(
|
frame_annots_list = types.load_dataclass(
|
||||||
@ -687,7 +691,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _load_sequences(self) -> None:
|
def _load_sequences(self) -> None:
|
||||||
print(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
|
logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
|
||||||
local_file = self._local_path(self.sequence_annotations_file)
|
local_file = self._local_path(self.sequence_annotations_file)
|
||||||
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
|
||||||
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
||||||
@ -696,7 +700,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
||||||
|
|
||||||
def _load_subset_lists(self) -> None:
|
def _load_subset_lists(self) -> None:
|
||||||
print(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
|
||||||
if not self.subset_lists_file:
|
if not self.subset_lists_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -731,7 +735,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
|
|
||||||
def _filter_db(self) -> None:
|
def _filter_db(self) -> None:
|
||||||
if self.remove_empty_masks:
|
if self.remove_empty_masks:
|
||||||
print("Removing images with empty masks.")
|
logger.info("Removing images with empty masks.")
|
||||||
old_len = len(self.frame_annots)
|
old_len = len(self.frame_annots)
|
||||||
|
|
||||||
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
||||||
@ -749,7 +753,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
for frame in self.frame_annots
|
for frame in self.frame_annots
|
||||||
if positive_mass(frame["frame_annotation"])
|
if positive_mass(frame["frame_annotation"])
|
||||||
]
|
]
|
||||||
print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
|
||||||
|
|
||||||
# this has to be called after joining with categories!!
|
# this has to be called after joining with categories!!
|
||||||
subsets = self.subsets
|
subsets = self.subsets
|
||||||
@ -759,7 +763,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
"Subset filter is on but subset_lists_file was not given"
|
"Subset filter is on but subset_lists_file was not given"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Limitting Co3D dataset to the '{subsets}' subsets.")
|
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
|
||||||
|
|
||||||
# truncate the list of subsets to the valid one
|
# truncate the list of subsets to the valid one
|
||||||
self.frame_annots = [
|
self.frame_annots = [
|
||||||
@ -771,7 +775,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
self._invalidate_indexes(filter_seq_annots=True)
|
self._invalidate_indexes(filter_seq_annots=True)
|
||||||
|
|
||||||
if len(self.limit_category_to) > 0:
|
if len(self.limit_category_to) > 0:
|
||||||
print(f"Limitting dataset to categories: {self.limit_category_to}")
|
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
||||||
self.seq_annots = {
|
self.seq_annots = {
|
||||||
name: entry
|
name: entry
|
||||||
for name, entry in self.seq_annots.items()
|
for name, entry in self.seq_annots.items()
|
||||||
@ -784,13 +788,13 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
attr = f"{prefix}_sequence"
|
attr = f"{prefix}_sequence"
|
||||||
arr = getattr(self, attr)
|
arr = getattr(self, attr)
|
||||||
if len(arr) > 0:
|
if len(arr) > 0:
|
||||||
print(f"{attr}: {str(arr)}")
|
logger.info(f"{attr}: {str(arr)}")
|
||||||
self.seq_annots = {
|
self.seq_annots = {
|
||||||
name: entry
|
name: entry
|
||||||
for name, entry in self.seq_annots.items()
|
for name, entry in self.seq_annots.items()
|
||||||
if (name in arr) == (prefix == "pick")
|
if (name in arr) == (prefix == "pick")
|
||||||
}
|
}
|
||||||
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
|
||||||
|
|
||||||
if self.limit_sequences_to > 0:
|
if self.limit_sequences_to > 0:
|
||||||
self.seq_annots = dict(
|
self.seq_annots = dict(
|
||||||
@ -807,7 +811,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
self._invalidate_indexes()
|
self._invalidate_indexes()
|
||||||
|
|
||||||
if self.n_frames_per_sequence > 0:
|
if self.n_frames_per_sequence > 0:
|
||||||
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||||
keep_idx = []
|
keep_idx = []
|
||||||
for seq, seq_indices in self._seq_to_idx.items():
|
for seq, seq_indices in self._seq_to_idx.items():
|
||||||
# infer the seed from the sequence name, this is reproducible
|
# infer the seed from the sequence name, this is reproducible
|
||||||
@ -818,13 +822,15 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
)
|
)
|
||||||
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
|
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
|
||||||
|
|
||||||
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
|
logger.info(
|
||||||
|
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
|
||||||
|
)
|
||||||
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
|
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
|
||||||
self._invalidate_indexes(filter_seq_annots=False)
|
self._invalidate_indexes(filter_seq_annots=False)
|
||||||
# sequences are not decimated, so self.seq_annots is valid
|
# sequences are not decimated, so self.seq_annots is valid
|
||||||
|
|
||||||
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
|
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
|
||||||
print(
|
logger.info(
|
||||||
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
|
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
|
||||||
)
|
)
|
||||||
self.frame_annots = self.frame_annots[: self.limit_to]
|
self.frame_annots = self.frame_annots[: self.limit_to]
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
@ -52,6 +53,7 @@ from .view_pooling.view_sampling import ViewSampler
|
|||||||
|
|
||||||
|
|
||||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# pyre-ignore: 13
|
# pyre-ignore: 13
|
||||||
@ -274,7 +276,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
self._implicit_functions = self._construct_implicit_functions()
|
self._implicit_functions = self._construct_implicit_functions()
|
||||||
|
|
||||||
self.print_loss_weights()
|
self.log_loss_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -507,7 +509,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
prefix: prepended to the names of images
|
prefix: prepended to the names of images
|
||||||
"""
|
"""
|
||||||
if not viz.check_connection():
|
if not viz.check_connection():
|
||||||
print("no visdom server! -> skipping batch vis")
|
logger.info("no visdom server! -> skipping batch vis")
|
||||||
return
|
return
|
||||||
|
|
||||||
idx_image = 0
|
idx_image = 0
|
||||||
@ -662,14 +664,16 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
]
|
]
|
||||||
return torch.nn.ModuleList(implicit_functions_list)
|
return torch.nn.ModuleList(implicit_functions_list)
|
||||||
|
|
||||||
def print_loss_weights(self) -> None:
|
def log_loss_weights(self) -> None:
|
||||||
"""
|
"""
|
||||||
Print a table of the loss weights.
|
Print a table of the loss weights.
|
||||||
"""
|
"""
|
||||||
print("-------\nloss_weights:")
|
loss_weights_message = (
|
||||||
for k, w in self.loss_weights.items():
|
"-------\nloss_weights:\n"
|
||||||
print(f"{k:40s}: {w:1.2e}")
|
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
|
||||||
print("-------")
|
+ "-------"
|
||||||
|
)
|
||||||
|
logger.info(loss_weights_message)
|
||||||
|
|
||||||
def _preprocess_input(
|
def _preprocess_input(
|
||||||
self,
|
self,
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@ -18,6 +19,9 @@ from .base import ImplicitFunctionBase
|
|||||||
from .utils import create_embeddings_for_implicit_function
|
from .utils import create_embeddings_for_implicit_function
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||||
n_harmonic_functions_xyz: int = 10
|
n_harmonic_functions_xyz: int = 10
|
||||||
n_harmonic_functions_dir: int = 4
|
n_harmonic_functions_dir: int = 4
|
||||||
@ -384,7 +388,7 @@ class TransformerWithInputSkips(torch.nn.Module):
|
|||||||
for layeri in range(n_layers):
|
for layeri in range(n_layers):
|
||||||
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
|
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
|
||||||
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
||||||
print(f"Tr: {dimin} -> {dimout}")
|
logger.info(f"Tr: {dimin} -> {dimout}")
|
||||||
for _i, l in enumerate((layers_pool, layers_ray)):
|
for _i, l in enumerate((layers_pool, layers_ray)):
|
||||||
l.append(
|
l.append(
|
||||||
TransformerEncoderLayer(
|
TransformerEncoderLayer(
|
||||||
|
@ -183,7 +183,6 @@ def _rgb_metrics(
|
|||||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||||
crop_mass = masks_crop.sum().clamp(1.0)
|
crop_mass = masks_crop.sum().clamp(1.0)
|
||||||
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
|
|
||||||
preds = {
|
preds = {
|
||||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,6 +14,9 @@ from pytorch3d.renderer import RayBundle
|
|||||||
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -28,7 +32,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
hidden_size: The dimensionality of the LSTM's hidden state.
|
hidden_size: The dimensionality of the LSTM's hidden state.
|
||||||
n_feature_channels: The number of feature channels returned by the
|
n_feature_channels: The number of feature channels returned by the
|
||||||
implicit_function evaluated at each raymarching step.
|
implicit_function evaluated at each raymarching step.
|
||||||
verbose: If `True`, prints raymarching debug info.
|
verbose: If `True`, logs raymarching debug info.
|
||||||
|
|
||||||
References:
|
References:
|
||||||
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
|
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
|
||||||
@ -110,8 +114,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
raymarch_features=None,
|
raymarch_features=None,
|
||||||
)
|
)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
# print some stats
|
msg = (
|
||||||
print(
|
|
||||||
f"{t}: mu={float(signed_distance.mean()):1.2e};"
|
f"{t}: mu={float(signed_distance.mean()):1.2e};"
|
||||||
+ f" std={float(signed_distance.std()):1.2e};"
|
+ f" std={float(signed_distance.std()):1.2e};"
|
||||||
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
||||||
@ -123,6 +126,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
# param but got `Tensor`.
|
# param but got `Tensor`.
|
||||||
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
|
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
|
||||||
)
|
)
|
||||||
|
logger.info(msg)
|
||||||
if t == self.num_raymarch_steps:
|
if t == self.num_raymarch_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
# https://github.com/lioryariv/idr/
|
# https://github.com/lioryariv/idr/
|
||||||
# Copyright (c) 2020 Lior Yariv
|
# Copyright (c) 2020 Lior Yariv
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,6 +11,9 @@ from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RayNormalColoringNetwork(torch.nn.Module):
|
class RayNormalColoringNetwork(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -36,7 +40,7 @@ class RayNormalColoringNetwork(torch.nn.Module):
|
|||||||
dims_full[0] += self.embedview_fn.get_output_dim() - 3
|
dims_full[0] += self.embedview_fn.get_output_dim() - 3
|
||||||
|
|
||||||
if pooled_feature_dim > 0:
|
if pooled_feature_dim > 0:
|
||||||
print("Pooled features in rendering network.")
|
logger.info("Pooled features in rendering network.")
|
||||||
dims_full[0] += pooled_feature_dim
|
dims_full[0] += pooled_feature_dim
|
||||||
|
|
||||||
self.num_layers = len(dims_full)
|
self.num_layers = len(dims_full)
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -12,13 +13,16 @@ import tempfile
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_stats(flstats):
|
def load_stats(flstats):
|
||||||
from pytorch3d.implicitron.tools.stats import Stats
|
from pytorch3d.implicitron.tools.stats import Stats
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stats = Stats.load(flstats)
|
stats = Stats.load(flstats)
|
||||||
except:
|
except:
|
||||||
print("Cant load stats! %s" % flstats)
|
logger.info("Cant load stats! %s" % flstats)
|
||||||
stats = None
|
stats = None
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@ -59,7 +63,7 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
|
|||||||
the moves. It is however quite improbable that a crash would occur right at
|
the moves. It is however quite improbable that a crash would occur right at
|
||||||
this time.
|
this time.
|
||||||
"""
|
"""
|
||||||
print(f"saving model files safely to {fl}")
|
logger.info(f"saving model files safely to {fl}")
|
||||||
# first store everything to a tmpdir
|
# first store everything to a tmpdir
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
|
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
|
||||||
@ -76,21 +80,20 @@ def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
|
|||||||
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
|
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
|
||||||
if tgt_fl is None:
|
if tgt_fl is None:
|
||||||
continue
|
continue
|
||||||
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
|
|
||||||
shutil.move(tmpfl, tgt_fl)
|
shutil.move(tmpfl, tgt_fl)
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, stats, fl, optimizer=None, cfg=None):
|
def save_model(model, stats, fl, optimizer=None, cfg=None):
|
||||||
flstats = get_stats_path(fl)
|
flstats = get_stats_path(fl)
|
||||||
flmodel = get_model_path(fl)
|
flmodel = get_model_path(fl)
|
||||||
print("saving model to %s" % flmodel)
|
logger.info("saving model to %s" % flmodel)
|
||||||
torch.save(model.state_dict(), flmodel)
|
torch.save(model.state_dict(), flmodel)
|
||||||
flopt = None
|
flopt = None
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
flopt = get_optimizer_path(fl)
|
flopt = get_optimizer_path(fl)
|
||||||
print("saving optimizer to %s" % flopt)
|
logger.info("saving optimizer to %s" % flopt)
|
||||||
torch.save(optimizer.state_dict(), flopt)
|
torch.save(optimizer.state_dict(), flopt)
|
||||||
print("saving model stats to %s" % flstats)
|
logger.info("saving model stats to %s" % flstats)
|
||||||
stats.save(flstats)
|
stats.save(flstats)
|
||||||
|
|
||||||
return flstats, flmodel, flopt
|
return flstats, flmodel, flopt
|
||||||
@ -159,5 +162,5 @@ def purge_epoch(exp_dir, epoch) -> None:
|
|||||||
get_stats_path(model_path),
|
get_stats_path(model_path),
|
||||||
]:
|
]:
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
print("deleting %s" % file_path)
|
logger.info("deleting %s" % file_path)
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
|
@ -4,12 +4,16 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_visdom_env(cfg):
|
def get_visdom_env(cfg):
|
||||||
"""
|
"""
|
||||||
Parse out visdom environment name from the input config.
|
Parse out visdom environment name from the input config.
|
||||||
@ -80,7 +84,7 @@ def visualize_basics(
|
|||||||
imout = {}
|
imout = {}
|
||||||
for k in visualize_preds_keys:
|
for k in visualize_preds_keys:
|
||||||
if k not in preds or preds[k] is None:
|
if k not in preds or preds[k] is None:
|
||||||
print(f"cant show {k}")
|
logger.info(f"cant show {k}")
|
||||||
continue
|
continue
|
||||||
v = preds[k].cpu().detach().clone()
|
v = preds[k].cpu().detach().clone()
|
||||||
if k.startswith("depth"):
|
if k.startswith("depth"):
|
||||||
@ -154,7 +158,7 @@ def make_depth_image(
|
|||||||
for d, m in zip(depths, masks):
|
for d, m in zip(depths, masks):
|
||||||
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
|
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
|
||||||
if ok.sum() <= 1:
|
if ok.sum() <= 1:
|
||||||
print("empty depth!")
|
logger.info("empty depth!")
|
||||||
normfacs.append(torch.zeros(2).type_as(depths))
|
normfacs.append(torch.zeros(2).type_as(depths))
|
||||||
continue
|
continue
|
||||||
dok = d.view(-1)[ok].view(-1)
|
dok = d.view(-1)[ok].view(-1)
|
||||||
|
@ -10,6 +10,7 @@ import sys
|
|||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
|
|
||||||
|
|
||||||
if os.environ.get("FB_TEST", False):
|
if os.environ.get("FB_TEST", False):
|
||||||
from common_testing import get_pytorch3d_dir
|
from common_testing import get_pytorch3d_dir
|
||||||
else:
|
else:
|
||||||
|
@ -24,6 +24,7 @@ from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
|||||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||||
|
|
||||||
|
|
||||||
if os.environ.get("FB_TEST", False):
|
if os.environ.get("FB_TEST", False):
|
||||||
from .common_resources import get_skateboard_data, provide_lpips_vgg
|
from .common_resources import get_skateboard_data, provide_lpips_vgg
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user