diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 5d9a85f87..38ddb90dc 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype -from ..extras.packages import is_mcore_adapter_available, is_ray_available +from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback @@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None: model = model.to(output_dtype) logger.info_rank0(f"Convert model dtype to: {output_dtype}.") - model.save_pretrained( - save_directory=model_args.export_dir, - max_shard_size=f"{model_args.export_size}GB", - safe_serialization=(not model_args.export_legacy_format), - ) + # Prepare save arguments (safe_serialization removed in transformers v5.0.0) + save_kwargs = { + "save_directory": model_args.export_dir, + "max_shard_size": f"{model_args.export_size}GB", + } + if not is_transformers_version_greater_than("5.0.0"): + save_kwargs["safe_serialization"] = not model_args.export_legacy_format + + model.save_pretrained(**save_kwargs) + if model_args.export_hub_model_id is not None: + # Prepare push arguments (safe_serialization removed in transformers v5.0.0) + push_kwargs = { + "max_shard_size": f"{model_args.export_size}GB", + } + if not is_transformers_version_greater_than("5.0.0"): + push_kwargs["safe_serialization"] = not model_args.export_legacy_format + model.push_to_hub( model_args.export_hub_model_id, token=model_args.hf_hub_token, - max_shard_size=f"{model_args.export_size}GB", - safe_serialization=(not model_args.export_legacy_format), + **push_kwargs, ) if finetuning_args.stage == "rm":