3 Commits

Author SHA1 Message Date
浮梦
2322bf1cc2 [v1] add cuda fused moe kernel, implementing with triton (#10481) 2026-05-20 20:49:42 +08:00
浮梦
368c48968f [callback] add torch profiler callback (#10463) 2026-05-20 20:47:52 +08:00
浮梦
8b5ea65770 [v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-05-20 20:46:52 +08:00
20 changed files with 1393 additions and 33 deletions

View File

@@ -47,7 +47,13 @@ logger = logging.get_logger(__name__)
check_dependencies() check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_ARGS = [
ModelArguments,
DataArguments,
TrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_CLS = tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
@@ -57,9 +63,19 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"): if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] _TRAIN_MCA_ARGS = [
ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
]
_TRAIN_MCA_CLS = tuple[ _TRAIN_MCA_CLS = tuple[
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments ModelArguments,
DataArguments,
McaTrainingArguments,
FinetuningArguments,
GeneratingArguments,
] ]
else: else:
_TRAIN_MCA_ARGS = [] _TRAIN_MCA_ARGS = []

View File

@@ -14,6 +14,7 @@
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
@@ -63,6 +64,58 @@ class RayArguments:
self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs)) self.ray_init_kwargs = _convert_str_dict(json.loads(self.ray_init_kwargs))
@dataclass
class ProfilerArguments:
r"""Arguments for torch profiler configuration."""
enable_torch_profiler: bool = field(
default=False,
metadata={"help": "Whether to enable torch profiler for collecting performance traces."},
)
profiler_output_dir: Optional[str] = field(
default=None,
metadata={"help": "Directory to write profiler traces. Defaults to <output_dir>/profiler if not set."},
)
profiler_wait_steps: int = field(
default=1,
metadata={"help": "Number of steps to skip at the start of each profiling cycle."},
)
profiler_warmup_steps: int = field(
default=1,
metadata={"help": "Number of profiler warm-up steps per cycle."},
)
profiler_active_steps: int = field(
default=1,
metadata={"help": "Number of steps to actively record per cycle."},
)
profiler_repeat: int = field(
default=1,
metadata={"help": "Number of profiling cycles. Set to 0 for continuous profiling."},
)
profiler_record_shapes: bool = field(
default=True,
metadata={"help": "Whether to record tensor shapes during profiling."},
)
profiler_profile_memory: bool = field(
default=True,
metadata={"help": "Whether to profile memory usage."},
)
profiler_with_stack: bool = field(
default=True,
metadata={"help": "Whether to record stack traces during profiling."},
)
profile_modules: Optional[str] = field(
default=None,
metadata={
"help": (
"Comma-separated list of module name patterns to profile with CUDA events. "
"Supports fnmatch wildcards (e.g. 'model.layers.0.self_attn,model.layers.*.mlp'). "
"Reports per-module forward/backward timing statistics at each logging step."
)
},
)
@dataclass @dataclass
class Fp8Arguments: class Fp8Arguments:
r"""Arguments pertaining to the FP8 training.""" r"""Arguments pertaining to the FP8 training."""
@@ -87,7 +140,7 @@ class Fp8Arguments:
@dataclass @dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments): class TrainingArguments(ProfilerArguments, Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(

View File

@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import fnmatch
import json import json
import os import os
import signal import signal
import sys import sys
import time import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -31,7 +34,7 @@ from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray from ..extras.misc import get_peak_memory, is_env_enabled, is_torch_cuda_available, is_torch_npu_available, use_ray
from ..extras.packages import is_safetensors_available from ..extras.packages import is_safetensors_available
@@ -338,6 +341,96 @@ class LogCallback(TrainerCallback):
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
class TorchProfilerCallback(TrainerCallback):
r"""A callback for collecting torch.profiler traces during training.
Activated by setting ``enable_torch_profiler: true`` in the YAML config.
Configuration fields (in YAML):
profiler_output_dir where to write traces (default: <output_dir>/profiler)
profiler_wait_steps steps to skip at start of each cycle (default: 1)
profiler_warmup_steps profiler warm-up steps per cycle (default: 1)
profiler_active_steps steps to record per cycle (default: 1)
profiler_repeat number of cycles; 0 = forever (default: 1)
profiler_record_shapes record tensor shapes (default: true)
profiler_profile_memory profile memory usage (default: true)
profiler_with_stack record stack traces (default: true)
Trace files (one per rank, Chrome / TensorBoard JSON format) are written to
``<profiler_output_dir>/rank_<N>/``.
"""
def __init__(self, training_args: "TrainingArguments") -> None:
self.profiler = None
self.profiler_args = training_args
@staticmethod
def _get_rank() -> int:
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0
@override
def on_train_begin(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
pa = self.profiler_args
output_dir = pa.profiler_output_dir or os.path.join(args.output_dir, "profiler")
rank = self._get_rank()
trace_dir = os.path.join(output_dir, f"rank_{rank}")
os.makedirs(trace_dir, exist_ok=True)
activities = [torch.profiler.ProfilerActivity.CPU]
try:
if is_torch_cuda_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
if is_torch_npu_available():
activities.append(torch.profiler.ProfilerActivity.NPU)
except Exception:
pass
self.profiler = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=pa.profiler_wait_steps,
warmup=pa.profiler_warmup_steps,
active=pa.profiler_active_steps,
repeat=pa.profiler_repeat,
),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=pa.profiler_record_shapes,
profile_memory=pa.profiler_profile_memory,
with_stack=pa.profiler_with_stack,
)
self.profiler.start()
logger.info_rank0(
f"TorchProfiler started — schedule: wait={pa.profiler_wait_steps}, warmup={pa.profiler_warmup_steps}, "
f"active={pa.profiler_active_steps}, repeat={pa.profiler_repeat}. Traces → {output_dir}"
)
@override
def on_step_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.step()
@override
def on_train_end(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
) -> None:
if self.profiler is not None:
self.profiler.stop()
self.profiler = None
logger.info_rank0("TorchProfiler stopped.")
class ReporterCallback(TrainerCallback): class ReporterCallback(TrainerCallback):
r"""A callback for reporting training status to external logger.""" r"""A callback for reporting training status to external logger."""
@@ -394,3 +487,143 @@ class ReporterCallback(TrainerCallback):
"generating_args": self.generating_args.to_dict(), "generating_args": self.generating_args.to_dict(),
} }
) )
class ModuleProfilerCallback(TrainerCallback):
r"""Profile forward/backward time of specified modules using accelerator events.
Hooks are registered on modules matching the user-provided name patterns.
Timing statistics are logged at each trainer logging step.
Usage in YAML config:
profile_modules: "*.layers.0.self_attn,*.layers.0.mlp"
Supports fnmatch wildcards:
profile_modules: "*.layers.*.self_attn,*.layers.*.mlp.experts"
"""
@staticmethod
def _get_accelerator():
"""Detect available accelerator and return (event_factory, synchronize_fn)."""
if is_torch_cuda_available():
return torch.cuda.Event, torch.cuda.synchronize
if is_torch_npu_available():
return torch.npu.Event, torch.npu.synchronize
return None, None
def __init__(self, profile_modules: str) -> None:
self.patterns = [p.strip() for p in profile_modules.split(",") if p.strip()]
self._create_event, self._synchronize = self._get_accelerator()
self._handles: list[Any] = []
self._forward_times: dict[str, list[float]] = defaultdict(list)
self._backward_times: dict[str, list[float]] = defaultdict(list)
self._pending_forward: dict[str, tuple] = {}
self._pending_backward: dict[str, tuple] = {}
@property
def enabled(self) -> bool:
return self._create_event is not None
def _match(self, name: str) -> bool:
return any(fnmatch.fnmatch(name, pat) for pat in self.patterns)
def _make_forward_pre_hook(self, name: str):
def hook(module, input):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_forward[name] = (start, end)
return hook
def _make_forward_hook(self, name: str):
def hook(module, input, output):
pair = self._pending_forward.get(name)
if pair is not None:
pair[1].record()
return hook
def _make_backward_pre_hook(self, name: str):
def hook(module, grad_output):
start = self._create_event(enable_timing=True)
end = self._create_event(enable_timing=True)
start.record()
self._pending_backward[name] = (start, end)
return hook
def _make_backward_hook(self, name: str):
def hook(module, grad_input, grad_output):
pair = self._pending_backward.get(name)
if pair is not None:
pair[1].record()
return hook
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
logger.warning_rank0("ModuleProfiler: no supported accelerator (CUDA/NPU) found, profiling disabled.")
return
model = kwargs.get("model")
if model is None:
return
matched = []
for name, module in model.named_modules():
if not name or not self._match(name):
continue
self._handles.append(module.register_forward_pre_hook(self._make_forward_pre_hook(name)))
self._handles.append(module.register_forward_hook(self._make_forward_hook(name)))
self._handles.append(module.register_full_backward_pre_hook(self._make_backward_pre_hook(name)))
self._handles.append(module.register_full_backward_hook(self._make_backward_hook(name)))
matched.append(name)
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self.enabled:
return
self._synchronize()
for name, (start, end) in self._pending_forward.items():
self._forward_times[name].append(start.elapsed_time(end))
self._pending_forward.clear()
for name, (start, end) in self._pending_backward.items():
self._backward_times[name].append(start.elapsed_time(end))
self._pending_backward.clear()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not self._forward_times and not self._backward_times:
return
lines = ["[ModuleProfiler] Timing (ms):"]
all_names = sorted(set(list(self._forward_times.keys()) + list(self._backward_times.keys())))
for name in all_names:
fwd = self._forward_times.get(name, [])
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()
self._backward_times.clear()
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
for handle in self._handles:
handle.remove()
self._handles.clear()

View File

@@ -123,10 +123,10 @@ class CustomDPOTrainer(DPOTrainer):
self.running = RunningMoments(self.accelerator) self.running = RunningMoments(self.accelerator)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -120,10 +120,10 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -69,10 +69,10 @@ class CustomTrainer(Trainer):
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -65,10 +65,10 @@ class PairwiseTrainer(Trainer):
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -128,10 +128,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self, *args, **kwargs) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer(*args, **kwargs)
@override @override
def create_scheduler( def create_scheduler(

View File

@@ -32,7 +32,13 @@ from ..extras.packages import (
) )
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import (
LogCallback,
ModuleProfilerCallback,
PissaConvertCallback,
ReporterCallback,
TorchProfilerCallback,
)
from .dpo import run_dpo from .dpo import run_dpo
from .kto import run_kto from .kto import run_kto
from .ppo import run_ppo from .ppo import run_ppo
@@ -74,6 +80,12 @@ def _training_function(config: dict[str, Any]) -> None:
if finetuning_args.early_stopping_steps is not None: if finetuning_args.early_stopping_steps is not None:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps)) callbacks.append(EarlyStoppingCallback(early_stopping_patience=finetuning_args.early_stopping_steps))
if training_args.enable_torch_profiler:
callbacks.append(TorchProfilerCallback(training_args))
if training_args.profile_modules:
callbacks.append(ModuleProfilerCallback(training_args.profile_modules))
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:

View File

@@ -134,6 +134,9 @@ class BaseTrainer:
global_step=self.global_step, global_step=self.global_step,
epoch=self._resume_epoch, epoch=self._resume_epoch,
) )
# Keep callback state aligned with checkpoint-resumed trainer counters.
self.state.global_step = self.global_step
self.state.epoch = self._resume_epoch
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1: if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future. # qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -303,7 +306,7 @@ class BaseTrainer:
if self.global_step % self.args.logging_steps == 0: if self.global_step % self.args.logging_steps == 0:
logs = { logs = {
"epoch": epoch, "epoch": epoch,
"step": self.global_step, "step": self.state.global_step,
"loss": step_loss, "loss": step_loss,
"grad_norm": grad_norm, "grad_norm": grad_norm,
"learning_rate": current_lr, "learning_rate": current_lr,
@@ -335,7 +338,9 @@ class BaseTrainer:
) )
else: else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB") model_to_save.save_pretrained(
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
)
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB") self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}") logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -143,6 +143,12 @@ class ModelEngine:
elif self.args.model_class == ModelClass.CLS: elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification from transformers import AutoModelForTokenClassification
self.model_config.num_labels = 1
self.model_config.classifier_dropout = 0.0
text_config = getattr(self.model_config, "text_config", None)
if text_config is not None:
text_config.num_labels = 1
text_config.classifier_dropout = 0.0
AutoClass = AutoModelForTokenClassification AutoClass = AutoModelForTokenClassification
else: else:
from transformers import AutoModel from transformers import AutoModel

View File

@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
else: else:
raise NotImplementedError("Iterable dataset is not supported yet.") raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator() generator_seed = torch.Generator()
generato_seed.manual_seed(self.seed) generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader( self._data_provider = StatefulDataLoader(
self.dataset, self.dataset,
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type, pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last, drop_last=self.drop_last,
generator=generato_seed, generator=generator_seed,
) )
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider) self._length = len(self._data_provider)

View File

@@ -172,7 +172,7 @@ def _save_standard_training_states(
if rank == 0: if rank == 0:
model_to_save = model.module if hasattr(model, "module") else model model_to_save = model.module if hasattr(model, "module") else model
model_dir = os.path.join(ckpt_dir, "model") model_dir = os.path.join(ckpt_dir, "model")
model_to_save.save_pretrained(model_dir, max_shard_size="4GB") model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
processor.save_pretrained(model_dir) processor.save_pretrained(model_dir)
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True) os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
@@ -212,7 +212,11 @@ def _load_standard_training_states(
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))): for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
state_dict.update(torch.load(f, map_location="cpu", weights_only=True)) state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
if state_dict: if state_dict:
model_to_load.load_state_dict(state_dict) incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
if incompatible_keys.missing_keys:
raise RuntimeError(
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
)
else: else:
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.") logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")

View File

@@ -148,7 +148,9 @@ def launch():
elif command == "dpo": elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.") raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm": elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.") from llamafactory.v1.trainers.rm_trainer import run_rm
run_rm()
else: else:
print(f"Unknown command: {command}.\n{USAGE}") print(f"Unknown command: {command}.\n{USAGE}")
@@ -175,9 +177,9 @@ def main():
# run_dpo() # run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.") raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm": elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.") run_rm()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,429 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pure-Triton Fused MoE Kernel for NVIDIA GPUs.
Replaces the HuggingFace per-expert Python loop with a fully fused Triton pipeline:
- Forward: scatter → grouped GEMM fc1 → SiLU·gate → apply routing → grouped GEMM fc2 → gather
- Backward: all dX via grouped GEMM, all dW via grouped GEMM (no Python loops)
Supported models: Mixtral, Qwen3-MoE, Qwen3.5-MoE.
"""
import logging
import types
import torch
import torch.nn.functional as F
from ......accelerator.helper import DeviceType
from ......utils.types import HFModel
from ...base import BaseKernel
from ...registry import register_kernel
from .triton_grouped_gemm import (
group_gemm_same_mn,
group_gemm_same_nk,
moe_gather,
moe_scatter,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Autograd Function: Full Triton MoE forward + backward
# ---------------------------------------------------------------------------
class TritonFusedMoeFunction(torch.autograd.Function):
"""Fused MoE expert computation using Triton grouped GEMMs.
Forward: scatter → fc1 (group GEMM) → SiLU·gate → weight → fc2 (group GEMM) → gather
Backward: all gradients computed via grouped GEMMs in single kernel launches.
"""
@staticmethod
def forward(
ctx,
num_experts,
gate_weights,
expert_index,
hidden_states,
fc1_weight,
fc2_weight,
):
"""Forward pass.
Args:
ctx: autograd context
num_experts: int
gate_weights: (num_tokens, top_k) routing weights
expert_index: (num_tokens, top_k) expert assignments
hidden_states: (num_tokens, hidden_dim)
fc1_weight: (E, 2*inter, hidden) merged gate+up weight
fc2_weight: (E, hidden, inter) down projection weight
"""
# Compute scatter index: maps (token, topk) → position in sorted buffer
scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape)
# Token counts per expert and cumulative boundaries
splits = torch.zeros(num_experts, dtype=torch.int32, device=hidden_states.device)
flat_experts = expert_index.flatten().int()
splits.scatter_add_(0, flat_experts.long(), torch.ones_like(flat_experts))
cumsum_t = torch.cumsum(splits, dim=0)
# Scatter hidden states to sorted expert buffer
scatter_output = moe_scatter(hidden_states, scatter_index)
# FC1: grouped GEMM (scatter_output @ fc1_weight.T)
max_M = int(splits.max().item())
fc1_output = group_gemm_same_nk(
a=scatter_output,
b=fc1_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=True,
)
# SiLU gate activation
fc1_1_output, fc1_2_output = fc1_output.chunk(2, dim=-1)
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
fc1_activation = fc1_1_activation * fc1_2_output
# Apply routing weights before fc2 (mathematically equivalent to after)
reshaped_gate_weight = gate_weights.reshape(-1, 1)
scattered_gate_weight = torch.empty_like(reshaped_gate_weight)
scattered_gate_weight[scatter_index.flatten().long()] = reshaped_gate_weight
fc1_weighted_output = fc1_activation * scattered_gate_weight
# FC2: grouped GEMM (fc1_weighted @ fc2_weight.T)
fc2_output = group_gemm_same_nk(
a=fc1_weighted_output,
b=fc2_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=True,
)
# Gather back to original token positions (sum over topk)
expert_output = moe_gather(fc2_output, scatter_index)
ctx.num_experts = num_experts
ctx.save_for_backward(
gate_weights,
fc1_weight,
fc2_weight,
hidden_states,
scatter_index,
scatter_output,
cumsum_t,
fc1_1_output,
fc1_2_output,
fc1_activation,
scattered_gate_weight,
fc1_weighted_output,
)
return expert_output
@staticmethod
def backward(ctx, grad_output):
(
gate_weights,
fc1_weight,
fc2_weight,
hidden_states,
scatter_index,
scatter_output,
cumsum_t,
fc1_1_output,
fc1_2_output,
fc1_activation,
scattered_gate_weight,
fc1_weighted_output,
) = ctx.saved_tensors
num_experts = ctx.num_experts
hidden_dim = grad_output.shape[-1]
grad_output = grad_output.reshape(-1, hidden_dim).contiguous()
# Recompute max_M from cumsum
splits = torch.zeros(num_experts, dtype=cumsum_t.dtype, device=cumsum_t.device)
splits[0] = cumsum_t[0]
splits[1:] = cumsum_t[1:] - cumsum_t[:-1]
max_M = int(splits.max().item())
# Step 1: Scatter grad_output to expert buffer
grad_fc2_output = moe_scatter(grad_output, scatter_index)
# Step 2: FC2 backward
# dX for fc2: grad_fc2_output @ fc2_weight (transpose_b=False since fc2 is (E, hidden, inter))
grad_fc1_weighted_output = group_gemm_same_nk(
a=grad_fc2_output,
b=fc2_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=False,
)
# dW for fc2: grad_fc2_output.T @ fc1_weighted_output
grad_fc2_weight = None
if fc2_weight.requires_grad:
grad_fc2_weight = torch.empty_like(fc2_weight)
group_gemm_same_mn(
a=grad_fc2_output,
b=fc1_weighted_output,
c=grad_fc2_weight,
cumsum_K=cumsum_t,
)
# Step 3: Routing weight backward
grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight
grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1)
grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten().long()]
grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape)
# Recompute silu activation for backward
fc1_1_activation = torch.nn.functional.silu(fc1_1_output)
# Step 4: SiLU gate backward
grad_fc1_1_activation = grad_fc1_activation * fc1_2_output
grad_fc1_2_output = fc1_1_activation * grad_fc1_activation
# SiLU backward: d/dx[x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
# Merge fc1 gradients back to (total_M, 2*inter)
grad_fc1_output = torch.cat([grad_fc1_1_output, grad_fc1_2_output], dim=-1)
# Step 5: FC1 backward
# dX for fc1: grad_fc1_output @ fc1_weight (transpose_b=False)
grad_scatter_output = group_gemm_same_nk(
a=grad_fc1_output,
b=fc1_weight,
cumsum_M=cumsum_t,
max_M=max_M,
transpose_b=False,
)
# dW for fc1: grad_fc1_output.T @ scatter_output
grad_fc1_weight = None
if fc1_weight.requires_grad:
grad_fc1_weight = torch.empty_like(fc1_weight)
group_gemm_same_mn(
a=grad_fc1_output,
b=scatter_output,
c=grad_fc1_weight,
cumsum_K=cumsum_t,
)
# Step 6: Gather gradients back to original positions
grad_hidden_states = moe_gather(grad_scatter_output, scatter_index)
grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape)
return (
None, # num_experts
grad_gate_weight, # gate_weights
None, # expert_index
grad_hidden_states, # hidden_states
grad_fc1_weight, # fc1_weight
grad_fc2_weight, # fc2_weight
)
# ---------------------------------------------------------------------------
# Patched forward functions
# ---------------------------------------------------------------------------
def _triton_moe_experts_forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""Replacement forward for v5+ MoE expert modules with stacked 3D weights."""
return TritonFusedMoeFunction.apply(
self.num_experts,
top_k_weights.to(hidden_states.dtype),
top_k_index,
hidden_states,
self.gate_up_proj,
self.down_proj,
)
# ---------------------------------------------------------------------------
# Legacy (transformers < 5.0) support: weight stacking + SparseMoeBlock patch
# ---------------------------------------------------------------------------
class _StackedExpertWeights(torch.nn.Module):
"""Lightweight container holding stacked 3D expert weight tensors."""
def __init__(self, gate_up_proj: torch.Tensor, down_proj: torch.Tensor, num_experts: int):
super().__init__()
self.gate_up_proj = torch.nn.Parameter(gate_up_proj)
self.down_proj = torch.nn.Parameter(down_proj)
self.num_experts = num_experts
def _stack_expert_weights(module: torch.nn.Module) -> None:
"""Replace nn.ModuleList of individual experts with stacked 3D parameter tensors."""
experts = module.experts
num_experts = len(experts)
gate_up_list = []
for expert in experts:
gate_w = expert.gate_proj.weight.data # (inter, hidden)
up_w = expert.up_proj.weight.data # (inter, hidden)
gate_up_list.append(torch.cat([gate_w, up_w], dim=0)) # (2*inter, hidden)
gate_up_proj = torch.stack(gate_up_list, dim=0) # (E, 2*inter, hidden)
down_proj = torch.stack([e.down_proj.weight.data for e in experts], dim=0) # (E, hidden, inter)
module.experts = _StackedExpertWeights(gate_up_proj, down_proj, num_experts)
logger.info(
f"cuda_fused_moe: Stacked {num_experts} expert weights into "
f"gate_up_proj {tuple(gate_up_proj.shape)}, down_proj {tuple(down_proj.shape)}"
)
def _triton_moe_sparse_block_forward(self, hidden_states: torch.Tensor):
"""Replacement forward for legacy SparseMoeBlock with inline routing."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = TritonFusedMoeFunction.apply(
self.num_experts,
routing_weights,
selected_experts,
hidden_states,
self.experts.gate_up_proj,
self.experts.down_proj,
)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
# ---------------------------------------------------------------------------
# Module mapping
# ---------------------------------------------------------------------------
_TRITON_MOE_MAPPING: dict[str, dict[str, object]] = {
"MixtralForCausalLM": {
"MixtralExperts": _triton_moe_experts_forward,
},
"Qwen3MoeForCausalLM": {
"Qwen3MoeExperts": _triton_moe_experts_forward,
"Qwen3MoeSparseMoeBlock": _triton_moe_sparse_block_forward,
},
"Qwen3_5MoeForCausalLM": {
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
},
"Qwen3_5MoeForConditionalGeneration": {
"Qwen3_5MoeExperts": _triton_moe_experts_forward,
},
}
# ---------------------------------------------------------------------------
# Kernel registration
# ---------------------------------------------------------------------------
@register_kernel
class CudaFusedMoEKernel(BaseKernel):
"""Pure-Triton fused MoE kernel for NVIDIA CUDA GPUs.
Replaces HuggingFace per-expert Python loops with a fully fused Triton pipeline:
- Forward: scatter + grouped GEMMs + gather (single kernel per GEMM)
- Backward: all dX and dW via grouped GEMMs (no Python loops)
Requires: CUDA GPU + Triton
"""
_kernel_id = "cuda_fused_moe"
_device = DeviceType.CUDA
@classmethod
def check_deps(cls) -> bool:
if not super().check_deps():
return False
try:
import triton # noqa: F401
return True
except ImportError:
logger.info("cuda_fused_moe: Triton not available, kernel disabled.")
return False
@classmethod
def apply(cls, **kwargs) -> HFModel:
model = kwargs.get("model")
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
if not cls.check_deps():
logger.warning("cuda_fused_moe: Dependencies not met. Skipping kernel application.")
return model
archs = getattr(model.config, "architectures", None) or []
target_mapping = None
for arch in archs:
if arch in _TRITON_MOE_MAPPING:
target_mapping = _TRITON_MOE_MAPPING[arch]
break
if target_mapping is None:
logger.info(
f"cuda_fused_moe: Model architecture {archs} not supported. "
f"Supported: {list(_TRITON_MOE_MAPPING.keys())}"
)
return model
patched_count = 0
for module in model.modules():
class_name = module.__class__.__name__
if class_name not in target_mapping:
continue
target_fn = target_mapping[class_name]
if hasattr(module, "gate_up_proj") and hasattr(module, "down_proj"):
module.forward = types.MethodType(target_fn, module)
patched_count += 1
elif (
hasattr(module, "experts")
and isinstance(module.experts, torch.nn.ModuleList)
and hasattr(module, "gate")
):
_stack_expert_weights(module)
module.forward = types.MethodType(target_fn, module)
patched_count += 1
if patched_count > 0:
logger.info(f"cuda_fused_moe: Patched {patched_count} MoE expert modules with pure Triton pipeline.")
else:
logger.warning("cuda_fused_moe: No MoE expert modules found to patch.")
return model

View File

@@ -0,0 +1,417 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Pure-Triton grouped GEMM and MoE scatter/gather kernels.
# Design adapted from VeOmni (ByteDance-Seed/VeOmni) group_gemm kernels.
"""Pure-Triton MoE kernels: grouped GEMM, scatter, and gather.
Provides four kernel types for fused MoE forward+backward without Python loops:
- group_gemm_same_nk: Variable-M grouped GEMM (forward & backward dX)
- group_gemm_same_mn: Variable-K grouped GEMM (backward dW)
- moe_scatter: Token dispatch to sorted expert buffers
- moe_gather: Token reduction from expert buffers
"""
import torch
import triton
import triton.language as tl
# ---------------------------------------------------------------------------
# Triton helper: grouped tile indexing with L2 cache-friendly swizzle
# ---------------------------------------------------------------------------
@triton.jit
def _get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr):
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
return pid_m, pid_n
# ---------------------------------------------------------------------------
# group_gemm_same_nk: All experts share same N, K; variable M per expert
# Used for: forward (x @ W.T) and backward dX (grad @ W)
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
],
key=["N", "K"],
)
@triton.jit
def _group_gemm_same_nk_kernel(
a_ptr,
b_ptr,
c_ptr,
cumsum_M,
max_M,
G: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
TRANSPOSE_B: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP: tl.constexpr,
):
pid_m, pid_n = _get_pid_mn(tl.program_id(0), max_M, N, BLOCK_M, BLOCK_N, GROUP)
gid = tl.program_id(1).to(tl.int64)
gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0).to(tl.int64)
gtid_end = tl.load(cumsum_M + gid).to(tl.int64)
m_size = gtid_end - gtid_start
if pid_m * BLOCK_M >= m_size:
return
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# a is (total_M, K) row-major, offset by expert start
a_base = a_ptr + gtid_start * K
# b is (G, N, K) if TRANSPOSE_B else (G, K, N)
b_base = b_ptr + gid * K * N
# c is (total_M, N) row-major
c_base = c_ptr + gtid_start * N
if TRANSPOSE_B:
# b layout: (G, N, K), we compute a @ b.T = a(M,K) @ b(N,K).T -> (M,N)
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_base + offs_n[:, None] * K + offs_k[None, :]
else:
# b layout: (G, K, N), we compute a @ b = a(M,K) @ b(K,N) -> (M,N)
a_ptrs = a_base + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_base + offs_k[:, None] * N + offs_n[None, :]
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k_start in range(0, K, BLOCK_K):
k_offs = k_start + offs_k
k_mask = k_offs < K
a_block = tl.load(a_ptrs, mask=(offs_m[:, None] < m_size) & k_mask[None, :], other=0.0)
if TRANSPOSE_B:
b_block = tl.load(b_ptrs, mask=(offs_n[:, None] < N) & k_mask[None, :], other=0.0)
acc += tl.dot(a_block, tl.trans(b_block))
else:
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_block, b_block)
if TRANSPOSE_B:
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K
else:
a_ptrs += BLOCK_K
b_ptrs += BLOCK_K * N
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < m_size) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
def group_gemm_same_nk(
a: torch.Tensor,
b: torch.Tensor,
cumsum_M: torch.Tensor,
max_M: int,
transpose_b: bool = False,
) -> torch.Tensor:
"""Grouped GEMM where all groups share same N, K dimensions but variable M.
Args:
a: (total_M, K) input tensor, rows grouped by expert
b: (G, N, K) if transpose_b else (G, K, N) weight tensor
cumsum_M: (G,) cumulative token counts per expert
max_M: maximum tokens any single expert has
transpose_b: if True, compute a @ b.T; else compute a @ b
Returns:
c: (total_M, N) output tensor
"""
if transpose_b:
G, N, K = b.shape
else:
G, K, N = b.shape
c = torch.empty((a.shape[0], N), dtype=a.dtype, device=a.device)
_group_gemm_same_nk_kernel[
(lambda meta: (triton.cdiv(max_M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))
](
a_ptr=a,
b_ptr=b,
c_ptr=c,
cumsum_M=cumsum_M,
max_M=max_M,
G=G,
N=N,
K=K,
TRANSPOSE_B=transpose_b,
)
return c
# ---------------------------------------------------------------------------
# group_gemm_same_mn: All experts share same M, N (weight dims); variable K
# Used for: backward dW (grad.T @ input)
# ---------------------------------------------------------------------------
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=4, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8}, num_warps=8, num_stages=3),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP": 8}, num_warps=8, num_stages=3),
],
key=["M", "N"],
)
@triton.jit
def _group_gemm_same_mn_kernel(
a_ptr,
b_ptr,
c_ptr,
cumsum_K,
G: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP: tl.constexpr,
):
pid_m, pid_n = _get_pid_mn(tl.program_id(0), M, N, BLOCK_M, BLOCK_N, GROUP)
gid = tl.program_id(1).to(tl.int64)
gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0).to(tl.int64)
gtid_end = tl.load(cumsum_K + gid).to(tl.int64)
k_size = gtid_end - gtid_start
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# c is (G, M, N)
c_base = c_ptr + gid * M * N
if k_size == 0:
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, tl.zeros((BLOCK_M, BLOCK_N), dtype=c_ptr.dtype.element_ty), mask=c_mask)
return
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_k = tl.arange(0, BLOCK_K)
# a is (total_K, M), compute a.T @ b -> (M, N)
# b is (total_K, N)
a_base = a_ptr + gtid_start * M
b_base = b_ptr + gtid_start * N
for k_start in range(0, k_size, BLOCK_K):
k_offs = k_start + offs_k
k_mask = k_offs < k_size
a_ptrs = a_base + k_offs[:, None] * M + offs_m[None, :]
a_block_t = tl.trans(tl.load(a_ptrs, mask=k_mask[:, None] & (offs_m[None, :] < M), other=0.0))
# Load b block: (BLOCK_K, BLOCK_N)
b_ptrs = b_base + k_offs[:, None] * N + offs_n[None, :]
b_block = tl.load(b_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < N), other=0.0)
acc += tl.dot(a_block_t, b_block)
c_ptrs = c_base + offs_m[:, None] * N + offs_n[None, :]
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
def group_gemm_same_mn(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
cumsum_K: torch.Tensor,
) -> None:
"""Grouped GEMM where all groups produce same (M, N) output; variable K reduction.
Computes: c[g] = a[s:e].T @ b[s:e] for each group g,
where s, e are defined by cumsum_K boundaries.
Args:
a: (total_K, M) input tensor grouped by expert
b: (total_K, N) input tensor grouped by expert
c: (G, M, N) output tensor (pre-allocated)
cumsum_K: (G,) cumulative token counts per expert
"""
G, M, N = c.shape
_group_gemm_same_mn_kernel[(lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), G))](
a_ptr=a,
b_ptr=b,
c_ptr=c,
cumsum_K=cumsum_K,
G=G,
M=M,
N=N,
)
# ---------------------------------------------------------------------------
# moe_scatter: Dispatch tokens to sorted expert buffer positions
# ---------------------------------------------------------------------------
@triton.jit
def _moe_scatter_kernel(
x_ptr,
out_ptr,
index_ptr,
M,
N: tl.constexpr,
TOPK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Scatter: for each token i, copy x[i] to out[index[i, k]] for k in 0..topk-1."""
pid_m = tl.program_id(0).to(tl.int64)
pid_n = tl.program_id(1)
if pid_m >= M:
return
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
# Load input row
x_ptrs = x_ptr + pid_m * N + offs_n
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0)
# Store to each topk destination
for k in tl.static_range(TOPK):
dst_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
out_ptrs = out_ptr + dst_idx * N + offs_n
tl.store(out_ptrs, x_vals, mask=n_mask)
def moe_scatter(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Scatter tokens to sorted expert buffer.
For each token i and topk slot k, copies x[i] to output[index[i, k]].
Args:
x: (M, N) input hidden states
index: (M, topk) scatter indices
Returns:
out: (M * topk, N) scattered output
"""
M, N = x.shape
topk = index.shape[1]
out = torch.empty(M * topk, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
grid = (M, triton.cdiv(N, BLOCK_N))
_moe_scatter_kernel[grid](
x_ptr=x,
out_ptr=out,
index_ptr=index,
M=M,
N=N,
TOPK=topk,
BLOCK_N=BLOCK_N,
)
return out
# ---------------------------------------------------------------------------
# moe_gather: Reduce expert outputs back to token positions (sum over topk)
# ---------------------------------------------------------------------------
@triton.jit
def _moe_gather_kernel(
x_ptr,
out_ptr,
index_ptr,
M,
N: tl.constexpr,
TOPK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Gather: for each token i, out[i] = sum_k(x[index[i, k]]) over topk."""
pid_m = tl.program_id(0).to(tl.int64)
pid_n = tl.program_id(1)
if pid_m >= M:
return
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
n_mask = offs_n < N
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
for k in tl.static_range(TOPK):
src_idx = tl.load(index_ptr + pid_m * TOPK + k).to(tl.int64)
x_ptrs = x_ptr + src_idx * N + offs_n
x_vals = tl.load(x_ptrs, mask=n_mask, other=0.0).to(tl.float32)
acc += x_vals
out_ptrs = out_ptr + pid_m * N + offs_n
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=n_mask)
def moe_gather(x: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
"""Gather and reduce expert outputs back to original token positions.
For each token i, sums x[index[i, k]] over all topk slots.
Args:
x: (M * topk, N) expert outputs in sorted buffer
index: (M, topk) scatter indices (same as used in moe_scatter)
Returns:
out: (M, N) gathered output
"""
M, topk = index.shape
N = x.shape[1]
out = torch.empty(M, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
grid = (M, triton.cdiv(N, BLOCK_N))
_moe_gather_kernel[grid](
x_ptr=x,
out_ptr=out,
index_ptr=index,
M=M,
N=N,
TOPK=topk,
BLOCK_N=BLOCK_N,
)
return out

View File

@@ -381,7 +381,7 @@ class FSDP2Engine:
with torch.no_grad(): with torch.no_grad():
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if isinstance(grad_norm, torch.distributed._tensor.DTensor): if isinstance(grad_norm, torch.distributed.tensor.DTensor):
grad_norm = grad_norm.full_tensor() grad_norm = grad_norm.full_tensor()
for param in model.parameters(): for param in model.parameters():

View File

@@ -78,14 +78,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
@DistributedPlugin("deepspeed").register("save_checkpoint") @DistributedPlugin("deepspeed").register("save_checkpoint")
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None: def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .deepspeed import save_checkpoint from .deepspeed import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir) return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("deepspeed").register("load_checkpoint") @DistributedPlugin("deepspeed").register("load_checkpoint")
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None: def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .deepspeed import load_checkpoint from .deepspeed import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir) return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)

View File

@@ -0,0 +1,183 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from ..accelerator.interface import Dim, DistributedInterface
from ..config import InputArgument, TrainingArguments, get_args
from ..config.arg_utils import ModelClass
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..utils import logging
from ..utils.types import BatchInput, HFModel, Tensor
logger = logging.get_logger(__name__)
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
"""Validate RM dataset format early for clearer error messages."""
if len(train_dataset) == 0:
raise ValueError(f"RM training dataset is empty: {dataset_path}")
sample = train_dataset[0]
if "chosen_messages" in sample and "rejected_messages" in sample:
return
dataset_name = sample.get("_dataset_name", "unknown")
sample_keys = sorted(sample.keys())
raise ValueError(
"RM training requires pair-format samples containing chosen/rejected responses. "
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
"or set converter='pair' for raw chosen/rejected fields)."
)
def _init_score_head(model: HFModel) -> None:
"""Initialize the score head for RM training with small Gaussian weights.
Uses Gaussian initialization so that different parameters have distinct values,
providing better gradient flow than zero initialization while keeping initial
scores small enough that the starting loss is close to ln(2).
"""
unwrapped = model.module if hasattr(model, "module") else model
score = getattr(unwrapped, "score", None)
if score is not None and hasattr(score, "weight"):
hidden_size = score.weight.shape[-1]
std = 1.0 / (hidden_size * 10)
with torch.no_grad():
score.weight.normal_(mean=0.0, std=std)
if score.bias is not None:
score.bias.zero_()
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
class RMTrainer(BaseTrainer):
def __init__(
self,
args: TrainingArguments,
model: HFModel,
renderer,
train_dataset,
callbacks=None,
) -> None:
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
if cp_size > 1:
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
super().__init__(args, model, renderer, train_dataset, callbacks)
def _shard_model(self) -> None:
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
else:
super()._shard_model()
@property
def _unwrapped_model(self):
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
model = self.model
if hasattr(model, "module"):
model = model.module
return model
def compute_loss(self, batch: BatchInput) -> Tensor:
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
token_type_ids = batch.get("token_type_ids")
if token_type_ids is None:
raise ValueError(
"RM training requires pair data with token_type_ids. "
"Ensure the dataset has chosen_messages/rejected_messages."
)
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
# Transformers v5 models natively support this format in _update_causal_mask,
# constructing the correct block-diagonal causal mask internally for all attention backends.
model_attention_mask = token_type_ids
# Build position_ids that reset at each document boundary.
batch_size, seq_len = token_type_ids.shape
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
chosen_mask = token_type_ids == 1
rejected_mask = token_type_ids == 2
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
position_ids = torch.zeros_like(token_type_ids)
position_ids[chosen_mask] = arange[chosen_mask]
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
model_output = self.model(
input_ids=input_ids,
attention_mask=model_attention_mask,
position_ids=position_ids,
use_cache=False,
return_dict=True,
)
rewards = model_output.logits.float().squeeze(-1)
chosen_mask = token_type_ids == 1
rejected_mask = token_type_ids == 2
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
if not torch.any(valid_pair_mask):
raise ValueError(
"No valid RM pairs found in this micro-batch. "
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
)
rewards = rewards[valid_pair_mask]
chosen_mask = chosen_mask[valid_pair_mask]
rejected_mask = rejected_mask[valid_pair_mask]
seq_len = rewards.size(-1)
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
def run_rm(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
model_args.model_class = ModelClass.CLS
DistributedInterface(training_args.dist_config)
train_dataset = DataEngine(data_args.train_dataset)
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
model_engine = ModelEngine(model_args, is_train=True)
_init_score_head(model_engine.model)
trainer = RMTrainer(
args=training_args,
model=model_engine.model,
renderer=model_engine.renderer,
train_dataset=train_dataset,
)
trainer.fit()
trainer.save_model()
DistributedInterface().destroy()
if __name__ == "__main__":
run_rm()

View File

@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
return return
# Human-readable output to stdout # Human-readable output to stdout
display_logs = {**logs, "total_steps": state.num_training_steps} display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items()) parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
logger.info_rank0(parts) logger.info_rank0(parts)