mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 10:58:54 +08:00
[callback] add torch profiler callback (#10463)
This commit is contained in:
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
TrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||
|
||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_MCA_ARGS = [
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
_TRAIN_MCA_CLS = tuple[
|
||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
McaTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
]
|
||||
else:
|
||||
_TRAIN_MCA_ARGS = []
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.training_args import _convert_str_dict
|
||||
@@ -63,6 +64,58 @@ class RayArguments:
|
||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilerArguments:
|
||||
r"""Arguments for torch profiler configuration."""
|
||||
|
||||
enable_torch_profiler: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
|
||||
)
|
||||
profiler_output_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
|
||||
)
|
||||
profiler_wait_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
|
||||
)
|
||||
profiler_warmup_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiler warm-up steps per cycle."},
|
||||
)
|
||||
profiler_active_steps: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of steps to actively record per cycle."},
|
||||
)
|
||||
profiler_repeat: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
|
||||
)
|
||||
profiler_record_shapes: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record tensor shapes during profiling."},
|
||||
)
|
||||
profiler_profile_memory: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to profile memory usage."},
|
||||
)
|
||||
profiler_with_stack: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to record stack traces during profiling."},
|
||||
)
|
||||
profile_modules: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Comma-separated list of module name patterns to profile with CUDA events. "
|
||||
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
|
||||
"Reports per-module forward/backward timing statistics at each logging step."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
@@ -87,7 +140,7 @@ class Fp8Arguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
Reference in New Issue
Block a user