mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 12:48:55 +08:00
Compare commits
3 Commits
40e786d016
...
2322bf1cc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2322bf1cc2 | ||
|
|
368c48968f | ||
|
|
8b5ea65770 |
@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
|
|||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
||||||
|
|
||||||
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_ARGS = [
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
TrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
]
|
||||||
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||||
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
|||||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||||
|
|
||||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_MCA_ARGS = [
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
McaTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
]
|
||||||
_TRAIN_MCA_CLS = tuple[
|
_TRAIN_MCA_CLS = tuple[
|
||||||
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
McaTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
_TRAIN_MCA_ARGS = []
|
_TRAIN_MCA_ARGS = []
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
@@ -63,6 +64,58 @@ class RayArguments:
|
|||||||
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfilerArguments:
|
||||||
|
r"""Arguments for torch profiler configuration."""
|
||||||
|
|
||||||
|
enable_torch_profiler: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
|
||||||
|
)
|
||||||
|
profiler_output_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
|
||||||
|
)
|
||||||
|
profiler_wait_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
|
||||||
|
)
|
||||||
|
profiler_warmup_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of profiler warm-up steps per cycle."},
|
||||||
|
)
|
||||||
|
profiler_active_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of steps to actively record per cycle."},
|
||||||
|
)
|
||||||
|
profiler_repeat: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
|
||||||
|
)
|
||||||
|
profiler_record_shapes: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to record tensor shapes during profiling."},
|
||||||
|
)
|
||||||
|
profiler_profile_memory: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to profile memory usage."},
|
||||||
|
)
|
||||||
|
profiler_with_stack: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to record stack traces during profiling."},
|
||||||
|
)
|
||||||
|
profile_modules: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Comma-separated list of module name patterns to profile with CUDA events. "
|
||||||
|
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
|
||||||
|
"Reports per-module forward/backward timing statistics at each logging step."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Arguments:
|
class Fp8Arguments:
|
||||||
r"""Arguments pertaining to the FP8 training."""
|
r"""Arguments pertaining to the FP8 training."""
|
||||||
@@ -87,7 +140,7 @@ class Fp8Arguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
|
|||||||
@@ -12,12 +12,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
@@ -31,7 +34,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
|
||||||
from ..extras.packages import is_safetensors_available
|
from ..extras.packages import is_safetensors_available
|
||||||
|
|
||||||
|
|
||||||
@@ -338,6 +341,96 @@ class LogCallback(TrainerCallback):
|
|||||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchProfilerCallback(TrainerCallback):
|
||||||
|
r"""A callback for collecting torch.profiler traces during training.
|
||||||
|
|
||||||
|
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
|
||||||
|
|
||||||
|
Configuration fields (in YAML):
|
||||||
|
profiler_output_dir – where to write traces (default: <output_dir>/profiler)
|
||||||
|
profiler_wait_steps – steps to skip at start of each cycle (default: 1)
|
||||||
|
profiler_warmup_steps – profiler warm-up steps per cycle (default: 1)
|
||||||
|
profiler_active_steps – steps to record per cycle (default: 1)
|
||||||
|
profiler_repeat – number of cycles; 0 = forever (default: 1)
|
||||||
|
profiler_record_shapes – record tensor shapes (default: true)
|
||||||
|
profiler_profile_memory – profile memory usage (default: true)
|
||||||
|
profiler_with_stack – record stack traces (default: true)
|
||||||
|
|
||||||
|
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
|
||||||
|
``<profiler_output_dir>/rank_<N>/``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, training_args: "TrainingArguments") -> None:
|
||||||
|
self.profiler = None
|
||||||
|
self.profiler_args = training_args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_rank() -> int:
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_rank()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_begin(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
) -> None:
|
||||||
|
if self.profiler is not None:
|
||||||
|
self.profiler.stop()
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
|
pa = self.profiler_args
|
||||||
|
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
|
||||||
|
rank = self._get_rank()
|
||||||
|
trace_dir = os.path.join(output_dir, f"rank_{rank}")
|
||||||
|
os.makedirs(trace_dir, exist_ok=True)
|
||||||
|
|
||||||
|
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||||
|
try:
|
||||||
|
if is_torch_cuda_available():
|
||||||
|
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||||
|
if is_torch_npu_available():
|
||||||
|
activities.append(torch.profiler.ProfilerActivity.NPU)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=activities,
|
||||||
|
schedule=torch.profiler.schedule(
|
||||||
|
wait=pa.profiler_wait_steps,
|
||||||
|
warmup=pa.profiler_warmup_steps,
|
||||||
|
active=pa.profiler_active_steps,
|
||||||
|
repeat=pa.profiler_repeat,
|
||||||
|
),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||||
|
record_shapes=pa.profiler_record_shapes,
|
||||||
|
profile_memory=pa.profiler_profile_memory,
|
||||||
|
with_stack=pa.profiler_with_stack,
|
||||||
|
)
|
||||||
|
self.profiler.start()
|
||||||
|
logger.info_rank0(
|
||||||
|
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
|
||||||
|
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_step_end(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
) -> None:
|
||||||
|
if self.profiler is not None:
|
||||||
|
self.profiler.step()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_end(
|
||||||
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
|
) -> None:
|
||||||
|
if self.profiler is not None:
|
||||||
|
self.profiler.stop()
|
||||||
|
self.profiler = None
|
||||||
|
logger.info_rank0("TorchProfiler stopped.")
|
||||||
|
|
||||||
|
|
||||||
class ReporterCallback(TrainerCallback):
|
class ReporterCallback(TrainerCallback):
|
||||||
r"""A callback for reporting training status to external logger."""
|
r"""A callback for reporting training status to external logger."""
|
||||||
|
|
||||||
@@ -394,3 +487,143 @@ class ReporterCallback(TrainerCallback):
|
|||||||
"generating_args": self.generating_args.to_dict(),
|
"generating_args": self.generating_args.to_dict(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleProfilerCallback(TrainerCallback):
|
||||||
|
r"""Profile forward/backward time of specified modules using accelerator events.
|
||||||
|
|
||||||
|
Hooks are registered on modules matching the user-provided name patterns.
|
||||||
|
Timing statistics are logged at each trainer logging step.
|
||||||
|
|
||||||
|
Usage in YAML config:
|
||||||
|
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
|
||||||
|
|
||||||
|
Supports fnmatch wildcards:
|
||||||
|
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_accelerator():
|
||||||
|
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
|
||||||
|
if is_torch_cuda_available():
|
||||||
|
return torch.cuda.Event, torch.cuda.synchronize
|
||||||
|
if is_torch_npu_available():
|
||||||
|
return torch.npu.Event, torch.npu.synchronize
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def __init__(self, profile_modules: str) -> None:
|
||||||
|
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
|
||||||
|
self._create_event, self._synchronize = self._get_accelerator()
|
||||||
|
self._handles: list[Any] = []
|
||||||
|
self._forward_times: dict[str, list[float]] = defaultdict(list)
|
||||||
|
self._backward_times: dict[str, list[float]] = defaultdict(list)
|
||||||
|
self._pending_forward: dict[str, tuple] = {}
|
||||||
|
self._pending_backward: dict[str, tuple] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self._create_event is not None
|
||||||
|
|
||||||
|
def _match(self, name: str) -> bool:
|
||||||
|
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
|
||||||
|
|
||||||
|
def _make_forward_pre_hook(self, name: str):
|
||||||
|
def hook(module, input):
|
||||||
|
start = self._create_event(enable_timing=True)
|
||||||
|
end = self._create_event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
self._pending_forward[name] = (start, end)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_forward_hook(self, name: str):
|
||||||
|
def hook(module, input, output):
|
||||||
|
pair = self._pending_forward.get(name)
|
||||||
|
if pair is not None:
|
||||||
|
pair[1].record()
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_backward_pre_hook(self, name: str):
|
||||||
|
def hook(module, grad_output):
|
||||||
|
start = self._create_event(enable_timing=True)
|
||||||
|
end = self._create_event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
self._pending_backward[name] = (start, end)
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_backward_hook(self, name: str):
|
||||||
|
def hook(module, grad_input, grad_output):
|
||||||
|
pair = self._pending_backward.get(name)
|
||||||
|
if pair is not None:
|
||||||
|
pair[1].record()
|
||||||
|
|
||||||
|
return hook
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
if not self.enabled:
|
||||||
|
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
|
||||||
|
return
|
||||||
|
|
||||||
|
model = kwargs.get("model")
|
||||||
|
if model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
matched = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not name or not self._match(name):
|
||||||
|
continue
|
||||||
|
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
|
||||||
|
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
|
||||||
|
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
|
||||||
|
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
|
||||||
|
matched.append(name)
|
||||||
|
|
||||||
|
if matched:
|
||||||
|
logger.info_rank0(
|
||||||
|
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||||
|
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._synchronize()
|
||||||
|
|
||||||
|
for name, (start, end) in self._pending_forward.items():
|
||||||
|
self._forward_times[name].append(start.elapsed_time(end))
|
||||||
|
self._pending_forward.clear()
|
||||||
|
|
||||||
|
for name, (start, end) in self._pending_backward.items():
|
||||||
|
self._backward_times[name].append(start.elapsed_time(end))
|
||||||
|
self._pending_backward.clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
if not self._forward_times and not self._backward_times:
|
||||||
|
return
|
||||||
|
|
||||||
|
lines = ["[ModuleProfiler] Timing (ms):"]
|
||||||
|
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
|
||||||
|
for name in all_names:
|
||||||
|
fwd = self._forward_times.get(name, [])
|
||||||
|
bwd = self._backward_times.get(name, [])
|
||||||
|
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||||
|
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||||
|
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
|
||||||
|
|
||||||
|
logger.info_rank0("\n".join(lines))
|
||||||
|
self._forward_times.clear()
|
||||||
|
self._backward_times.clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
for handle in self._handles:
|
||||||
|
handle.remove()
|
||||||
|
self._handles.clear()
|
||||||
|
|||||||
@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.running = RunningMoments(self.accelerator)
|
self.running = RunningMoments(self.accelerator)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(*args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|||||||
@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(*args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|||||||
@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
|
|||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(*args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|||||||
@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
|
|||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(*args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|||||||
@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer(*args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|||||||
@@ -32,7 +32,13 @@ from ..extras.packages import (
|
|||||||
)
|
)
|
||||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import (
|
||||||
|
LogCallback,
|
||||||
|
ModuleProfilerCallback,
|
||||||
|
PissaConvertCallback,
|
||||||
|
ReporterCallback,
|
||||||
|
TorchProfilerCallback,
|
||||||
|
)
|
||||||
from .dpo import run_dpo
|
from .dpo import run_dpo
|
||||||
from .kto import run_kto
|
from .kto import run_kto
|
||||||
from .ppo import run_ppo
|
from .ppo import run_ppo
|
||||||
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
if finetuning_args.early_stopping_steps is not None:
|
if finetuning_args.early_stopping_steps is not None:
|
||||||
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
|
||||||
|
|
||||||
|
if training_args.enable_torch_profiler:
|
||||||
|
callbacks.append(TorchProfilerCallback(training_args))
|
||||||
|
|
||||||
|
if training_args.profile_modules:
|
||||||
|
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
|
||||||
|
|
||||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||||
|
|
||||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||||
|
|||||||
@@ -134,6 +134,9 @@ class BaseTrainer:
|
|||||||
global_step=self.global_step,
|
global_step=self.global_step,
|
||||||
epoch=self._resume_epoch,
|
epoch=self._resume_epoch,
|
||||||
)
|
)
|
||||||
|
# Keep callback state aligned with checkpoint-resumed trainer counters.
|
||||||
|
self.state.global_step = self.global_step
|
||||||
|
self.state.epoch = self._resume_epoch
|
||||||
|
|
||||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||||
@@ -303,7 +306,7 @@ class BaseTrainer:
|
|||||||
if self.global_step % self.args.logging_steps == 0:
|
if self.global_step % self.args.logging_steps == 0:
|
||||||
logs = {
|
logs = {
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"step": self.global_step,
|
"step": self.state.global_step,
|
||||||
"loss": step_loss,
|
"loss": step_loss,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"learning_rate": current_lr,
|
"learning_rate": current_lr,
|
||||||
@@ -335,7 +338,9 @@ class BaseTrainer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
model_to_save.save_pretrained(
|
||||||
|
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
|
||||||
|
)
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|
||||||
|
|||||||
@@ -143,6 +143,12 @@ class ModelEngine:
|
|||||||
elif self.args.model_class == ModelClass.CLS:
|
elif self.args.model_class == ModelClass.CLS:
|
||||||
from transformers import AutoModelForTokenClassification
|
from transformers import AutoModelForTokenClassification
|
||||||
|
|
||||||
|
self.model_config.num_labels = 1
|
||||||
|
self.model_config.classifier_dropout = 0.0
|
||||||
|
text_config = getattr(self.model_config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
text_config.num_labels = 1
|
||||||
|
text_config.classifier_dropout = 0.0
|
||||||
AutoClass = AutoModelForTokenClassification
|
AutoClass = AutoModelForTokenClassification
|
||||||
else:
|
else:
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|||||||
@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||||
|
|
||||||
generato_seed = torch.Generator()
|
generator_seed = torch.Generator()
|
||||||
generato_seed.manual_seed(self.seed)
|
generator_seed.manual_seed(self.seed)
|
||||||
|
|
||||||
self._data_provider = StatefulDataLoader(
|
self._data_provider = StatefulDataLoader(
|
||||||
self.dataset,
|
self.dataset,
|
||||||
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
pin_memory_device=DistributedInterface().current_device.type,
|
pin_memory_device=DistributedInterface().current_device.type,
|
||||||
drop_last=self.drop_last,
|
drop_last=self.drop_last,
|
||||||
generator=generato_seed,
|
generator=generator_seed,
|
||||||
)
|
)
|
||||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||||
self._length = len(self._data_provider)
|
self._length = len(self._data_provider)
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ def _save_standard_training_states(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
model_to_save = model.module if hasattr(model, "module") else model
|
model_to_save = model.module if hasattr(model, "module") else model
|
||||||
model_dir = os.path.join(ckpt_dir, "model")
|
model_dir = os.path.join(ckpt_dir, "model")
|
||||||
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
|
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
|
||||||
processor.save_pretrained(model_dir)
|
processor.save_pretrained(model_dir)
|
||||||
|
|
||||||
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
||||||
@@ -212,7 +212,11 @@ def _load_standard_training_states(
|
|||||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
||||||
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
||||||
if state_dict:
|
if state_dict:
|
||||||
model_to_load.load_state_dict(state_dict)
|
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
|
||||||
|
if incompatible_keys.missing_keys:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,9 @@ def launch():
|
|||||||
elif command == "dpo":
|
elif command == "dpo":
|
||||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||||
elif command == "rm":
|
elif command == "rm":
|
||||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||||
|
|
||||||
|
run_rm()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Unknown command: {command}.\n{USAGE}")
|
print(f"Unknown command: {command}.\n{USAGE}")
|
||||||
@@ -175,9 +177,9 @@ def main():
|
|||||||
# run_dpo()
|
# run_dpo()
|
||||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||||
elif command == "rm":
|
elif command == "rm":
|
||||||
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||||
# run_rm()
|
|
||||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
run_rm()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -0,0 +1,429 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Pure-Triton Fused MoE Kernel for NVIDIA GPUs.
|
||||||
|
|
||||||
|
Replaces the HuggingFace per-expert Python loop with a fully fused Triton pipeline:
|
||||||
|
- Forward: scatter → grouped GEMM fc1 → SiLU·gate → apply routing → grouped GEMM fc2 → gather
|
||||||
|
- Backward: all dX via grouped GEMM, all dW via grouped GEMM (no Python loops)
|
||||||
|
|
||||||
|
Supported models: Mixtral, Qwen3-MoE, Qwen3.5-MoE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ......accelerator.helper import DeviceType
|
||||||
|
from ......utils.types import HFModel
|
||||||
|
from ...base import BaseKernel
|
||||||
|
from ...registry import register_kernel
|
||||||
|
from .triton_grouped_gemm import (
|
||||||
|
group_gemm_same_mn,
|
||||||
|
group_gemm_same_nk,
|
||||||
|
moe_gather,
|
||||||
|
moe_scatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Autograd Function: Full Triton MoE forward + backward
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TritonFusedMoeFunction(torch.autograd.Function):
|
||||||
|
"""Fused MoE expert computation using Triton grouped GEMMs.
|
||||||
|
|
||||||
|
Forward: scatter → fc1 (group GEMM) → SiLU·gate → weight → fc2 (group GEMM) → gather
|
||||||
|
Backward: all gradients computed via grouped GEMMs in single kernel launches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
num_experts,
|
||||||
|
gate_weights,
|
||||||
|
expert_index,
|
||||||
|
hidden_states,
|
||||||
|
fc1_weight,
|
||||||
|
fc2_weight,
|
||||||
|
):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: autograd context
|
||||||
|
num_experts: int
|
||||||
|
gate_weights: (num_tokens, top_k) routing weights
|
||||||
|
expert_index: (num_tokens, top_k) expert assignments
|
||||||
|
hidden_states: (num_tokens, hidden_dim)
|
||||||
|
fc1_weight: (E, 2*inter, hidden) merged gate+up weight
|
||||||
|
fc2_weight: (E, hidden, inter) down projection weight
|
||||||
|
"""
|
||||||
|
# Compute scatter index: maps (token, topk) → position in sorted buffer
|
||||||
|
scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape)
|
||||||
|
|
||||||
|
# Token counts per expert and cumulative boundaries
|
||||||
|
splits = torch.zeros(num_experts, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
flat_experts = expert_index.flatten().int()
|
||||||
|
splits.scatter_add_(0, flat_experts.long(), torch.ones_like(flat_experts))
|
||||||
|
cumsum_t = torch.cumsum(splits, dim=0)
|
||||||
|
|
||||||
|
# Scatter hidden states to sorted expert buffer
|
||||||
|
scatter_output = moe_scatter(hidden_states, scatter_index)
|
||||||
|
|
||||||
|
# FC1: grouped GEMM (scatter_output @ fc1_weight.T)
|
||||||
|
max_M = int(splits.max().item())
|
||||||
|
fc1_output = group_gemm_same_nk(
|
||||||
|
a=scatter_output,
|
||||||
|
b=fc1_weight,
|
||||||
|
cumsum_M=cumsum_t,
|
||||||
|
max_M=max_M,
|
||||||
|
transpose_b=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SiLU gate activation
|
||||||
|
fc1_1_output, fc1_2_output = fc1_output.chunk(2, dim=-1)
|
||||||
|
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
|
||||||
|
fc1_activation = fc1_1_activation * fc1_2_output
|
||||||
|
|
||||||
|
# Apply routing weights before fc2 (mathematically equivalent to after)
|
||||||
|
reshaped_gate_weight = gate_weights.reshape(-1, 1)
|
||||||
|
scattered_gate_weight = torch.empty_like(reshaped_gate_weight)
|
||||||
|
scattered_gate_weight[scatter_index.flatten().long()] = reshaped_gate_weight
|
||||||
|
fc1_weighted_output = fc1_activation * scattered_gate_weight
|
||||||
|
|
||||||
|
# FC2: grouped GEMM (fc1_weighted @ fc2_weight.T)
|
||||||
|
fc2_output = group_gemm_same_nk(
|
||||||
|
a=fc1_weighted_output,
|
||||||
|
b=fc2_weight,
|
||||||
|
cumsum_M=cumsum_t,
|
||||||
|
max_M=max_M,
|
||||||
|
transpose_b=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gather back to original token positions (sum over topk)
|
||||||
|
expert_output = moe_gather(fc2_output, scatter_index)
|
||||||
|
|
||||||
|
ctx.num_experts = num_experts
|
||||||
|
ctx.save_for_backward(
|
||||||
|
gate_weights,
|
||||||
|
fc1_weight,
|
||||||
|
fc2_weight,
|
||||||
|
hidden_states,
|
||||||
|
scatter_index,
|
||||||
|
scatter_output,
|
||||||
|
cumsum_t,
|
||||||
|
fc1_1_output,
|
||||||
|
fc1_2_output,
|
||||||
|
fc1_activation,
|
||||||
|
scattered_gate_weight,
|
||||||
|
fc1_weighted_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
return expert_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
(
|
||||||
|
gate_weights,
|
||||||
|
fc1_weight,
|
||||||
|
fc2_weight,
|
||||||
|
hidden_states,
|
||||||
|
scatter_index,
|
||||||
|
scatter_output,
|
||||||
|
cumsum_t,
|
||||||
|
fc1_1_output,
|
||||||
|
fc1_2_output,
|
||||||
|
fc1_activation,
|
||||||
|
scattered_gate_weight,
|
||||||
|
fc1_weighted_output,
|
||||||
|
) = ctx.saved_tensors
|
||||||
|
num_experts = ctx.num_experts
|
||||||
|
hidden_dim = grad_output.shape[-1]
|
||||||
|
grad_output = grad_output.reshape(-1, hidden_dim).contiguous()
|
||||||
|
|
||||||
|
# Recompute max_M from cumsum
|
||||||
|
splits = torch.zeros(num_experts, dtype=cumsum_t.dtype, device=cumsum_t.device)
|
||||||
|
splits[0] = cumsum_t[0]
|
||||||
|
splits[1:] = cumsum_t[1:] - cumsum_t[:-1]
|
||||||
|
max_M = int(splits.max().item())
|
||||||
|
|
||||||
|
# Step 1: Scatter grad_output to expert buffer
|
||||||
|
grad_fc2_output = moe_scatter(grad_output, scatter_index)
|
||||||
|
|
||||||
|
# Step 2: FC2 backward
|
||||||
|
# dX for fc2: grad_fc2_output @ fc2_weight (transpose_b=False since fc2 is (E, hidden, inter))
|
||||||
|
grad_fc1_weighted_output = group_gemm_same_nk(
|
||||||
|
a=grad_fc2_output,
|
||||||
|
b=fc2_weight,
|
||||||
|
cumsum_M=cumsum_t,
|
||||||
|
max_M=max_M,
|
||||||
|
transpose_b=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dW for fc2: grad_fc2_output.T @ fc1_weighted_output
|
||||||
|
grad_fc2_weight = None
|
||||||
|
if fc2_weight.requires_grad:
|
||||||
|
grad_fc2_weight = torch.empty_like(fc2_weight)
|
||||||
|
group_gemm_same_mn(
|
||||||
|
a=grad_fc2_output,
|
||||||
|
b=fc1_weighted_output,
|
||||||
|
c=grad_fc2_weight,
|
||||||
|
cumsum_K=cumsum_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Routing weight backward
|
||||||
|
grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight
|
||||||
|
grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1)
|
||||||
|
grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten().long()]
|
||||||
|
grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape)
|
||||||
|
|
||||||
|
# Recompute silu activation for backward
|
||||||
|
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
|
||||||
|
|
||||||
|
# Step 4: SiLU gate backward
|
||||||
|
grad_fc1_1_activation = grad_fc1_activation * fc1_2_output
|
||||||
|
grad_fc1_2_output = fc1_1_activation * grad_fc1_activation
|
||||||
|
|
||||||
|
# SiLU backward: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
|
||||||
|
grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
|
||||||
|
|
||||||
|
# Merge fc1 gradients back to (total_M, 2*inter)
|
||||||
|
grad_fc1_output = torch.cat([grad_fc1_1_output, grad_fc1_2_output], dim=-1)
|
||||||
|
|
||||||
|
# Step 5: FC1 backward
|
||||||
|
# dX for fc1: grad_fc1_output @ fc1_weight (transpose_b=False)
|
||||||
|
grad_scatter_output = group_gemm_same_nk(
|
||||||
|
a=grad_fc1_output,
|
||||||
|
b=fc1_weight,
|
||||||
|
cumsum_M=cumsum_t,
|
||||||
|
max_M=max_M,
|
||||||
|
transpose_b=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dW for fc1: grad_fc1_output.T @ scatter_output
|
||||||
|
grad_fc1_weight = None
|
||||||
|
if fc1_weight.requires_grad:
|
||||||
|
grad_fc1_weight = torch.empty_like(fc1_weight)
|
||||||
|
group_gemm_same_mn(
|
||||||
|
a=grad_fc1_output,
|
||||||
|
b=scatter_output,
|
||||||
|
c=grad_fc1_weight,
|
||||||
|
cumsum_K=cumsum_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 6: Gather gradients back to original positions
|
||||||
|
grad_hidden_states = moe_gather(grad_scatter_output, scatter_index)
|
||||||
|
grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape)
|
||||||
|
|
||||||
|
return (
|
||||||
|
None, # num_experts
|
||||||
|
grad_gate_weight, # gate_weights
|
||||||
|
None, # expert_index
|
||||||
|
grad_hidden_states, # hidden_states
|
||||||
|
grad_fc1_weight, # fc1_weight
|
||||||
|
grad_fc2_weight, # fc2_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Patched forward functions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _triton_moe_experts_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
top_k_index: torch.Tensor,
|
||||||
|
top_k_weights: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Replacement forward for v5+ MoE expert modules with stacked 3D weights."""
|
||||||
|
return TritonFusedMoeFunction.apply(
|
||||||
|
self.num_experts,
|
||||||
|
top_k_weights.to(hidden_states.dtype),
|
||||||
|
top_k_index,
|
||||||
|
hidden_states,
|
||||||
|
self.gate_up_proj,
|
||||||
|
self.down_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legacy (transformers < 5.0) support: weight stacking + SparseMoeBlock patch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _StackedExpertWeights(torch.nn.Module):
|
||||||
|
"""Lightweight container holding stacked 3D expert weight tensors."""
|
||||||
|
|
||||||
|
def __init__(self, gate_up_proj: torch.Tensor, down_proj: torch.Tensor, num_experts: int):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = torch.nn.Parameter(gate_up_proj)
|
||||||
|
self.down_proj = torch.nn.Parameter(down_proj)
|
||||||
|
self.num_experts = num_experts
|
||||||
|
|
||||||
|
|
||||||
|
def _stack_expert_weights(module: torch.nn.Module) -> None:
|
||||||
|
"""Replace nn.ModuleList of individual experts with stacked 3D parameter tensors."""
|
||||||
|
experts = module.experts
|
||||||
|
num_experts = len(experts)
|
||||||
|
|
||||||
|
gate_up_list = []
|
||||||
|
for expert in experts:
|
||||||
|
gate_w = expert.gate_proj.weight.data # (inter, hidden)
|
||||||
|
up_w = expert.up_proj.weight.data # (inter, hidden)
|
||||||
|
gate_up_list.append(torch.cat([gate_w, up_w], dim=0)) # (2*inter, hidden)
|
||||||
|
gate_up_proj = torch.stack(gate_up_list, dim=0) # (E, 2*inter, hidden)
|
||||||
|
|
||||||
|
down_proj = torch.stack([e.down_proj.weight.data for e in experts], dim=0) # (E, hidden, inter)
|
||||||
|
|
||||||
|
module.experts = _StackedExpertWeights(gate_up_proj, down_proj, num_experts)
|
||||||
|
logger.info(
|
||||||
|
f"cuda_fused_moe: Stacked {num_experts} expert weights into "
|
||||||
|
f"gate_up_proj {tuple(gate_up_proj.shape)}, down_proj {tuple(down_proj.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _triton_moe_sparse_block_forward(self, hidden_states: torch.Tensor):
|
||||||
|
"""Replacement forward for legacy SparseMoeBlock with inline routing."""
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||||
|
if self.norm_topk_prob:
|
||||||
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
final_hidden_states = TritonFusedMoeFunction.apply(
|
||||||
|
self.num_experts,
|
||||||
|
routing_weights,
|
||||||
|
selected_experts,
|
||||||
|
hidden_states,
|
||||||
|
self.experts.gate_up_proj,
|
||||||
|
self.experts.down_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
|
return final_hidden_states, router_logits
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module mapping
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_TRITON_MOE_MAPPING: dict[str, dict[str, object]] = {
|
||||||
|
"MixtralForCausalLM": {
|
||||||
|
"MixtralExperts": _triton_moe_experts_forward,
|
||||||
|
},
|
||||||
|
"Qwen3MoeForCausalLM": {
|
||||||
|
"Qwen3MoeExperts": _triton_moe_experts_forward,
|
||||||
|
"Qwen3MoeSparseMoeBlock": _triton_moe_sparse_block_forward,
|
||||||
|
},
|
||||||
|
"Qwen3_5MoeForCausalLM": {
|
||||||
|
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
|
||||||
|
},
|
||||||
|
"Qwen3_5MoeForConditionalGeneration": {
|
||||||
|
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Kernel registration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@register_kernel
|
||||||
|
class CudaFusedMoEKernel(BaseKernel):
|
||||||
|
"""Pure-Triton fused MoE kernel for NVIDIA CUDA GPUs.
|
||||||
|
|
||||||
|
Replaces HuggingFace per-expert Python loops with a fully fused Triton pipeline:
|
||||||
|
- Forward: scatter + grouped GEMMs + gather (single kernel per GEMM)
|
||||||
|
- Backward: all dX and dW via grouped GEMMs (no Python loops)
|
||||||
|
|
||||||
|
Requires: CUDA GPU + Triton
|
||||||
|
"""
|
||||||
|
|
||||||
|
_kernel_id = "cuda_fused_moe"
|
||||||
|
_device = DeviceType.CUDA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_deps(cls) -> bool:
|
||||||
|
if not super().check_deps():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import triton # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.info("cuda_fused_moe: Triton not available, kernel disabled.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, **kwargs) -> HFModel:
|
||||||
|
model = kwargs.get("model")
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||||
|
|
||||||
|
if not cls.check_deps():
|
||||||
|
logger.warning("cuda_fused_moe: Dependencies not met. Skipping kernel application.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
archs = getattr(model.config, "architectures", None) or []
|
||||||
|
target_mapping = None
|
||||||
|
for arch in archs:
|
||||||
|
if arch in _TRITON_MOE_MAPPING:
|
||||||
|
target_mapping = _TRITON_MOE_MAPPING[arch]
|
||||||
|
break
|
||||||
|
|
||||||
|
if target_mapping is None:
|
||||||
|
logger.info(
|
||||||
|
f"cuda_fused_moe: Model architecture {archs} not supported. "
|
||||||
|
f"Supported: {list(_TRITON_MOE_MAPPING.keys())}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
patched_count = 0
|
||||||
|
for module in model.modules():
|
||||||
|
class_name = module.__class__.__name__
|
||||||
|
if class_name not in target_mapping:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_fn = target_mapping[class_name]
|
||||||
|
|
||||||
|
if hasattr(module, "gate_up_proj") and hasattr(module, "down_proj"):
|
||||||
|
module.forward = types.MethodType(target_fn, module)
|
||||||
|
patched_count += 1
|
||||||
|
elif (
|
||||||
|
hasattr(module, "experts")
|
||||||
|
and isinstance(module.experts, torch.nn.ModuleList)
|
||||||
|
and hasattr(module, "gate")
|
||||||
|
):
|
||||||
|
_stack_expert_weights(module)
|
||||||
|
module.forward = types.MethodType(target_fn, module)
|
||||||
|
patched_count += 1
|
||||||
|
|
||||||
|
if patched_count > 0:
|
||||||
|
logger.info(f"cuda_fused_moe: Patched {patched_count} MoE expert modules with pure Triton pipeline.")
|
||||||
|
else:
|
||||||
|
logger.warning("cuda_fused_moe: No MoE expert modules found to patch.")
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -0,0 +1,417 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# Pure-Triton grouped GEMM and MoE scatter/gather kernels.
|
||||||
|
# Design adapted from VeOmni (ByteDance-Seed/VeOmni) group_gemm kernels.
|
||||||
|
|
||||||
|
"""Pure-Triton MoE kernels: grouped GEMM, scatter, and gather.
|
||||||
|
|
||||||
|
Provides four kernel types for fused MoE forward+backward without Python loops:
|
||||||
|
- group_gemm_same_nk: Variable-M grouped GEMM (forward & backward dX)
|
||||||
|
- group_gemm_same_mn: Variable-K grouped GEMM (backward dW)
|
||||||
|
- moe_scatter: Token dispatch to sorted expert buffers
|
||||||
|
- moe_gather: Token reduction from expert buffers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Triton helper: grouped tile indexing with L2 cache-friendly swizzle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr):
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
return pid_m, pid_n
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# group_gemm_same_nk: All experts share same N, K; variable M per expert
|
||||||
|
# Used for: forward (x @ W.T) and backward dX (grad @ W)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||||
|
],
|
||||||
|
key=["N", "K"],
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _group_gemm_same_nk_kernel(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
cumsum_M,
|
||||||
|
max_M,
|
||||||
|
G: tl.constexpr,
|
||||||
|
N: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
TRANSPOSE_B: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
GROUP: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_m, pid_n = _get_pid_mn(tl.program_id(0), max_M, N, BLOCK_M, BLOCK_N, GROUP)
|
||||||
|
gid = tl.program_id(1).to(tl.int64)
|
||||||
|
|
||||||
|
gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0).to(tl.int64)
|
||||||
|
gtid_end = tl.load(cumsum_M + gid).to(tl.int64)
|
||||||
|
m_size = gtid_end - gtid_start
|
||||||
|
|
||||||
|
if pid_m * BLOCK_M >= m_size:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
|
# a is (total_M, K) row-major, offset by expert start
|
||||||
|
a_base = a_ptr + gtid_start * K
|
||||||
|
# b is (G, N, K) if TRANSPOSE_B else (G, K, N)
|
||||||
|
b_base = b_ptr + gid * K * N
|
||||||
|
# c is (total_M, N) row-major
|
||||||
|
c_base = c_ptr + gtid_start * N
|
||||||
|
|
||||||
|
if TRANSPOSE_B:
|
||||||
|
# b layout: (G, N, K), we compute a @ b.T = a(M,K) @ b(N,K).T -> (M,N)
|
||||||
|
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
|
||||||
|
b_ptrs = b_base + offs_n[:, None] * K + offs_k[None, :]
|
||||||
|
else:
|
||||||
|
# b layout: (G, K, N), we compute a @ b = a(M,K) @ b(K,N) -> (M,N)
|
||||||
|
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
|
||||||
|
b_ptrs = b_base + offs_k[:, None] * N + offs_n[None, :]
|
||||||
|
|
||||||
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
for k_start in range(0, K, BLOCK_K):
|
||||||
|
k_offs = k_start + offs_k
|
||||||
|
k_mask = k_offs < K
|
||||||
|
|
||||||
|
a_block = tl.load(a_ptrs, mask=(offs_m[:, None] < m_size) & k_mask[None, :], other=0.0)
|
||||||
|
|
||||||
|
if TRANSPOSE_B:
|
||||||
|
b_block = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0)
|
||||||
|
acc += tl.dot(a_block, tl.trans(b_block))
|
||||||
|
else:
|
||||||
|
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
|
||||||
|
acc += tl.dot(a_block, b_block)
|
||||||
|
|
||||||
|
if TRANSPOSE_B:
|
||||||
|
a_ptrs += BLOCK_K
|
||||||
|
b_ptrs += BLOCK_K
|
||||||
|
else:
|
||||||
|
a_ptrs += BLOCK_K
|
||||||
|
b_ptrs += BLOCK_K * N
|
||||||
|
|
||||||
|
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
|
||||||
|
c_mask = (offs_m[:, None] < m_size) & (offs_n[None, :] < N)
|
||||||
|
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def group_gemm_same_nk(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
cumsum_M: torch.Tensor,
|
||||||
|
max_M: int,
|
||||||
|
transpose_b: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Grouped GEMM where all groups share same N, K dimensions but variable M.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: (total_M, K) input tensor, rows grouped by expert
|
||||||
|
b: (G, N, K) if transpose_b else (G, K, N) weight tensor
|
||||||
|
cumsum_M: (G,) cumulative token counts per expert
|
||||||
|
max_M: maximum tokens any single expert has
|
||||||
|
transpose_b: if True, compute a @ b.T; else compute a @ b
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
c: (total_M, N) output tensor
|
||||||
|
"""
|
||||||
|
if transpose_b:
|
||||||
|
G, N, K = b.shape
|
||||||
|
else:
|
||||||
|
G, K, N = b.shape
|
||||||
|
|
||||||
|
c = torch.empty((a.shape[0], N), dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
|
_group_gemm_same_nk_kernel[
|
||||||
|
(lambda meta: (triton.cdiv(max_M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))
|
||||||
|
](
|
||||||
|
a_ptr=a,
|
||||||
|
b_ptr=b,
|
||||||
|
c_ptr=c,
|
||||||
|
cumsum_M=cumsum_M,
|
||||||
|
max_M=max_M,
|
||||||
|
G=G,
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
TRANSPOSE_B=transpose_b,
|
||||||
|
)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# group_gemm_same_mn: All experts share same M, N (weight dims); variable K
|
||||||
|
# Used for: backward dW (grad.T @ input)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||||
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
|
||||||
|
],
|
||||||
|
key=["M", "N"],
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _group_gemm_same_mn_kernel(
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
cumsum_K,
|
||||||
|
G: tl.constexpr,
|
||||||
|
M: tl.constexpr,
|
||||||
|
N: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
BLOCK_K: tl.constexpr,
|
||||||
|
GROUP: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_m, pid_n = _get_pid_mn(tl.program_id(0), M, N, BLOCK_M, BLOCK_N, GROUP)
|
||||||
|
gid = tl.program_id(1).to(tl.int64)
|
||||||
|
|
||||||
|
gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0).to(tl.int64)
|
||||||
|
gtid_end = tl.load(cumsum_K + gid).to(tl.int64)
|
||||||
|
k_size = gtid_end - gtid_start
|
||||||
|
|
||||||
|
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
# c is (G, M, N)
|
||||||
|
c_base = c_ptr + gid * M * N
|
||||||
|
|
||||||
|
if k_size == 0:
|
||||||
|
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
|
||||||
|
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||||
|
tl.store(c_ptrs, tl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.dtype.element_ty), mask=c_mask)
|
||||||
|
return
|
||||||
|
|
||||||
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
|
offs_k = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
|
# a is (total_K, M), compute a.T @ b -> (M, N)
|
||||||
|
# b is (total_K, N)
|
||||||
|
a_base = a_ptr + gtid_start * M
|
||||||
|
b_base = b_ptr + gtid_start * N
|
||||||
|
|
||||||
|
for k_start in range(0, k_size, BLOCK_K):
|
||||||
|
k_offs = k_start + offs_k
|
||||||
|
k_mask = k_offs < k_size
|
||||||
|
|
||||||
|
a_ptrs = a_base + k_offs[:, None] * M + offs_m[None, :]
|
||||||
|
a_block_t = tl.trans(tl.load(a_ptrs, mask=k_mask[:, None] & (offs_m[None, :] < M), other=0.0))
|
||||||
|
|
||||||
|
# Load b block: (BLOCK_K, BLOCK_N)
|
||||||
|
b_ptrs = b_base + k_offs[:, None] * N + offs_n[None, :]
|
||||||
|
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
|
||||||
|
|
||||||
|
acc += tl.dot(a_block_t, b_block)
|
||||||
|
|
||||||
|
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
|
||||||
|
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||||
|
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def group_gemm_same_mn(
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
cumsum_K: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Grouped GEMM where all groups produce same (M, N) output; variable K reduction.
|
||||||
|
|
||||||
|
Computes: c[g] = a[s:e].T @ b[s:e] for each group g,
|
||||||
|
where s, e are defined by cumsum_K boundaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: (total_K, M) input tensor grouped by expert
|
||||||
|
b: (total_K, N) input tensor grouped by expert
|
||||||
|
c: (G, M, N) output tensor (pre-allocated)
|
||||||
|
cumsum_K: (G,) cumulative token counts per expert
|
||||||
|
"""
|
||||||
|
G, M, N = c.shape
|
||||||
|
|
||||||
|
_group_gemm_same_mn_kernel[(lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))](
|
||||||
|
a_ptr=a,
|
||||||
|
b_ptr=b,
|
||||||
|
c_ptr=c,
|
||||||
|
cumsum_K=cumsum_K,
|
||||||
|
G=G,
|
||||||
|
M=M,
|
||||||
|
N=N,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# moe_scatter: Dispatch tokens to sorted expert buffer positions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _moe_scatter_kernel(
|
||||||
|
x_ptr,
|
||||||
|
out_ptr,
|
||||||
|
index_ptr,
|
||||||
|
M,
|
||||||
|
N: tl.constexpr,
|
||||||
|
TOPK: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Scatter: for each token i, copy x[i] to out[index[i, k]] for k in 0..topk-1."""
|
||||||
|
pid_m = tl.program_id(0).to(tl.int64)
|
||||||
|
pid_n = tl.program_id(1)
|
||||||
|
|
||||||
|
if pid_m >= M:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
n_mask = offs_n < N
|
||||||
|
|
||||||
|
# Load input row
|
||||||
|
x_ptrs = x_ptr + pid_m * N + offs_n
|
||||||
|
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0)
|
||||||
|
|
||||||
|
# Store to each topk destination
|
||||||
|
for k in tl.static_range(TOPK):
|
||||||
|
dst_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
|
||||||
|
out_ptrs = out_ptr + dst_idx * N + offs_n
|
||||||
|
tl.store(out_ptrs, x_vals, mask=n_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_scatter(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Scatter tokens to sorted expert buffer.
|
||||||
|
|
||||||
|
For each token i and topk slot k, copies x[i] to output[index[i, k]].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (M, N) input hidden states
|
||||||
|
index: (M, topk) scatter indices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: (M * topk, N) scattered output
|
||||||
|
"""
|
||||||
|
M, N = x.shape
|
||||||
|
topk = index.shape[1]
|
||||||
|
out = torch.empty(M * topk, N, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
BLOCK_N = min(triton.next_power_of_2(N), 1024)
|
||||||
|
grid = (M, triton.cdiv(N, BLOCK_N))
|
||||||
|
|
||||||
|
_moe_scatter_kernel[grid](
|
||||||
|
x_ptr=x,
|
||||||
|
out_ptr=out,
|
||||||
|
index_ptr=index,
|
||||||
|
M=M,
|
||||||
|
N=N,
|
||||||
|
TOPK=topk,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# moe_gather: Reduce expert outputs back to token positions (sum over topk)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _moe_gather_kernel(
|
||||||
|
x_ptr,
|
||||||
|
out_ptr,
|
||||||
|
index_ptr,
|
||||||
|
M,
|
||||||
|
N: tl.constexpr,
|
||||||
|
TOPK: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Gather: for each token i, out[i] = sum_k(x[index[i, k]]) over topk."""
|
||||||
|
pid_m = tl.program_id(0).to(tl.int64)
|
||||||
|
pid_n = tl.program_id(1)
|
||||||
|
|
||||||
|
if pid_m >= M:
|
||||||
|
return
|
||||||
|
|
||||||
|
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
n_mask = offs_n < N
|
||||||
|
|
||||||
|
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
|
||||||
|
|
||||||
|
for k in tl.static_range(TOPK):
|
||||||
|
src_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
|
||||||
|
x_ptrs = x_ptr + src_idx * N + offs_n
|
||||||
|
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0).to(tl.float32)
|
||||||
|
acc += x_vals
|
||||||
|
|
||||||
|
out_ptrs = out_ptr + pid_m * N + offs_n
|
||||||
|
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=n_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_gather(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Gather and reduce expert outputs back to original token positions.
|
||||||
|
|
||||||
|
For each token i, sums x[index[i, k]] over all topk slots.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (M * topk, N) expert outputs in sorted buffer
|
||||||
|
index: (M, topk) scatter indices (same as used in moe_scatter)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: (M, N) gathered output
|
||||||
|
"""
|
||||||
|
M, topk = index.shape
|
||||||
|
N = x.shape[1]
|
||||||
|
out = torch.empty(M, N, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
BLOCK_N = min(triton.next_power_of_2(N), 1024)
|
||||||
|
grid = (M, triton.cdiv(N, BLOCK_N))
|
||||||
|
|
||||||
|
_moe_gather_kernel[grid](
|
||||||
|
x_ptr=x,
|
||||||
|
out_ptr=out,
|
||||||
|
index_ptr=index,
|
||||||
|
M=M,
|
||||||
|
N=N,
|
||||||
|
TOPK=topk,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
)
|
||||||
|
return out
|
||||||
@@ -381,7 +381,7 @@ class FSDP2Engine:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
|
||||||
grad_norm = grad_norm.full_tensor()
|
grad_norm = grad_norm.full_tensor()
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|||||||
@@ -78,14 +78,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
|
|||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
||||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||||
from .deepspeed import save_checkpoint
|
from .deepspeed import save_checkpoint
|
||||||
|
|
||||||
return save_checkpoint(model, optimizer, ckpt_dir)
|
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
||||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||||
from .deepspeed import load_checkpoint
|
from .deepspeed import load_checkpoint
|
||||||
|
|
||||||
return load_checkpoint(model, optimizer, ckpt_dir)
|
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||||
|
|||||||
@@ -0,0 +1,183 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..accelerator.interface import Dim, DistributedInterface
|
||||||
|
from ..config import InputArgument, TrainingArguments, get_args
|
||||||
|
from ..config.arg_utils import ModelClass
|
||||||
|
from ..core.base_trainer import BaseTrainer
|
||||||
|
from ..core.data_engine import DataEngine
|
||||||
|
from ..core.model_engine import ModelEngine
|
||||||
|
from ..utils import logging
|
||||||
|
from ..utils.types import BatchInput, HFModel, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
|
||||||
|
"""Validate RM dataset format early for clearer error messages."""
|
||||||
|
if len(train_dataset) == 0:
|
||||||
|
raise ValueError(f"RM training dataset is empty: {dataset_path}")
|
||||||
|
|
||||||
|
sample = train_dataset[0]
|
||||||
|
if "chosen_messages" in sample and "rejected_messages" in sample:
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset_name = sample.get("_dataset_name", "unknown")
|
||||||
|
sample_keys = sorted(sample.keys())
|
||||||
|
raise ValueError(
|
||||||
|
"RM training requires pair-format samples containing chosen/rejected responses. "
|
||||||
|
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
|
||||||
|
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
|
||||||
|
"or set converter='pair' for raw chosen/rejected fields)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_score_head(model: HFModel) -> None:
|
||||||
|
"""Initialize the score head for RM training with small Gaussian weights.
|
||||||
|
|
||||||
|
Uses Gaussian initialization so that different parameters have distinct values,
|
||||||
|
providing better gradient flow than zero initialization while keeping initial
|
||||||
|
scores small enough that the starting loss is close to ln(2).
|
||||||
|
"""
|
||||||
|
unwrapped = model.module if hasattr(model, "module") else model
|
||||||
|
score = getattr(unwrapped, "score", None)
|
||||||
|
if score is not None and hasattr(score, "weight"):
|
||||||
|
hidden_size = score.weight.shape[-1]
|
||||||
|
std = 1.0 / (hidden_size * 10)
|
||||||
|
with torch.no_grad():
|
||||||
|
score.weight.normal_(mean=0.0, std=std)
|
||||||
|
if score.bias is not None:
|
||||||
|
score.bias.zero_()
|
||||||
|
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
class RMTrainer(BaseTrainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer,
|
||||||
|
train_dataset,
|
||||||
|
callbacks=None,
|
||||||
|
) -> None:
|
||||||
|
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
|
||||||
|
if cp_size > 1:
|
||||||
|
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
|
||||||
|
|
||||||
|
super().__init__(args, model, renderer, train_dataset, callbacks)
|
||||||
|
|
||||||
|
def _shard_model(self) -> None:
|
||||||
|
if self.args.dist_config is None:
|
||||||
|
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||||
|
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
|
||||||
|
else:
|
||||||
|
super()._shard_model()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _unwrapped_model(self):
|
||||||
|
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
|
||||||
|
model = self.model
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model = model.module
|
||||||
|
return model
|
||||||
|
|
||||||
|
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||||
|
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
token_type_ids = batch.get("token_type_ids")
|
||||||
|
if token_type_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RM training requires pair data with token_type_ids. "
|
||||||
|
"Ensure the dataset has chosen_messages/rejected_messages."
|
||||||
|
)
|
||||||
|
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
|
||||||
|
# Transformers v5 models natively support this format in _update_causal_mask,
|
||||||
|
# constructing the correct block-diagonal causal mask internally for all attention backends.
|
||||||
|
model_attention_mask = token_type_ids
|
||||||
|
|
||||||
|
# Build position_ids that reset at each document boundary.
|
||||||
|
batch_size, seq_len = token_type_ids.shape
|
||||||
|
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||||
|
chosen_mask = token_type_ids == 1
|
||||||
|
rejected_mask = token_type_ids == 2
|
||||||
|
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
|
||||||
|
position_ids = torch.zeros_like(token_type_ids)
|
||||||
|
position_ids[chosen_mask] = arange[chosen_mask]
|
||||||
|
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
|
||||||
|
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=model_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
use_cache=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = model_output.logits.float().squeeze(-1)
|
||||||
|
|
||||||
|
chosen_mask = token_type_ids == 1
|
||||||
|
rejected_mask = token_type_ids == 2
|
||||||
|
|
||||||
|
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
|
||||||
|
if not torch.any(valid_pair_mask):
|
||||||
|
raise ValueError(
|
||||||
|
"No valid RM pairs found in this micro-batch. "
|
||||||
|
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = rewards[valid_pair_mask]
|
||||||
|
chosen_mask = chosen_mask[valid_pair_mask]
|
||||||
|
rejected_mask = rejected_mask[valid_pair_mask]
|
||||||
|
|
||||||
|
seq_len = rewards.size(-1)
|
||||||
|
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
|
||||||
|
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
|
||||||
|
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
|
||||||
|
|
||||||
|
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
|
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
|
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def run_rm(args: InputArgument = None):
|
||||||
|
model_args, data_args, training_args, _ = get_args(args)
|
||||||
|
model_args.model_class = ModelClass.CLS
|
||||||
|
DistributedInterface(training_args.dist_config)
|
||||||
|
train_dataset = DataEngine(data_args.train_dataset)
|
||||||
|
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
|
||||||
|
model_engine = ModelEngine(model_args, is_train=True)
|
||||||
|
_init_score_head(model_engine.model)
|
||||||
|
trainer = RMTrainer(
|
||||||
|
args=training_args,
|
||||||
|
model=model_engine.model,
|
||||||
|
renderer=model_engine.renderer,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
)
|
||||||
|
trainer.fit()
|
||||||
|
trainer.save_model()
|
||||||
|
DistributedInterface().destroy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_rm()
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Human-readable output to stdout
|
# Human-readable output to stdout
|
||||||
display_logs = {**logs, "total_steps": state.num_training_steps}
|
display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
|
||||||
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||||
logger.info_rank0(parts)
|
logger.info_rank0(parts)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user