mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[callback] add torch profiler callback (#10463)
This commit is contained in:
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
|||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_ARGS = [
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
TrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
]
|
||||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
|||||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||||
|
|
||||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_MCA_ARGS = [
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
McaTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
]
|
||||||
_TRAIN_MCA_CLS = tuple[
|
_TRAIN_MCA_CLS = tuple[
|
||||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
McaTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
_TRAIN_MCA_ARGS = []
|
_TRAIN_MCA_ARGS = []
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
@@ -63,6 +64,58 @@ class RayArguments:
|
|||||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfilerArguments:
|
||||||
|
r"""Arguments for torch profiler configuration."""
|
||||||
|
|
||||||
|
enable_torch_profiler: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
|
||||||
|
)
|
||||||
|
profiler_output_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
|
||||||
|
)
|
||||||
|
profiler_wait_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
|
||||||
|
)
|
||||||
|
profiler_warmup_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of profiler warm-up steps per cycle."},
|
||||||
|
)
|
||||||
|
profiler_active_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of steps to actively record per cycle."},
|
||||||
|
)
|
||||||
|
profiler_repeat: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
|
||||||
|
)
|
||||||
|
profiler_record_shapes: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to record tensor shapes during profiling."},
|
||||||
|
)
|
||||||
|
profiler_profile_memory: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to profile memory usage."},
|
||||||
|
)
|
||||||
|
profiler_with_stack: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to record stack traces during profiling."},
|
||||||
|
)
|
||||||
|
profile_modules: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Comma-separated list of module name patterns to profile with CUDA events. "
|
||||||
|
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
|
||||||
|
"Reports per-module forward/backward timing statistics at each logging step."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Arguments:
|
class Fp8Arguments:
|
||||||
r"""Arguments pertaining to the FP8 training."""
|
r"""Arguments pertaining to the FP8 training."""
|
||||||
@@ -87,7 +140,7 @@ class Fp8Arguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
|
|||||||
@@ -12,12 +12,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
@@ -31,7 +34,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
|
||||||
from ..extras.packages import is_safetensors_available
|
from ..extras.packages import is_safetensors_available
|
||||||
|
|
||||||
|
|
||||||
@@ -338,6 +341,96 @@ class LogCallback(TrainerCallback):
|
|||||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchProfilerCallback(TrainerCallback):
|
||||||
|
r"""A callback for collecting torch.profiler traces during training.
|
||||||
|
|
||||||
|
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
|
||||||
|
|
||||||
|
Configuration fields (in YAML):
|
||||||
|
profiler_output_dir – where to write traces (default: <output_dir>/profiler)
|
||||||
|
profiler_wait_steps – steps to skip at start of each cycle (default: 1)
|
||||||
|
profiler_warmup_steps – profiler warm-up steps per cycle (default: 1)
|
||||||
|
profiler_active_steps – steps to record per cycle (default: 1)
|
||||||
|
profiler_repeat – number of cycles; 0 = forever (default: 1)
|
||||||
|
profiler_record_shapes – record tensor shapes (default: true)
|
||||||
|
profiler_profile_memory – profile memory usage (default: true)
|
||||||
|
profiler_with_stack – record stack traces (default: true)
|
||||||
|
|
||||||
|
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
|
||||||
|
``<profiler_output_dir>/rank_<N>/``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, training_args: "TrainingArguments") -> None:
|
||||||
|
self.profiler = None
|
||||||
|
self.profiler_args = training_args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_rank() -> int:
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_rank()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_begin(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
) -> None:
|
||||||
|
if self.profiler is not None:
|
||||||
|
self.profiler.stop()
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
|
pa = self.profiler_args
|
||||||
|
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
|
||||||
|
rank = self._get_rank()
|
||||||
|
trace_dir = os.path.join(output_dir, f"rank_{rank}")
|
||||||
|
os.makedirs(trace_dir, exist_ok=True)
|
||||||
|
|
||||||
|
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||||
|
try:
|
||||||
|
if is_torch_cuda_available():
|
||||||
|
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||||
|
if is_torch_npu_available():
|
||||||
|
activities.append(torch.profiler.ProfilerActivity.NPU)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=activities,
|
||||||
|
schedule=torch.profiler.schedule(
|
||||||
|
wait=pa.profiler_wait_steps,
|
||||||
|
warmup=pa.profiler_warmup_steps,
|
||||||
|
active=pa.profiler_active_steps,
|
||||||
|
repeat=pa.profiler_repeat,
|
||||||
|
),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||||
|
record_shapes=pa.profiler_record_shapes,
|
||||||
|
profile_memory=pa.profiler_profile_memory,
|
||||||
|
with_stack=pa.profiler_with_stack,
|
||||||
|
)
|
||||||
|
self.profiler.start()
|
||||||
|
logger.info_rank0(
|
||||||
|
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
|
||||||
|
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_step_end(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
) -> None:
|
||||||
|
if self.profiler is not None:
|
||||||
|
self.profiler.step()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_end(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
) -> None:
|
||||||
|
if self.profiler is not None:
|
||||||
|
self.profiler.stop()
|
||||||
|
self.profiler = None
|
||||||
|
logger.info_rank0("TorchProfiler stopped.")
|
||||||
|
|
||||||
|
|
||||||
class ReporterCallback(TrainerCallback):
|
class ReporterCallback(TrainerCallback):
|
||||||
r"""A callback for reporting training status to external logger."""
|
r"""A callback for reporting training status to external logger."""
|
||||||
|
|
||||||
@@ -394,3 +487,143 @@ class ReporterCallback(TrainerCallback):
|
|||||||
"generating_args": self.generating_args.to_dict(),
|
"generating_args": self.generating_args.to_dict(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleProfilerCallback(TrainerCallback):
|
||||||
|
r"""Profile forward/backward time of specified modules using accelerator events.
|
||||||
|
|
||||||
|
Hooks are registered on modules matching the user-provided name patterns.
|
||||||
|
Timing statistics are logged at each trainer logging step.
|
||||||
|
|
||||||
|
Usage in YAML config:
|
||||||
|
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
|
||||||
|
|
||||||
|
Supports fnmatch wildcards:
|
||||||
|
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_accelerator():
|
||||||
|
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
|
||||||
|
if is_torch_cuda_available():
|
||||||
|
return torch.cuda.Event, torch.cuda.synchronize
|
||||||
|
if is_torch_npu_available():
|
||||||
|
return torch.npu.Event, torch.npu.synchronize
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def __init__(self, profile_modules: str) -> None:
|
||||||
|
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
|
||||||
|
self._create_event, self._synchronize = self._get_accelerator()
|
||||||
|
self._handles: list[Any] = []
|
||||||
|
self._forward_times: dict[str, list[float]] = defaultdict(list)
|
||||||
|
self._backward_times: dict[str, list[float]] = defaultdict(list)
|
||||||
|
self._pending_forward: dict[str, tuple] = {}
|
||||||
|
self._pending_backward: dict[str, tuple] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self._create_event is not None
|
||||||
|
|
||||||
|
def _match(self, name: str) -> bool:
|
||||||
|
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
|
||||||
|
|
||||||
|
def _make_forward_pre_hook(self, name: str):
|
||||||
|
def hook(module, input):
|
||||||
|
start = self._create_event(enable_timing=True)
|
||||||
|
end = self._create_event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
self._pending_forward[name] = (start, end)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_forward_hook(self, name: str):
|
||||||
|
def hook(module, input, output):
|
||||||
|
pair = self._pending_forward.get(name)
|
||||||
|
if pair is not None:
|
||||||
|
pair[1].record()
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_backward_pre_hook(self, name: str):
|
||||||
|
def hook(module, grad_output):
|
||||||
|
start = self._create_event(enable_timing=True)
|
||||||
|
end = self._create_event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
self._pending_backward[name] = (start, end)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_backward_hook(self, name: str):
|
||||||
|
def hook(module, grad_input, grad_output):
|
||||||
|
pair = self._pending_backward.get(name)
|
||||||
|
if pair is not None:
|
||||||
|
pair[1].record()
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
if not self.enabled:
|
||||||
|
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model = kwargs.get("model")
|
||||||
|
if model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
matched = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not name or not self._match(name):
|
||||||
|
continue
|
||||||
|
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
|
||||||
|
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
|
||||||
|
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
|
||||||
|
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
|
||||||
|
matched.append(name)
|
||||||
|
|
||||||
|
if matched:
|
||||||
|
logger.info_rank0(
|
||||||
|
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||||
|
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._synchronize()
|
||||||
|
|
||||||
|
for name, (start, end) in self._pending_forward.items():
|
||||||
|
self._forward_times[name].append(start.elapsed_time(end))
|
||||||
|
self._pending_forward.clear()
|
||||||
|
|
||||||
|
for name, (start, end) in self._pending_backward.items():
|
||||||
|
self._backward_times[name].append(start.elapsed_time(end))
|
||||||
|
self._pending_backward.clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
if not self._forward_times and not self._backward_times:
|
||||||
|
return
|
||||||
|
|
||||||
|
lines = ["[ModuleProfiler] Timing (ms):"]
|
||||||
|
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
|
||||||
|
for name in all_names:
|
||||||
|
fwd = self._forward_times.get(name, [])
|
||||||
|
bwd = self._backward_times.get(name, [])
|
||||||
|
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||||
|
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||||
|
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
|
||||||
|
|
||||||
|
logger.info_rank0("\n".join(lines))
|
||||||
|
self._forward_times.clear()
|
||||||
|
self._backward_times.clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
for handle in self._handles:
|
||||||
|
handle.remove()
|
||||||
|
self._handles.clear()
|
||||||
|
|||||||
@@ -32,7 +32,13 @@ from ..extras.packages import (
|
|||||||
)
|
)
|
||||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import (
|
||||||
|
LogCallback,
|
||||||
|
ModuleProfilerCallback,
|
||||||
|
PissaConvertCallback,
|
||||||
|
ReporterCallback,
|
||||||
|
TorchProfilerCallback,
|
||||||
|
)
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
from .kto import run_kto
|
from .kto import run_kto
|
||||||
from .ppo import run_ppo
|
from .ppo import run_ppo
|
||||||
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
if finetuning_args.early_stopping_steps is not None:
|
if finetuning_args.early_stopping_steps is not None:
|
||||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||||
|
|
||||||
|
if training_args.enable_torch_profiler:
|
||||||
|
callbacks.append(TorchProfilerCallback(training_args))
|
||||||
|
|
||||||
|
if training_args.profile_modules:
|
||||||
|
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||||
|
|
||||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||||
|
|||||||
Reference in New Issue
Block a user