[callback] add torch profiler callback (#10463)

This commit is contained in:
浮梦
2026-05-20 20:47:52 +08:00
committed by GitHub
parent 8b5ea65770
commit 368c48968f
4 changed files with 320 additions and 6 deletions

View File

@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [
ModelArguments,
DataArguments,
TrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_ARGS = [
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
else:
_TRAIN_MCA_ARGS = []