chunked_inputs

Summary: Make method for SDF's use of object mask more general, so that a renderer can be given per-pixel values.

Reviewed By: shapovalov

Differential Revision: D35247412

fbshipit-source-id: 6aeccb1d0b5f1265a3f692a1453407a07e51a33c
This commit is contained in:
Jeremy Reizenstein
2022-04-26 08:01:45 -07:00
committed by Facebook GitHub Bot
parent 41c594ca37
commit 2edb93d184
2 changed files with 56 additions and 17 deletions

View File

@@ -9,13 +9,13 @@ import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
class TestGenericModel(unittest.TestCase):
def test_gm(self):
# Simple test of a forward pass of the default GenericModel.
# Simple test of a forward and backward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
@@ -51,6 +51,7 @@ class TestGenericModel(unittest.TestCase):
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
train_preds["objective"].backward()
model.eval()
with torch.no_grad():
@@ -65,3 +66,41 @@ class TestGenericModel(unittest.TestCase):
eval_preds["images_render"].shape,
(1, 3, model.render_image_height, model.render_image_width),
)
def test_idr(self):
# Forward pass of GenericModel with IDR.
device = torch.device("cuda:1")
args = get_default_args(GenericModel)
args.renderer_class_type = "SignedDistanceFunctionRenderer"
args.implicit_function_class_type = "IdrFeatureField"
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6
model = GenericModel(**args)
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
defaulted_args = {
"depth_map": None,
"mask_crop": None,
"sequence_name": None,
}
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
fg_probability = torch.rand(
(n_train_cameras, 1, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
fg_probability=fg_probability,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)