remove get_task

Summary: Remove the dataset's need to provide the task type.

Reviewed By: davnov134, kjchalup

Differential Revision: D38314000

fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
This commit is contained in:
Jeremy Reizenstein
2022-08-02 07:55:42 -07:00
committed by Facebook GitHub Bot
parent 37250a4326
commit f8bf528043
13 changed files with 36 additions and 83 deletions

View File

@@ -10,7 +10,6 @@ from typing import Any, Optional
import torch
from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode
@@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir: str,
stats: Stats,
seed: int,
task: Task,
**kwargs,
):
"""
@@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
epoch=stats.epoch,
exp_dir=exp_dir,
model=model,
task=task,
)
return
else:
@@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
device=device,
dataloader=test_loader,
model=model,
task=task,
)
assert stats.epoch == epoch, "inconsistent stats!"
@@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir=exp_dir,
dataloader=test_loader,
model=model,
task=task,
)
else:
raise ValueError(