mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 10:58:54 +08:00
[callback] add torch profiler callback (#10463)
This commit is contained in:
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
TrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||
|
||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_MCA_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_MCA_CLS = tuple[
|
||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
else:
|
||||
_TRAIN_MCA_ARGS = []
|
||||
|
||||
Reference in New Issue
Block a user