mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[deps] upgrade transformers to 4.50.0 (#7437)
* upgrade transformers * fix hf cache * fix dpo trainer
This commit is contained in:
parent
dfbe1391e9
commit
b1b78daf06
@ -1,5 +1,5 @@
|
||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
|
||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
|
||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
||||
datasets>=2.16.0,<=3.3.2
|
||||
accelerate>=0.34.0,<=1.4.0
|
||||
peft>=0.11.1,<=0.12.0
|
||||
|
@ -19,7 +19,7 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
datasets>=2.16.0,<=3.3.2
|
||||
accelerate>=0.34.0,<=1.4.0
|
||||
peft>=0.11.1,<=0.12.0
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope, use_openmind
|
||||
@ -99,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
|
||||
dataset_info = None
|
||||
else:
|
||||
if dataset_dir.startswith("REMOTE:"):
|
||||
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||
else:
|
||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||
|
||||
|
@ -88,7 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.3.2")
|
||||
check_version("accelerate>=0.34.0,<=1.4.0")
|
||||
check_version("peft>=0.11.1,<=0.12.0")
|
||||
|
@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
return super()._get_train_sampler()
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
||||
|
@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
@override
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user