mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -134,6 +134,9 @@ class BaseTrainer:
|
||||
global_step=self.global_step,
|
||||
epoch=self._resume_epoch,
|
||||
)
|
||||
# Keep callback state aligned with checkpoint-resumed trainer counters.
|
||||
self.state.global_step = self.global_step
|
||||
self.state.epoch = self._resume_epoch
|
||||
|
||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||
@@ -303,7 +306,7 @@ class BaseTrainer:
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"epoch": epoch,
|
||||
"step": self.global_step,
|
||||
"step": self.state.global_step,
|
||||
"loss": step_loss,
|
||||
"grad_norm": grad_norm,
|
||||
"learning_rate": current_lr,
|
||||
@@ -335,7 +338,9 @@ class BaseTrainer:
|
||||
)
|
||||
else:
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
model_to_save.save_pretrained(
|
||||
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
|
||||
)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
|
||||
@@ -143,6 +143,12 @@ class ModelEngine:
|
||||
elif self.args.model_class == ModelClass.CLS:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
self.model_config.num_labels = 1
|
||||
self.model_config.classifier_dropout = 0.0
|
||||
text_config = getattr(self.model_config, "text_config", None)
|
||||
if text_config is not None:
|
||||
text_config.num_labels = 1
|
||||
text_config.classifier_dropout = 0.0
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
generator_seed = torch.Generator()
|
||||
generator_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
generator=generator_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
|
||||
@@ -172,7 +172,7 @@ def _save_standard_training_states(
|
||||
if rank == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_dir = os.path.join(ckpt_dir, "model")
|
||||
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
|
||||
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
|
||||
processor.save_pretrained(model_dir)
|
||||
|
||||
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
||||
@@ -212,7 +212,11 @@ def _load_standard_training_states(
|
||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
||||
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
||||
if state_dict:
|
||||
model_to_load.load_state_dict(state_dict)
|
||||
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
|
||||
if incompatible_keys.missing_keys:
|
||||
raise RuntimeError(
|
||||
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
||||
|
||||
|
||||
@@ -148,7 +148,9 @@ def launch():
|
||||
elif command == "dpo":
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
|
||||
run_rm()
|
||||
|
||||
else:
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
@@ -175,9 +177,9 @@ def main():
|
||||
# run_dpo()
|
||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||
elif command == "rm":
|
||||
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
# run_rm()
|
||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
||||
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||
|
||||
run_rm()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -381,7 +381,7 @@ class FSDP2Engine:
|
||||
|
||||
with torch.no_grad():
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
|
||||
grad_norm = grad_norm.full_tensor()
|
||||
|
||||
for param in model.parameters():
|
||||
|
||||
@@ -78,14 +78,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
from .deepspeed import save_checkpoint
|
||||
|
||||
return save_checkpoint(model, optimizer, ckpt_dir)
|
||||
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||
from .deepspeed import load_checkpoint
|
||||
|
||||
return load_checkpoint(model, optimizer, ckpt_dir)
|
||||
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import InputArgument, TrainingArguments, get_args
|
||||
from ..config.arg_utils import ModelClass
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..utils import logging
|
||||
from ..utils.types import BatchInput, HFModel, Tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
|
||||
"""Validate RM dataset format early for clearer error messages."""
|
||||
if len(train_dataset) == 0:
|
||||
raise ValueError(f"RM training dataset is empty: {dataset_path}")
|
||||
|
||||
sample = train_dataset[0]
|
||||
if "chosen_messages" in sample and "rejected_messages" in sample:
|
||||
return
|
||||
|
||||
dataset_name = sample.get("_dataset_name", "unknown")
|
||||
sample_keys = sorted(sample.keys())
|
||||
raise ValueError(
|
||||
"RM training requires pair-format samples containing chosen/rejected responses. "
|
||||
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
|
||||
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
|
||||
"or set converter='pair' for raw chosen/rejected fields)."
|
||||
)
|
||||
|
||||
|
||||
def _init_score_head(model: HFModel) -> None:
|
||||
"""Initialize the score head for RM training with small Gaussian weights.
|
||||
|
||||
Uses Gaussian initialization so that different parameters have distinct values,
|
||||
providing better gradient flow than zero initialization while keeping initial
|
||||
scores small enough that the starting loss is close to ln(2).
|
||||
"""
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
score = getattr(unwrapped, "score", None)
|
||||
if score is not None and hasattr(score, "weight"):
|
||||
hidden_size = score.weight.shape[-1]
|
||||
std = 1.0 / (hidden_size * 10)
|
||||
with torch.no_grad():
|
||||
score.weight.normal_(mean=0.0, std=std)
|
||||
if score.bias is not None:
|
||||
score.bias.zero_()
|
||||
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
|
||||
|
||||
|
||||
class RMTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
renderer,
|
||||
train_dataset,
|
||||
callbacks=None,
|
||||
) -> None:
|
||||
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
|
||||
if cp_size > 1:
|
||||
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
|
||||
|
||||
super().__init__(args, model, renderer, train_dataset, callbacks)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
if self.args.dist_config is None:
|
||||
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
|
||||
else:
|
||||
super()._shard_model()
|
||||
|
||||
@property
|
||||
def _unwrapped_model(self):
|
||||
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
|
||||
model = self.model
|
||||
if hasattr(model, "module"):
|
||||
model = model.module
|
||||
return model
|
||||
|
||||
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
|
||||
|
||||
token_type_ids = batch.get("token_type_ids")
|
||||
if token_type_ids is None:
|
||||
raise ValueError(
|
||||
"RM training requires pair data with token_type_ids. "
|
||||
"Ensure the dataset has chosen_messages/rejected_messages."
|
||||
)
|
||||
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
|
||||
|
||||
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
|
||||
# Transformers v5 models natively support this format in _update_causal_mask,
|
||||
# constructing the correct block-diagonal causal mask internally for all attention backends.
|
||||
model_attention_mask = token_type_ids
|
||||
|
||||
# Build position_ids that reset at each document boundary.
|
||||
batch_size, seq_len = token_type_ids.shape
|
||||
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||
chosen_mask = token_type_ids == 1
|
||||
rejected_mask = token_type_ids == 2
|
||||
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
|
||||
position_ids = torch.zeros_like(token_type_ids)
|
||||
position_ids[chosen_mask] = arange[chosen_mask]
|
||||
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
|
||||
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=model_attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
rewards = model_output.logits.float().squeeze(-1)
|
||||
|
||||
chosen_mask = token_type_ids == 1
|
||||
rejected_mask = token_type_ids == 2
|
||||
|
||||
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
|
||||
if not torch.any(valid_pair_mask):
|
||||
raise ValueError(
|
||||
"No valid RM pairs found in this micro-batch. "
|
||||
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
|
||||
)
|
||||
|
||||
rewards = rewards[valid_pair_mask]
|
||||
chosen_mask = chosen_mask[valid_pair_mask]
|
||||
rejected_mask = rejected_mask[valid_pair_mask]
|
||||
|
||||
seq_len = rewards.size(-1)
|
||||
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
|
||||
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
|
||||
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
|
||||
|
||||
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
|
||||
|
||||
|
||||
def run_rm(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
model_args.model_class = ModelClass.CLS
|
||||
DistributedInterface(training_args.dist_config)
|
||||
train_dataset = DataEngine(data_args.train_dataset)
|
||||
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
_init_score_head(model_engine.model)
|
||||
trainer = RMTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
renderer=model_engine.renderer,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.fit()
|
||||
trainer.save_model()
|
||||
DistributedInterface().destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_rm()
|
||||
|
||||
@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
|
||||
return
|
||||
|
||||
# Human-readable output to stdout
|
||||
display_logs = {**logs, "total_steps": state.num_training_steps}
|
||||
display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
|
||||
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||
logger.info_rank0(parts)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user