[v1] support reward training stage (#10431)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
浮梦
2026-05-20 20:46:52 +08:00
committed by GitHub
parent 40e786d016
commit 8b5ea65770
9 changed files with 217 additions and 17 deletions

View File

@@ -134,6 +134,9 @@ class BaseTrainer:
global_step=self.global_step,
epoch=self._resume_epoch,
)
# Keep callback state aligned with checkpoint-resumed trainer counters.
self.state.global_step = self.global_step
self.state.epoch = self._resume_epoch
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -303,7 +306,7 @@ class BaseTrainer:
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"step": self.state.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
@@ -335,7 +338,9 @@ class BaseTrainer:
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
model_to_save.save_pretrained(
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
)
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -143,6 +143,12 @@ class ModelEngine:
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
self.model_config.num_labels = 1
self.model_config.classifier_dropout = 0.0
text_config = getattr(self.model_config, "text_config", None)
if text_config is not None:
text_config.num_labels = 1
text_config.classifier_dropout = 0.0
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel

View File

@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader(
self.dataset,
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
generator=generato_seed,
generator=generator_seed,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)

View File

@@ -172,7 +172,7 @@ def _save_standard_training_states(
if rank == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_dir = os.path.join(ckpt_dir, "model")
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
processor.save_pretrained(model_dir)
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
@@ -212,7 +212,11 @@ def _load_standard_training_states(
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
if state_dict:
model_to_load.load_state_dict(state_dict)
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
if incompatible_keys.missing_keys:
raise RuntimeError(
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
)
else:
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")

View File

@@ -148,7 +148,9 @@ def launch():
elif command == "dpo":
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
raise NotImplementedError("RM trainer is not implemented yet.")
from llamafactory.v1.trainers.rm_trainer import run_rm
run_rm()
else:
print(f"Unknown command: {command}.\n{USAGE}")
@@ -175,9 +177,9 @@ def main():
# run_dpo()
raise NotImplementedError("DPO trainer is not implemented yet.")
elif command == "rm":
# from llamafactory.v1.trainers.rm_trainer import run_rm
# run_rm()
raise NotImplementedError("RM trainer is not implemented yet.")
from llamafactory.v1.trainers.rm_trainer import run_rm
run_rm()
if __name__ == "__main__":

View File

@@ -381,7 +381,7 @@ class FSDP2Engine:
with torch.no_grad():
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
grad_norm = grad_norm.full_tensor()
for param in model.parameters():

View File

@@ -78,14 +78,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
@DistributedPlugin("deepspeed").register("save_checkpoint")
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .deepspeed import save_checkpoint
return save_checkpoint(model, optimizer, ckpt_dir)
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
@DistributedPlugin("deepspeed").register("load_checkpoint")
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
from .deepspeed import load_checkpoint
return load_checkpoint(model, optimizer, ckpt_dir)
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)

View File

@@ -0,0 +1,183 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from ..accelerator.interface import Dim, DistributedInterface
from ..config import InputArgument, TrainingArguments, get_args
from ..config.arg_utils import ModelClass
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..utils import logging
from ..utils.types import BatchInput, HFModel, Tensor
logger = logging.get_logger(__name__)
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
"""Validate RM dataset format early for clearer error messages."""
if len(train_dataset) == 0:
raise ValueError(f"RM training dataset is empty: {dataset_path}")
sample = train_dataset[0]
if "chosen_messages" in sample and "rejected_messages" in sample:
return
dataset_name = sample.get("_dataset_name", "unknown")
sample_keys = sorted(sample.keys())
raise ValueError(
"RM training requires pair-format samples containing chosen/rejected responses. "
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
"or set converter='pair' for raw chosen/rejected fields)."
)
def _init_score_head(model: HFModel) -> None:
"""Initialize the score head for RM training with small Gaussian weights.
Uses Gaussian initialization so that different parameters have distinct values,
providing better gradient flow than zero initialization while keeping initial
scores small enough that the starting loss is close to ln(2).
"""
unwrapped = model.module if hasattr(model, "module") else model
score = getattr(unwrapped, "score", None)
if score is not None and hasattr(score, "weight"):
hidden_size = score.weight.shape[-1]
std = 1.0 / (hidden_size * 10)
with torch.no_grad():
score.weight.normal_(mean=0.0, std=std)
if score.bias is not None:
score.bias.zero_()
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
class RMTrainer(BaseTrainer):
def __init__(
self,
args: TrainingArguments,
model: HFModel,
renderer,
train_dataset,
callbacks=None,
) -> None:
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
if cp_size > 1:
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
super().__init__(args, model, renderer, train_dataset, callbacks)
def _shard_model(self) -> None:
if self.args.dist_config is None:
if DistributedInterface().get_world_size(Dim.DP) > 1:
from torch.nn.parallel import DistributedDataParallel as DDP
device_ids = None if self.device.type == "cpu" else [self.device.index]
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
else:
super()._shard_model()
@property
def _unwrapped_model(self):
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
model = self.model
if hasattr(model, "module"):
model = model.module
return model
def compute_loss(self, batch: BatchInput) -> Tensor:
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
token_type_ids = batch.get("token_type_ids")
if token_type_ids is None:
raise ValueError(
"RM training requires pair data with token_type_ids. "
"Ensure the dataset has chosen_messages/rejected_messages."
)
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
# Transformers v5 models natively support this format in _update_causal_mask,
# constructing the correct block-diagonal causal mask internally for all attention backends.
model_attention_mask = token_type_ids
# Build position_ids that reset at each document boundary.
batch_size, seq_len = token_type_ids.shape
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
chosen_mask = token_type_ids == 1
rejected_mask = token_type_ids == 2
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
position_ids = torch.zeros_like(token_type_ids)
position_ids[chosen_mask] = arange[chosen_mask]
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
model_output = self.model(
input_ids=input_ids,
attention_mask=model_attention_mask,
position_ids=position_ids,
use_cache=False,
return_dict=True,
)
rewards = model_output.logits.float().squeeze(-1)
chosen_mask = token_type_ids == 1
rejected_mask = token_type_ids == 2
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
if not torch.any(valid_pair_mask):
raise ValueError(
"No valid RM pairs found in this micro-batch. "
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
)
rewards = rewards[valid_pair_mask]
chosen_mask = chosen_mask[valid_pair_mask]
rejected_mask = rejected_mask[valid_pair_mask]
seq_len = rewards.size(-1)
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
def run_rm(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
model_args.model_class = ModelClass.CLS
DistributedInterface(training_args.dist_config)
train_dataset = DataEngine(data_args.train_dataset)
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
model_engine = ModelEngine(model_args, is_train=True)
_init_score_head(model_engine.model)
trainer = RMTrainer(
args=training_args,
model=model_engine.model,
renderer=model_engine.renderer,
train_dataset=train_dataset,
)
trainer.fit()
trainer.save_model()
DistributedInterface().destroy()
if __name__ == "__main__":
run_rm()

View File

@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
return
# Human-readable output to stdout
display_logs = {**logs, "total_steps": state.num_training_steps}
display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
logger.info_rank0(parts)