[feat] support LlamaFactory SFT training by HyperParallel FSDP2 backend (#10289)

This commit is contained in:
Cui-yshoho
2026-03-30 10:47:20 +08:00
committed by GitHub
parent b5afabe3d2
commit 97433c53b6
5 changed files with 235 additions and 2 deletions

View File

@@ -24,7 +24,12 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
from ..extras.packages import (
is_hyper_parallel_available,
is_mcore_adapter_available,
is_ray_available,
is_transformers_version_greater_than,
)
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -71,7 +76,16 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
if finetuning_args.stage == "pt":