[feat] support HyperParallel PT training and activation optimization (#10370)

This commit is contained in:
Cui-yshoho
2026-06-02 22:39:32 +08:00
committed by GitHub
parent a98a1ef101
commit 053d43c0ac
5 changed files with 326 additions and 44 deletions

View File

@@ -88,12 +88,19 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from .hyper_parallel import run_sft as run_sft_hp
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
if finetuning_args.stage == "pt":
from .hyper_parallel import run_pt as run_pt_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
else:
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():