From 739954910a4db837b9a7f7b03e48b71898bad810 Mon Sep 17 00:00:00 2001 From: tangefly <124695565+tangefly@users.noreply.github.com> Date: Mon, 8 Dec 2025 01:13:32 +0800 Subject: [PATCH] [deps] Update for Transformers v5 (#9569) --- src/llamafactory/train/sft/workflow.py | 15 ++++++++++++--- tests/model/model_utils/test_attention.py | 9 ++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index aab1cb13..7acf7595 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -21,13 +21,13 @@ from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_templat from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ...extras.misc import calculate_tps +from ...extras.packages import is_transformers_version_greater_than from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from .trainer import CustomSeq2SeqTrainer - if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -75,7 +75,16 @@ def run_sft( # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict(obey_generation_config=True) - gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + + # Compatible with Transformers v4 and Transformers v5 + if is_transformers_version_greater_than("5.0.0RC0"): + extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", []) + extra_ids = tokenizer.convert_tokens_to_ids(extra_special_tokens) + all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1] + unique_eos_ids = list(dict.fromkeys(all_eos_ids)) + gen_kwargs["eos_token_id"] = unique_eos_ids + else: + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id # Initialize our Trainer @@ -132,4 +141,4 @@ def run_sft( trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens) # Create model card - create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) \ No newline at end of file diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py index 446d8063..6bf51639 100644 --- a/tests/model/model_utils/test_attention.py +++ b/tests/model/model_utils/test_attention.py @@ -15,7 +15,14 @@ import os import pytest -from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available +from transformers.utils import is_flash_attn_2_available + +# Compatible with Transformers v4 and Transformers v5 +try: + from transformers.utils import is_torch_sdpa_available +except ImportError: + def is_torch_sdpa_available(): + return True from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.train.test_utils import load_infer_model