mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
[deps] Update for Transformers v5 (#9569)
This commit is contained in:
@@ -21,13 +21,13 @@ from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_templat
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
@@ -75,7 +75,16 @@ def run_sft(
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
|
||||
# Compatible with Transformers v4 and Transformers v5
|
||||
if is_transformers_version_greater_than("5.0.0RC0"):
|
||||
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
|
||||
extra_ids = tokenizer.convert_tokens_to_ids(extra_special_tokens)
|
||||
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
|
||||
unique_eos_ids = list(dict.fromkeys(all_eos_ids))
|
||||
gen_kwargs["eos_token_id"] = unique_eos_ids
|
||||
else:
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
# Initialize our Trainer
|
||||
@@ -132,4 +141,4 @@ def run_sft(
|
||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
@@ -15,7 +15,14 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
# Compatible with Transformers v4 and Transformers v5
|
||||
try:
|
||||
from transformers.utils import is_torch_sdpa_available
|
||||
except ImportError:
|
||||
def is_torch_sdpa_available():
|
||||
return True
|
||||
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.train.test_utils import load_infer_model
|
||||
|
||||
Reference in New Issue
Block a user