From 05b19d695250194ab5f98512439b00d03b4367dc Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 23 Mar 2025 17:44:27 +0800 Subject: [PATCH] [deps] upgrade transformers to 4.50.0 (#7437) * upgrade transformers * fix hf cache * fix dpo trainer --- requirements.txt | 4 ++-- src/llamafactory/__init__.py | 2 +- src/llamafactory/data/parser.py | 4 ++-- src/llamafactory/extras/misc.py | 2 +- src/llamafactory/train/dpo/trainer.py | 4 ++-- src/llamafactory/train/kto/trainer.py | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 81de13e8..8b666ce3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' -transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' +transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' +transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' datasets>=2.16.0,<=3.3.2 accelerate>=0.34.0,<=1.4.0 peft>=0.11.1,<=0.12.0 diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 47405aeb..b23f3120 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -19,7 +19,7 @@ Level: Dependency graph: main: - transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0 + transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0 datasets>=2.16.0,<=3.3.2 accelerate>=0.34.0,<=1.4.0 peft>=0.11.1,<=0.12.0 diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index ccc1bdcd..27bff26a 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -17,7 +17,7 @@ import os from dataclasses import dataclass from typing import Any, Literal, Optional -from transformers.utils import cached_file +from huggingface_hub import hf_hub_download from ..extras.constants import DATA_CONFIG from ..extras.misc import use_modelscope, use_openmind @@ -99,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li dataset_info = None else: if dataset_dir.startswith("REMOTE:"): - config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") + config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset") else: config_path = os.path.join(dataset_dir, DATA_CONFIG) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 6a56606a..7310800a 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -88,7 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" - check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") + check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("datasets>=2.16.0,<=3.3.2") check_version("accelerate>=0.34.0,<=1.4.0") check_version("peft>=0.11.1,<=0.12.0") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 98c22022..7d5aad9f 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer): return super()._get_train_sampler() @override - def get_batch_samples(self, epoch_iterator, num_batches): + def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs): r"""Replace the method of DPO Trainer with the one of the standard Trainer.""" - return Trainer.get_batch_samples(self, epoch_iterator, num_batches) + return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs) def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.""" diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 0409c305..5f620b18 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer): return Trainer._get_train_sampler(self) @override - def get_batch_samples(self, epoch_iterator, num_batches): + def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs): r"""Replace the method of KTO Trainer with the one of the standard Trainer.""" - return Trainer.get_batch_samples(self, epoch_iterator, num_batches) + return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs) @override def forward(