mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 23:02:49 +08:00
adapt for badam with ds zero3
Former-commit-id: 33b437277846d4f0b64c13a0bc892ef4f345a21e
This commit is contained in:
parent
db569a2d61
commit
756566342d
@ -184,12 +184,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
):
|
):
|
||||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||||
|
|
||||||
if (
|
# if (
|
||||||
finetuning_args.use_badam
|
# finetuning_args.use_badam
|
||||||
and finetuning_args.badam_mode == "layer"
|
# and finetuning_args.badam_mode == "layer"
|
||||||
and training_args.parallel_mode.value == "distributed"
|
# and training_args.parallel_mode.value == "distributed"
|
||||||
):
|
# ):
|
||||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
# raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||||
|
|
||||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||||
|
@ -55,6 +55,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||||
|
|
||||||
|
def training_step(self, *args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Update the reference to deepspeed optimizer
|
||||||
|
"""
|
||||||
|
if self.finetuning_args.use_badam and \
|
||||||
|
self.args.deepspeed_plugin is not None and \
|
||||||
|
self.args.deepspeed_plugin.zero_stage == 3:
|
||||||
|
|
||||||
|
ds_optim = self.optimizer.optimizer
|
||||||
|
badam_optim = ds_optim.optimizer
|
||||||
|
badam_optim.ds_optimizer = ds_optim
|
||||||
|
|
||||||
|
return super().training_step(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
model: "torch.nn.Module",
|
model: "torch.nn.Module",
|
||||||
|
@ -309,6 +309,12 @@ def _create_badam_optimizer(
|
|||||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ds_zero3_enabled = False
|
||||||
|
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
|
||||||
|
assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
|
||||||
|
assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
|
||||||
|
ds_zero3_enabled = True
|
||||||
|
|
||||||
if finetuning_args.badam_mode == "layer":
|
if finetuning_args.badam_mode == "layer":
|
||||||
from badam import BlockOptimizer
|
from badam import BlockOptimizer
|
||||||
|
|
||||||
@ -321,6 +327,7 @@ def _create_badam_optimizer(
|
|||||||
start_block=finetuning_args.badam_start_block,
|
start_block=finetuning_args.badam_start_block,
|
||||||
switch_mode=finetuning_args.badam_switch_mode,
|
switch_mode=finetuning_args.badam_switch_mode,
|
||||||
verbose=finetuning_args.badam_verbose,
|
verbose=finetuning_args.badam_verbose,
|
||||||
|
ds_zero3_enabled=ds_zero3_enabled
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user