[model] gemma4 (#10346)

This commit is contained in:
Kingsley
2026-04-05 12:10:28 +08:00
committed by GitHub
parent acac63ef35
commit eae6f0b541
8 changed files with 576 additions and 7 deletions

View File

@@ -48,7 +48,10 @@ def run_sft(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
from hyper_parallel.integration.llamafactory import HyperParallelArguments, HyperParallelTrainer # pylint: disable=C0415
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -128,9 +131,10 @@ def run_sft(
)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)