mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[callback] add torch profiler callback (#10463)
This commit is contained in:
@@ -32,7 +32,13 @@ from ..extras.packages import (
|
||||
)
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
from .callbacks import (
|
||||
LogCallback,
|
||||
ModuleProfilerCallback,
|
||||
PissaConvertCallback,
|
||||
ReporterCallback,
|
||||
TorchProfilerCallback,
|
||||
)
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .ppo import run_ppo
|
||||
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
if finetuning_args.early_stopping_steps is not None:
|
||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||
|
||||
if training_args.enable_torch_profiler:
|
||||
callbacks.append(TorchProfilerCallback(training_args))
|
||||
|
||||
if training_args.profile_modules:
|
||||
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
|
||||
Reference in New Issue
Block a user