resnet34 weights for remote executor

Summary: Like vgg16 for lpips, internally we need resnet34 weights for coming feature extractor tests.

Reviewed By: davnov134

Differential Revision: D36349361

fbshipit-source-id: 1c33009c904766fcc15e7e31cd15d0f820c57354
This commit is contained in:
Jeremy Reizenstein 2022-05-12 16:57:16 -07:00 committed by Facebook GitHub Bot
parent e767c4b548
commit 9e57b994ca

View File

@ -84,10 +84,9 @@ def get_skateboard_data(
yield CO3D_MANIFOLD_PATH, get_path_manager()
def provide_lpips_vgg():
def _provide_torchvision_weights(par_path: str, filename: str) -> None:
"""
Ensure the weights files are available for lpips.LPIPS(net="vgg")
to be called. Specifically, torchvision's vgg16
Ensure the weights files are available for a torchvision model.
"""
# In OSS, torchvision looks for vgg16 weights in
# https://download.pytorch.org/models/vgg16-397923af.pth
@ -126,11 +125,31 @@ def provide_lpips_vgg():
os.environ["FVCORE_CACHE"] = "iopath_cache"
par_path = "vgg_weights_for_lpips"
source = Path(get_file_path(par_path))
assert source.is_file()
dest = Path("iopath_cache/manifold_cache/tree/models")
if not dest.exists():
dest.mkdir(parents=True)
(dest / "vgg16-397923af.pth").symlink_to(source)
if not (dest / filename).is_symlink():
(dest / filename).symlink_to(source)
def provide_lpips_vgg() -> None:
"""
Ensure the weights files are available for lpips.LPIPS(net="vgg")
to be called. Specifically, torchvision's vgg16.
"""
_provide_torchvision_weights("vgg_weights_for_lpips", "vgg16-397923af.pth")
def provide_resnet34() -> None:
"""
Ensure the weights files are available for
torchvision.models.resnet34(pretrained=True)
to be called.
"""
_provide_torchvision_weights("resnet34_weights", "resnet34-b627a593.pth")