mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
[deps] Update for Transformers v5 (#9569)
This commit is contained in:
@@ -21,13 +21,13 @@ from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_templat
|
|||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import calculate_tps
|
from ...extras.misc import calculate_tps
|
||||||
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push
|
||||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||||
from .trainer import CustomSeq2SeqTrainer
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
@@ -75,7 +75,16 @@ def run_sft(
|
|||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
|
||||||
|
# Compatible with Transformers v4 and Transformers v5
|
||||||
|
if is_transformers_version_greater_than("5.0.0RC0"):
|
||||||
|
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
|
||||||
|
extra_ids = tokenizer.convert_tokens_to_ids(extra_special_tokens)
|
||||||
|
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
|
||||||
|
unique_eos_ids = list(dict.fromkeys(all_eos_ids))
|
||||||
|
gen_kwargs["eos_token_id"] = unique_eos_ids
|
||||||
|
else:
|
||||||
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
|||||||
@@ -15,7 +15,14 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
# Compatible with Transformers v4 and Transformers v5
|
||||||
|
try:
|
||||||
|
from transformers.utils import is_torch_sdpa_available
|
||||||
|
except ImportError:
|
||||||
|
def is_torch_sdpa_available():
|
||||||
|
return True
|
||||||
|
|
||||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.train.test_utils import load_infer_model
|
from llamafactory.train.test_utils import load_infer_model
|
||||||
|
|||||||
Reference in New Issue
Block a user