mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
remove get_task
Summary: Remove the dataset's need to provide the task type. Reviewed By: davnov134, kjchalup Differential Revision: D38314000 fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
37250a4326
commit
f8bf528043
@@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.666667
|
||||
- 0.833334
|
||||
is_multisequence: true
|
||||
|
||||
@@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
val_loader,
|
||||
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||
|
||||
task = self.data_source.get_task()
|
||||
all_train_cameras = self.data_source.all_train_cameras
|
||||
|
||||
# Enter the main training loop.
|
||||
@@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
exp_dir=self.exp_dir,
|
||||
stats=stats,
|
||||
seed=self.seed,
|
||||
task=task,
|
||||
)
|
||||
|
||||
def _check_config_consistent(self) -> None:
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
||||
@@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
exp_dir: str,
|
||||
stats: Stats,
|
||||
seed: int,
|
||||
task: Task,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
epoch=stats.epoch,
|
||||
exp_dir=exp_dir,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
return
|
||||
else:
|
||||
@@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
device=device,
|
||||
dataloader=test_loader,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
@@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
exp_dir=exp_dir,
|
||||
dataloader=test_loader,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.97
|
||||
- 0.98
|
||||
is_multisequence: false
|
||||
|
||||
Reference in New Issue
Block a user