diff --git a/tests/implicitron/common_resources.py b/tests/implicitron/common_resources.py index 0ddad5cd..41a83575 100644 --- a/tests/implicitron/common_resources.py +++ b/tests/implicitron/common_resources.py @@ -84,10 +84,9 @@ def get_skateboard_data( yield CO3D_MANIFOLD_PATH, get_path_manager() -def provide_lpips_vgg(): +def _provide_torchvision_weights(par_path: str, filename: str) -> None: """ - Ensure the weights files are available for lpips.LPIPS(net="vgg") - to be called. Specifically, torchvision's vgg16 + Ensure the weights files are available for a torchvision model. """ # In OSS, torchvision looks for vgg16 weights in # https://download.pytorch.org/models/vgg16-397923af.pth @@ -126,11 +125,31 @@ def provide_lpips_vgg(): os.environ["FVCORE_CACHE"] = "iopath_cache" - par_path = "vgg_weights_for_lpips" source = Path(get_file_path(par_path)) assert source.is_file() dest = Path("iopath_cache/manifold_cache/tree/models") if not dest.exists(): dest.mkdir(parents=True) - (dest / "vgg16-397923af.pth").symlink_to(source) + + if not (dest / filename).is_symlink(): + (dest / filename).symlink_to(source) + + +def provide_lpips_vgg() -> None: + """ + Ensure the weights files are available for lpips.LPIPS(net="vgg") + to be called. Specifically, torchvision's vgg16. + """ + _provide_torchvision_weights("vgg_weights_for_lpips", "vgg16-397923af.pth") + + +def provide_resnet34() -> None: + """ + Ensure the weights files are available for + + torchvision.models.resnet34(pretrained=True) + + to be called. + """ + _provide_torchvision_weights("resnet34_weights", "resnet34-b627a593.pth")