mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-02-26 07:45:59 +08:00
[misc] remove safe_serialization arg for transformers v5 compatibility (#10208)
Co-authored-by: P. Clawmogorov <262173731+Alm0stSurely@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
||||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
|||||||
model = model.to(output_dtype)
|
model = model.to(output_dtype)
|
||||||
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
||||||
|
|
||||||
model.save_pretrained(
|
# Prepare save arguments (safe_serialization removed in transformers v5.0.0)
|
||||||
save_directory=model_args.export_dir,
|
save_kwargs = {
|
||||||
max_shard_size=f"{model_args.export_size}GB",
|
"save_directory": model_args.export_dir,
|
||||||
safe_serialization=(not model_args.export_legacy_format),
|
"max_shard_size": f"{model_args.export_size}GB",
|
||||||
)
|
}
|
||||||
|
if not is_transformers_version_greater_than("5.0.0"):
|
||||||
|
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||||
|
|
||||||
|
model.save_pretrained(**save_kwargs)
|
||||||
|
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
|
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
|
||||||
|
push_kwargs = {
|
||||||
|
"max_shard_size": f"{model_args.export_size}GB",
|
||||||
|
}
|
||||||
|
if not is_transformers_version_greater_than("5.0.0"):
|
||||||
|
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||||
|
|
||||||
model.push_to_hub(
|
model.push_to_hub(
|
||||||
model_args.export_hub_model_id,
|
model_args.export_hub_model_id,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
max_shard_size=f"{model_args.export_size}GB",
|
**push_kwargs,
|
||||||
safe_serialization=(not model_args.export_legacy_format),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.stage == "rm":
|
if finetuning_args.stage == "rm":
|
||||||
|
|||||||
Reference in New Issue
Block a user