[core deps] upgrade TRL to be between 0.18 and 0.24 (#9617)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Username_Full
2025-12-31 20:54:27 +08:00
committed by GitHub
parent c8d7e85b3e
commit 000526908a
7 changed files with 60 additions and 29 deletions

View File

@@ -33,17 +33,17 @@ jobs:
- "windows-latest" - "windows-latest"
- "macos-latest" - "macos-latest"
transformers: transformers:
- null - ""
include: # test backward compatibility include: # test backward compatibility
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.49.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.51.0" transformers: "4.51.0"
- python: "3.11" - python: "3.11"
os: "ubuntu-latest" os: "ubuntu-latest"
transformers: "4.53.0" transformers: "4.53.0"
- python: "3.11"
os: "ubuntu-latest"
transformers: "4.55.0"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@@ -41,12 +41,12 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.49.0,<=4.56.2,!=4.52.0; python_version < '3.10'", "transformers>=4.51.0,<=4.56.2,!=4.52.0; python_version < '3.10'",
"transformers>=4.49.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'", "transformers>=4.51.0,<=4.57.1,!=4.52.0,!=4.57.0; python_version >= '3.10'",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.14.0,<=0.17.1", "peft>=0.14.0,<=0.17.1",
"trl>=0.8.6,<=0.9.6", "trl>=0.18.0,<=0.24.0",
"torchdata>=0.10.0,<=0.11.0", "torchdata>=0.10.0,<=0.11.0",
# gui # gui
"gradio>=4.38.0,<=5.50.0", "gradio>=4.38.0,<=5.50.0",

View File

@@ -94,11 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.49.0,<=4.57.1") check_version("transformers>=4.51.0,<=4.57.1")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.14.0,<=0.17.1") check_version("peft>=0.14.0,<=0.17.1")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.18.0,<=0.24.0")
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:

View File

@@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -95,7 +96,7 @@ class CustomDPOTrainer(DPOTrainer):
if not ( if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()
@@ -210,7 +211,7 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities. Otherwise the average log probabilities.
@@ -230,11 +231,18 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps chosen_logps_avg = chosen_logps
else: else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length chosen_logps_avg = chosen_logps / chosen_length
return {
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps,
"chosen_logits": chosen_logits,
"rejected_logits": rejected_logits,
"chosen_logps_avg": chosen_logps_avg,
}
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
@@ -252,9 +260,9 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward( ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
ref_model, batch, is_ref_model=True reference_chosen_logps = ref_output["chosen_logps"]
) reference_rejected_logps = ref_output["rejected_logps"]
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps
@@ -267,13 +275,13 @@ class CustomDPOTrainer(DPOTrainer):
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {} metrics = {}
(
policy_chosen_logps, model_output = self.concatenated_forward(model, batch)
policy_rejected_logps, policy_chosen_logps = model_output["chosen_logps"]
policy_chosen_logits, policy_rejected_logps = model_output["rejected_logps"]
policy_rejected_logits, policy_chosen_logits = model_output["chosen_logits"]
policy_chosen_logps_avg, policy_rejected_logits = model_output["rejected_logits"]
) = self.concatenated_forward(model, batch) policy_chosen_logps_avg = model_output["chosen_logps_avg"]
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(

View File

@@ -25,6 +25,7 @@ import torch
from transformers import Trainer from transformers import Trainer
from trl import KTOTrainer from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
self.desirable_weight = finetuning_args.kto_chosen_weight self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.pref_ftx self.ftx_gamma = finetuning_args.pref_ftx
# trl
# Not all losses require a KL calculation
self.calculate_KL = True
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
self.calculate_KL = False
else:
self.loss_type = "kto"
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
if not ( if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device ): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval() self.ref_model.eval()

View File

@@ -33,12 +33,12 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl import __version__ as trl_version
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
@@ -83,6 +83,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if eval_dataset is not None: if eval_dataset is not None:
raise NotImplementedError("PPOTrainer does not support eval dataset yet.") raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
# Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
try:
from transformers.utils.versions import require_version
require_version(
"trl>=0.8.6,<=0.9.6",
"Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
"To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
)
except ImportError as e:
raise e
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig( ppo_config = PPOConfig(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,
@@ -406,7 +419,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return rewards.float().detach() # use fp32 type return rewards.float().detach() # use fp32 type
@override @override
@PPODecorators.empty_device_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
@@ -420,6 +432,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
from trl.core import logprobs_from_logits
torch_gc()
bs = len(queries) bs = len(queries)
fbs = self.config.mini_batch_size fbs = self.config.mini_batch_size
all_logprobs = [] all_logprobs = []

View File

@@ -108,7 +108,7 @@ def create_modelcard_and_push(
elif training_args.push_to_hub: elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)
else: else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub Trainer.create_model_card(trainer, license="other", **kwargs) # prevent from connecting to hub
def create_ref_model( def create_ref_model(