From 8b5ea65770396e9c79718d9ffbd7c7de14b37ed7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=AE=E6=A2=A6?= <46097299+frozenleaves@users.noreply.github.com> Date: Wed, 20 May 2026 20:46:52 +0800 Subject: [PATCH] [v1] support reward training stage (#10431) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/v1/core/base_trainer.py | 9 +- src/llamafactory/v1/core/model_engine.py | 6 + src/llamafactory/v1/core/utils/batching.py | 6 +- src/llamafactory/v1/core/utils/checkpoint.py | 8 +- src/llamafactory/v1/launcher.py | 10 +- .../trainer_plugins/distributed/fsdp2.py | 2 +- .../trainer_plugins/distributed/hub.py | 8 +- src/llamafactory/v1/trainers/rm_trainer.py | 183 ++++++++++++++++++ .../v1/utils/callbacks/logging_callback.py | 2 +- 9 files changed, 217 insertions(+), 17 deletions(-) diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 97289c94a..ff1a60539 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -134,6 +134,9 @@ class BaseTrainer: global_step=self.global_step, epoch=self._resume_epoch, ) + # Keep callback state aligned with checkpoint-resumed trainer counters. + self.state.global_step = self.global_step + self.state.epoch = self._resume_epoch if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1: # qwen3.5 is not supported because of the different attention implementation, which will be supported in the future. @@ -303,7 +306,7 @@ class BaseTrainer: if self.global_step % self.args.logging_steps == 0: logs = { "epoch": epoch, - "step": self.global_step, + "step": self.state.global_step, "loss": step_loss, "grad_norm": grad_norm, "learning_rate": current_lr, @@ -335,7 +338,9 @@ class BaseTrainer: ) else: model_to_save = self.model.module if hasattr(self.model, "module") else self.model - model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB") + model_to_save.save_pretrained( + self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB" + ) self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB") logger.info_rank0(f"Model saved to {self.args.output_dir}") diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index 8e16e5363..629e76f5b 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -143,6 +143,12 @@ class ModelEngine: elif self.args.model_class == ModelClass.CLS: from transformers import AutoModelForTokenClassification + self.model_config.num_labels = 1 + self.model_config.classifier_dropout = 0.0 + text_config = getattr(self.model_config, "text_config", None) + if text_config is not None: + text_config.num_labels = 1 + text_config.classifier_dropout = 0.0 AutoClass = AutoModelForTokenClassification else: from transformers import AutoModel diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index 25a626b94..3b1ea5c8f 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -137,8 +137,8 @@ class BatchGenerator(Iterator): else: raise NotImplementedError("Iterable dataset is not supported yet.") - generato_seed = torch.Generator() - generato_seed.manual_seed(self.seed) + generator_seed = torch.Generator() + generator_seed.manual_seed(self.seed) self._data_provider = StatefulDataLoader( self.dataset, @@ -149,7 +149,7 @@ class BatchGenerator(Iterator): pin_memory=self.pin_memory, pin_memory_device=DistributedInterface().current_device.type, drop_last=self.drop_last, - generator=generato_seed, + generator=generator_seed, ) if self.batching_strategy == BatchingStrategy.NORMAL: self._length = len(self._data_provider) diff --git a/src/llamafactory/v1/core/utils/checkpoint.py b/src/llamafactory/v1/core/utils/checkpoint.py index b9cfcd911..40ed4f01c 100644 --- a/src/llamafactory/v1/core/utils/checkpoint.py +++ b/src/llamafactory/v1/core/utils/checkpoint.py @@ -172,7 +172,7 @@ def _save_standard_training_states( if rank == 0: model_to_save = model.module if hasattr(model, "module") else model model_dir = os.path.join(ckpt_dir, "model") - model_to_save.save_pretrained(model_dir, max_shard_size="4GB") + model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB") processor.save_pretrained(model_dir) os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True) @@ -212,7 +212,11 @@ def _load_standard_training_states( for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))): state_dict.update(torch.load(f, map_location="cpu", weights_only=True)) if state_dict: - model_to_load.load_state_dict(state_dict) + incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False) + if incompatible_keys.missing_keys: + raise RuntimeError( + f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}." + ) else: logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.") diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py index 4d78787d7..f0481d4e7 100644 --- a/src/llamafactory/v1/launcher.py +++ b/src/llamafactory/v1/launcher.py @@ -148,7 +148,9 @@ def launch(): elif command == "dpo": raise NotImplementedError("DPO trainer is not implemented yet.") elif command == "rm": - raise NotImplementedError("RM trainer is not implemented yet.") + from llamafactory.v1.trainers.rm_trainer import run_rm + + run_rm() else: print(f"Unknown command: {command}.\n{USAGE}") @@ -175,9 +177,9 @@ def main(): # run_dpo() raise NotImplementedError("DPO trainer is not implemented yet.") elif command == "rm": - # from llamafactory.v1.trainers.rm_trainer import run_rm - # run_rm() - raise NotImplementedError("RM trainer is not implemented yet.") + from llamafactory.v1.trainers.rm_trainer import run_rm + + run_rm() if __name__ == "__main__": diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index 88bdb4f4e..4fbb5b61f 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -381,7 +381,7 @@ class FSDP2Engine: with torch.no_grad(): grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - if isinstance(grad_norm, torch.distributed._tensor.DTensor): + if isinstance(grad_norm, torch.distributed.tensor.DTensor): grad_norm = grad_norm.full_tensor() for param in model.parameters(): diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py index f7389b28d..2a4a9e392 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py @@ -78,14 +78,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) @DistributedPlugin("deepspeed").register("save_checkpoint") -def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None: +def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: from .deepspeed import save_checkpoint - return save_checkpoint(model, optimizer, ckpt_dir) + return save_checkpoint(model, optimizer, ckpt_dir, **kwargs) @DistributedPlugin("deepspeed").register("load_checkpoint") -def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None: +def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None: from .deepspeed import load_checkpoint - return load_checkpoint(model, optimizer, ckpt_dir) + return load_checkpoint(model, optimizer, ckpt_dir, **kwargs) diff --git a/src/llamafactory/v1/trainers/rm_trainer.py b/src/llamafactory/v1/trainers/rm_trainer.py index e69de29bb..6c78fa6fa 100644 --- a/src/llamafactory/v1/trainers/rm_trainer.py +++ b/src/llamafactory/v1/trainers/rm_trainer.py @@ -0,0 +1,183 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F + +from ..accelerator.interface import Dim, DistributedInterface +from ..config import InputArgument, TrainingArguments, get_args +from ..config.arg_utils import ModelClass +from ..core.base_trainer import BaseTrainer +from ..core.data_engine import DataEngine +from ..core.model_engine import ModelEngine +from ..utils import logging +from ..utils.types import BatchInput, HFModel, Tensor + + +logger = logging.get_logger(__name__) + + +def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None: + """Validate RM dataset format early for clearer error messages.""" + if len(train_dataset) == 0: + raise ValueError(f"RM training dataset is empty: {dataset_path}") + + sample = train_dataset[0] + if "chosen_messages" in sample and "rejected_messages" in sample: + return + + dataset_name = sample.get("_dataset_name", "unknown") + sample_keys = sorted(sample.keys()) + raise ValueError( + "RM training requires pair-format samples containing chosen/rejected responses. " + f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. " + "Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, " + "or set converter='pair' for raw chosen/rejected fields)." + ) + + +def _init_score_head(model: HFModel) -> None: + """Initialize the score head for RM training with small Gaussian weights. + + Uses Gaussian initialization so that different parameters have distinct values, + providing better gradient flow than zero initialization while keeping initial + scores small enough that the starting loss is close to ln(2). + """ + unwrapped = model.module if hasattr(model, "module") else model + score = getattr(unwrapped, "score", None) + if score is not None and hasattr(score, "weight"): + hidden_size = score.weight.shape[-1] + std = 1.0 / (hidden_size * 10) + with torch.no_grad(): + score.weight.normal_(mean=0.0, std=std) + if score.bias is not None: + score.bias.zero_() + logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}") + + +class RMTrainer(BaseTrainer): + def __init__( + self, + args: TrainingArguments, + model: HFModel, + renderer, + train_dataset, + callbacks=None, + ) -> None: + cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1 + if cp_size > 1: + raise NotImplementedError("RM trainer currently only supports cp_size == 1.") + + super().__init__(args, model, renderer, train_dataset, callbacks) + + def _shard_model(self) -> None: + if self.args.dist_config is None: + if DistributedInterface().get_world_size(Dim.DP) > 1: + from torch.nn.parallel import DistributedDataParallel as DDP + + device_ids = None if self.device.type == "cpu" else [self.device.index] + self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True) + else: + super()._shard_model() + + @property + def _unwrapped_model(self): + """Access the underlying model, unwrapping DDP/FSDP wrappers if present.""" + model = self.model + if hasattr(model, "module"): + model = model.module + return model + + def compute_loss(self, batch: BatchInput) -> Tensor: + input_ids = batch["input_ids"].to(self.device, non_blocking=True) + + token_type_ids = batch.get("token_type_ids") + if token_type_ids is None: + raise ValueError( + "RM training requires pair data with token_type_ids. " + "Ensure the dataset has chosen_messages/rejected_messages." + ) + token_type_ids = token_type_ids.to(self.device, non_blocking=True) + + # Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding). + # Transformers v5 models natively support this format in _update_causal_mask, + # constructing the correct block-diagonal causal mask internally for all attention backends. + model_attention_mask = token_type_ids + + # Build position_ids that reset at each document boundary. + batch_size, seq_len = token_type_ids.shape + arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1) + chosen_mask = token_type_ids == 1 + rejected_mask = token_type_ids == 2 + chosen_lens = chosen_mask.sum(dim=1, keepdim=True) + position_ids = torch.zeros_like(token_type_ids) + position_ids[chosen_mask] = arange[chosen_mask] + position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask] + + model_output = self.model( + input_ids=input_ids, + attention_mask=model_attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=True, + ) + + rewards = model_output.logits.float().squeeze(-1) + + chosen_mask = token_type_ids == 1 + rejected_mask = token_type_ids == 2 + + valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1) + if not torch.any(valid_pair_mask): + raise ValueError( + "No valid RM pairs found in this micro-batch. " + "This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens." + ) + + rewards = rewards[valid_pair_mask] + chosen_mask = chosen_mask[valid_pair_mask] + rejected_mask = rejected_mask[valid_pair_mask] + + seq_len = rewards.size(-1) + position_index = torch.arange(seq_len, device=self.device).unsqueeze(0) + chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values + rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values + + chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1) + rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1) + return -F.logsigmoid(chosen_scores - rejected_scores).mean() + + +def run_rm(args: InputArgument = None): + model_args, data_args, training_args, _ = get_args(args) + model_args.model_class = ModelClass.CLS + DistributedInterface(training_args.dist_config) + train_dataset = DataEngine(data_args.train_dataset) + _validate_rm_dataset_format(train_dataset, data_args.train_dataset) + model_engine = ModelEngine(model_args, is_train=True) + _init_score_head(model_engine.model) + trainer = RMTrainer( + args=training_args, + model=model_engine.model, + renderer=model_engine.renderer, + train_dataset=train_dataset, + ) + trainer.fit() + trainer.save_model() + DistributedInterface().destroy() + + +if __name__ == "__main__": + run_rm() diff --git a/src/llamafactory/v1/utils/callbacks/logging_callback.py b/src/llamafactory/v1/utils/callbacks/logging_callback.py index d6bdba604..f1674cfbb 100644 --- a/src/llamafactory/v1/utils/callbacks/logging_callback.py +++ b/src/llamafactory/v1/utils/callbacks/logging_callback.py @@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback): return # Human-readable output to stdout - display_logs = {**logs, "total_steps": state.num_training_steps} + display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps} parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items()) logger.info_rank0(parts)