mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
[deps] upgrade transformers to 4.50.0 (#7437)
* upgrade transformers * fix hf cache * fix dpo trainer
This commit is contained in:
parent
919415dba9
commit
05b19d6952
@ -1,5 +1,5 @@
|
|||||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
|
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10'
|
||||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
||||||
datasets>=2.16.0,<=3.3.2
|
datasets>=2.16.0,<=3.3.2
|
||||||
accelerate>=0.34.0,<=1.4.0
|
accelerate>=0.34.0,<=1.4.0
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
|
@ -19,7 +19,7 @@ Level:
|
|||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.0
|
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
datasets>=2.16.0,<=3.3.2
|
datasets>=2.16.0,<=3.3.2
|
||||||
accelerate>=0.34.0,<=1.4.0
|
accelerate>=0.34.0,<=1.4.0
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
|
@ -17,7 +17,7 @@ import os
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from transformers.utils import cached_file
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope, use_openmind
|
from ..extras.misc import use_modelscope, use_openmind
|
||||||
@ -99,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> li
|
|||||||
dataset_info = None
|
dataset_info = None
|
||||||
else:
|
else:
|
||||||
if dataset_dir.startswith("REMOTE:"):
|
if dataset_dir.startswith("REMOTE:"):
|
||||||
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
config_path = hf_hub_download(repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||||
else:
|
else:
|
||||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
|||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
r"""Check the version of the required packages."""
|
r"""Check the version of the required packages."""
|
||||||
check_version("transformers>=4.41.2,<=4.49.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||||
check_version("datasets>=2.16.0,<=3.3.2")
|
check_version("datasets>=2.16.0,<=3.3.2")
|
||||||
check_version("accelerate>=0.34.0,<=1.4.0")
|
check_version("accelerate>=0.34.0,<=1.4.0")
|
||||||
check_version("peft>=0.11.1,<=0.12.0")
|
check_version("peft>=0.11.1,<=0.12.0")
|
||||||
|
@ -128,9 +128,9 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||||
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
|
||||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||||
|
|
||||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||||
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
|
||||||
|
@ -127,9 +127,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
return Trainer._get_train_sampler(self)
|
return Trainer._get_train_sampler(self)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
def get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs):
|
||||||
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
|
||||||
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
return Trainer.get_batch_samples(self, epoch_iterator, num_batches, *args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user