mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -134,6 +134,9 @@ class BaseTrainer:
|
|||||||
global_step=self.global_step,
|
global_step=self.global_step,
|
||||||
epoch=self._resume_epoch,
|
epoch=self._resume_epoch,
|
||||||
)
|
)
|
||||||
|
# Keep callback state aligned with checkpoint-resumed trainer counters.
|
||||||
|
self.state.global_step = self.global_step
|
||||||
|
self.state.epoch = self._resume_epoch
|
||||||
|
|
||||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||||
@@ -303,7 +306,7 @@ class BaseTrainer:
|
|||||||
if self.global_step % self.args.logging_steps == 0:
|
if self.global_step % self.args.logging_steps == 0:
|
||||||
logs = {
|
logs = {
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"step": self.global_step,
|
"step": self.state.global_step,
|
||||||
"loss": step_loss,
|
"loss": step_loss,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"learning_rate": current_lr,
|
"learning_rate": current_lr,
|
||||||
@@ -335,7 +338,9 @@ class BaseTrainer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
model_to_save.save_pretrained(
|
||||||
|
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
|
||||||
|
)
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|
||||||
|
|||||||
@@ -143,6 +143,12 @@ class ModelEngine:
|
|||||||
elif self.args.model_class == ModelClass.CLS:
|
elif self.args.model_class == ModelClass.CLS:
|
||||||
from transformers import AutoModelForTokenClassification
|
from transformers import AutoModelForTokenClassification
|
||||||
|
|
||||||
|
self.model_config.num_labels = 1
|
||||||
|
self.model_config.classifier_dropout = 0.0
|
||||||
|
text_config = getattr(self.model_config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
text_config.num_labels = 1
|
||||||
|
text_config.classifier_dropout = 0.0
|
||||||
AutoClass = AutoModelForTokenClassification
|
AutoClass = AutoModelForTokenClassification
|
||||||
else:
|
else:
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|||||||
@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||||
|
|
||||||
generato_seed = torch.Generator()
|
generator_seed = torch.Generator()
|
||||||
generato_seed.manual_seed(self.seed)
|
generator_seed.manual_seed(self.seed)
|
||||||
|
|
||||||
self._data_provider = StatefulDataLoader(
|
self._data_provider = StatefulDataLoader(
|
||||||
self.dataset,
|
self.dataset,
|
||||||
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
pin_memory_device=DistributedInterface().current_device.type,
|
pin_memory_device=DistributedInterface().current_device.type,
|
||||||
drop_last=self.drop_last,
|
drop_last=self.drop_last,
|
||||||
generator=generato_seed,
|
generator=generator_seed,
|
||||||
)
|
)
|
||||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||||
self._length = len(self._data_provider)
|
self._length = len(self._data_provider)
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ def _save_standard_training_states(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
model_to_save = model.module if hasattr(model, "module") else model
|
model_to_save = model.module if hasattr(model, "module") else model
|
||||||
model_dir = os.path.join(ckpt_dir, "model")
|
model_dir = os.path.join(ckpt_dir, "model")
|
||||||
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
|
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
|
||||||
processor.save_pretrained(model_dir)
|
processor.save_pretrained(model_dir)
|
||||||
|
|
||||||
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
||||||
@@ -212,7 +212,11 @@ def _load_standard_training_states(
|
|||||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
||||||
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
||||||
if state_dict:
|
if state_dict:
|
||||||
model_to_load.load_state_dict(state_dict)
|
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
|
||||||
|
if incompatible_keys.missing_keys:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
||||||
|
|
||||||
|
|||||||
@@ -148,7 +148,9 @@ def launch():
|
|||||||
elif command == "dpo":
|
elif command == "dpo":
|
||||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||||
elif command == "rm":
|
elif command == "rm":
|
||||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||||
|
|
||||||
|
run_rm()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f"Unknown command: {command}.\n{USAGE}")
|
print(f"Unknown command: {command}.\n{USAGE}")
|
||||||
@@ -175,9 +177,9 @@ def main():
|
|||||||
# run_dpo()
|
# run_dpo()
|
||||||
raise NotImplementedError("DPO trainer is not implemented yet.")
|
raise NotImplementedError("DPO trainer is not implemented yet.")
|
||||||
elif command == "rm":
|
elif command == "rm":
|
||||||
# from llamafactory.v1.trainers.rm_trainer import run_rm
|
from llamafactory.v1.trainers.rm_trainer import run_rm
|
||||||
# run_rm()
|
|
||||||
raise NotImplementedError("RM trainer is not implemented yet.")
|
run_rm()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -381,7 +381,7 @@ class FSDP2Engine:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
|
||||||
grad_norm = grad_norm.full_tensor()
|
grad_norm = grad_norm.full_tensor()
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|||||||
@@ -78,14 +78,14 @@ def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor)
|
|||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
@DistributedPlugin("deepspeed").register("save_checkpoint")
|
||||||
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
def save_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||||
from .deepspeed import save_checkpoint
|
from .deepspeed import save_checkpoint
|
||||||
|
|
||||||
return save_checkpoint(model, optimizer, ckpt_dir)
|
return save_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
@DistributedPlugin("deepspeed").register("load_checkpoint")
|
||||||
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str) -> None:
|
def load_checkpoint_deepspeed(model: HFModel, optimizer: torch.optim.Optimizer, ckpt_dir: str, **kwargs) -> None:
|
||||||
from .deepspeed import load_checkpoint
|
from .deepspeed import load_checkpoint
|
||||||
|
|
||||||
return load_checkpoint(model, optimizer, ckpt_dir)
|
return load_checkpoint(model, optimizer, ckpt_dir, **kwargs)
|
||||||
|
|||||||
@@ -0,0 +1,183 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..accelerator.interface import Dim, DistributedInterface
|
||||||
|
from ..config import InputArgument, TrainingArguments, get_args
|
||||||
|
from ..config.arg_utils import ModelClass
|
||||||
|
from ..core.base_trainer import BaseTrainer
|
||||||
|
from ..core.data_engine import DataEngine
|
||||||
|
from ..core.model_engine import ModelEngine
|
||||||
|
from ..utils import logging
|
||||||
|
from ..utils.types import BatchInput, HFModel, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_rm_dataset_format(train_dataset: DataEngine, dataset_path: str) -> None:
|
||||||
|
"""Validate RM dataset format early for clearer error messages."""
|
||||||
|
if len(train_dataset) == 0:
|
||||||
|
raise ValueError(f"RM training dataset is empty: {dataset_path}")
|
||||||
|
|
||||||
|
sample = train_dataset[0]
|
||||||
|
if "chosen_messages" in sample and "rejected_messages" in sample:
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset_name = sample.get("_dataset_name", "unknown")
|
||||||
|
sample_keys = sorted(sample.keys())
|
||||||
|
raise ValueError(
|
||||||
|
"RM training requires pair-format samples containing chosen/rejected responses. "
|
||||||
|
f"First sample from dataset '{dataset_name}' has keys: {sample_keys}. "
|
||||||
|
"Please use pair data (e.g. a dataset with chosen_messages/rejected_messages, "
|
||||||
|
"or set converter='pair' for raw chosen/rejected fields)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_score_head(model: HFModel) -> None:
|
||||||
|
"""Initialize the score head for RM training with small Gaussian weights.
|
||||||
|
|
||||||
|
Uses Gaussian initialization so that different parameters have distinct values,
|
||||||
|
providing better gradient flow than zero initialization while keeping initial
|
||||||
|
scores small enough that the starting loss is close to ln(2).
|
||||||
|
"""
|
||||||
|
unwrapped = model.module if hasattr(model, "module") else model
|
||||||
|
score = getattr(unwrapped, "score", None)
|
||||||
|
if score is not None and hasattr(score, "weight"):
|
||||||
|
hidden_size = score.weight.shape[-1]
|
||||||
|
std = 1.0 / (hidden_size * 10)
|
||||||
|
with torch.no_grad():
|
||||||
|
score.weight.normal_(mean=0.0, std=std)
|
||||||
|
if score.bias is not None:
|
||||||
|
score.bias.zero_()
|
||||||
|
logger.info_rank0(f"Initialized score head with Gaussian (std={std:.6f}): {score.weight.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
class RMTrainer(BaseTrainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
model: HFModel,
|
||||||
|
renderer,
|
||||||
|
train_dataset,
|
||||||
|
callbacks=None,
|
||||||
|
) -> None:
|
||||||
|
cp_size = args.dist_config.get("cp_size", 1) if args.dist_config is not None else 1
|
||||||
|
if cp_size > 1:
|
||||||
|
raise NotImplementedError("RM trainer currently only supports cp_size == 1.")
|
||||||
|
|
||||||
|
super().__init__(args, model, renderer, train_dataset, callbacks)
|
||||||
|
|
||||||
|
def _shard_model(self) -> None:
|
||||||
|
if self.args.dist_config is None:
|
||||||
|
if DistributedInterface().get_world_size(Dim.DP) > 1:
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
device_ids = None if self.device.type == "cpu" else [self.device.index]
|
||||||
|
self.model = DDP(self.model, device_ids=device_ids, find_unused_parameters=True)
|
||||||
|
else:
|
||||||
|
super()._shard_model()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _unwrapped_model(self):
|
||||||
|
"""Access the underlying model, unwrapping DDP/FSDP wrappers if present."""
|
||||||
|
model = self.model
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model = model.module
|
||||||
|
return model
|
||||||
|
|
||||||
|
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||||
|
input_ids = batch["input_ids"].to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
token_type_ids = batch.get("token_type_ids")
|
||||||
|
if token_type_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RM training requires pair data with token_type_ids. "
|
||||||
|
"Ensure the dataset has chosen_messages/rejected_messages."
|
||||||
|
)
|
||||||
|
token_type_ids = token_type_ids.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Use token_type_ids as document-index attention mask (values: 1=chosen, 2=rejected, 0=padding).
|
||||||
|
# Transformers v5 models natively support this format in _update_causal_mask,
|
||||||
|
# constructing the correct block-diagonal causal mask internally for all attention backends.
|
||||||
|
model_attention_mask = token_type_ids
|
||||||
|
|
||||||
|
# Build position_ids that reset at each document boundary.
|
||||||
|
batch_size, seq_len = token_type_ids.shape
|
||||||
|
arange = torch.arange(seq_len, device=self.device).unsqueeze(0).expand(batch_size, -1)
|
||||||
|
chosen_mask = token_type_ids == 1
|
||||||
|
rejected_mask = token_type_ids == 2
|
||||||
|
chosen_lens = chosen_mask.sum(dim=1, keepdim=True)
|
||||||
|
position_ids = torch.zeros_like(token_type_ids)
|
||||||
|
position_ids[chosen_mask] = arange[chosen_mask]
|
||||||
|
position_ids[rejected_mask] = (arange - chosen_lens)[rejected_mask]
|
||||||
|
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=model_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
use_cache=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = model_output.logits.float().squeeze(-1)
|
||||||
|
|
||||||
|
chosen_mask = token_type_ids == 1
|
||||||
|
rejected_mask = token_type_ids == 2
|
||||||
|
|
||||||
|
valid_pair_mask = chosen_mask.any(dim=-1) & rejected_mask.any(dim=-1)
|
||||||
|
if not torch.any(valid_pair_mask):
|
||||||
|
raise ValueError(
|
||||||
|
"No valid RM pairs found in this micro-batch. "
|
||||||
|
"This is usually caused by cutoff_len being too small and truncating chosen/rejected tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = rewards[valid_pair_mask]
|
||||||
|
chosen_mask = chosen_mask[valid_pair_mask]
|
||||||
|
rejected_mask = rejected_mask[valid_pair_mask]
|
||||||
|
|
||||||
|
seq_len = rewards.size(-1)
|
||||||
|
position_index = torch.arange(seq_len, device=self.device).unsqueeze(0)
|
||||||
|
chosen_last_idx = (position_index * chosen_mask.long()).max(dim=-1).values
|
||||||
|
rejected_last_idx = (position_index * rejected_mask.long()).max(dim=-1).values
|
||||||
|
|
||||||
|
chosen_scores = rewards.gather(dim=1, index=chosen_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
|
rejected_scores = rewards.gather(dim=1, index=rejected_last_idx.unsqueeze(-1)).squeeze(-1)
|
||||||
|
return -F.logsigmoid(chosen_scores - rejected_scores).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def run_rm(args: InputArgument = None):
|
||||||
|
model_args, data_args, training_args, _ = get_args(args)
|
||||||
|
model_args.model_class = ModelClass.CLS
|
||||||
|
DistributedInterface(training_args.dist_config)
|
||||||
|
train_dataset = DataEngine(data_args.train_dataset)
|
||||||
|
_validate_rm_dataset_format(train_dataset, data_args.train_dataset)
|
||||||
|
model_engine = ModelEngine(model_args, is_train=True)
|
||||||
|
_init_score_head(model_engine.model)
|
||||||
|
trainer = RMTrainer(
|
||||||
|
args=training_args,
|
||||||
|
model=model_engine.model,
|
||||||
|
renderer=model_engine.renderer,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
)
|
||||||
|
trainer.fit()
|
||||||
|
trainer.save_model()
|
||||||
|
DistributedInterface().destroy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_rm()
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class LoggingCallback(TrainerCallback):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Human-readable output to stdout
|
# Human-readable output to stdout
|
||||||
display_logs = {**logs, "total_steps": state.num_training_steps}
|
display_logs = {**logs, "step": state.global_step, "total_steps": state.num_training_steps}
|
||||||
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||||
logger.info_rank0(parts)
|
logger.info_rank0(parts)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user