[deps] Update for Transformers v5 (#9569)

This commit is contained in:
tangefly
2025-12-08 01:13:32 +08:00
committed by GitHub
parent 109162dc56
commit 739954910a
2 changed files with 20 additions and 4 deletions

View File

@@ -21,13 +21,13 @@ from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_templat
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -75,7 +75,16 @@ def run_sft(
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
# Compatible with Transformers v4 and Transformers v5
if is_transformers_version_greater_than("5.0.0RC0"):
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
extra_ids = tokenizer.convert_tokens_to_ids(extra_special_tokens)
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
unique_eos_ids = list(dict.fromkeys(all_eos_ids))
gen_kwargs["eos_token_id"] = unique_eos_ids
else:
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
@@ -132,4 +141,4 @@ def run_sft(
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -15,7 +15,14 @@
import os
import pytest
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils import is_flash_attn_2_available
# Compatible with Transformers v4 and Transformers v5
try:
from transformers.utils import is_torch_sdpa_available
except ImportError:
def is_torch_sdpa_available():
return True
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_infer_model