remove get_task

Summary: Remove the dataset's need to provide the task type.

Reviewed By: davnov134, kjchalup

Differential Revision: D38314000

fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
Jeremy Reizenstein
2022-08-02 07:55:42 -07:00
committed by Facebook GitHub Bot
parent 37250a4326
commit f8bf528043
13 changed files with 36 additions and 83 deletions

View File

@@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334
is_multisequence: true

View File

@@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
val_loader,
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
task = self.data_source.get_task()
all_train_cameras = self.data_source.all_train_cameras
# Enter the main training loop.
@@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir,
stats=stats,
seed=self.seed,
task=task,
)
def _check_config_consistent(self) -> None:

View File

@@ -10,7 +10,6 @@ from typing import Any, Optional
import torch
from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode
@@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir: str,
stats: Stats,
seed: int,
task: Task,
**kwargs,
):
"""
@@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
epoch=stats.epoch,
exp_dir=exp_dir,
model=model,
task=task,
)
return
else:
@@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
device=device,
dataloader=test_loader,
model=model,
task=task,
)
assert stats.epoch == epoch, "inconsistent stats!"
@@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir=exp_dir,
dataloader=test_loader,
model=model,
task=task,
)
else:
raise ValueError(

View File

@@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks:
- 0.97
- 0.98
is_multisequence: false