[deps] upgrade transformers to 4.50.0 (#7437)

* upgrade transformers

* fix hf cache

* fix dpo trainer
This commit is contained in:
hoshi-hiyouga 2025-03-23 17:44:27 +08:00 committed by GitHub
parent dfbe1391e9
commit b1b78daf06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 10 additions and 10 deletions

View File

@ -1,5 +1,5 @@
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
datasets>=2.16.0,<=3.3.2
accelerate>=0.34.0,<=1.4.0
peft>=0.11.1,<=0.12.0

View File

@ -19,7 +19,7 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.3.2
accelerate>=0.34.0,<=1.4.0
peft>=0.11.1,<=0.12.0

View File

@ -17,7 +17,7 @@ import os
from dataclasses import dataclass
from typing import Any, Literal, Optional
from transformers.utils import cached_file
from huggingface_hub import hf_hub_download
from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope, use_openmind
@ -99,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
dataset_info = None
else:
if dataset_dir.startswith("REMOTE:"):
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
else:
config_path = os.path.join(dataset_dir, DATA_CONFIG)

View File

@ -88,7 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.3.2")
check_version("accelerate>=0.34.0,<=1.4.0")
check_version("peft>=0.11.1,<=0.12.0")

View File

@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer):
return super()._get_train_sampler()
@override
def get_batch_samples(self, epoch_iterator, num_batches):
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""

View File

@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer):
return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
@override
def forward(