mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[feat] support megatron-LM training by mcore_adapter (#9237)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.packages import is_ray_available
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -66,7 +66,19 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if not is_mcore_adapter_available():
|
||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||
if finetuning_args.stage == "pt":
|
||||
from .mca import run_pt as run_pt_mca
|
||||
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
from .mca import run_sft as run_sft_mca
|
||||
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else: # dpo
|
||||
from .mca import run_dpo as run_dpo_mca
|
||||
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
Reference in New Issue
Block a user