mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Cleaner integration.
Former-commit-id: 5c2ff1b749a265dd3c979189ec491d8ac911a6f6
This commit is contained in:
parent
bc1c082bc2
commit
c779899f7b
@ -215,11 +215,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
):
|
):
|
||||||
if finetuning_args.badam_mode == "ratio":
|
if finetuning_args.badam_mode == "ratio":
|
||||||
raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer")
|
raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer")
|
||||||
if (finetuning_args.badam_mode == "layer"
|
if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()):
|
||||||
and training_args.deepspeed_plugin is not None
|
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.")
|
||||||
and training_args.deepspeed_plugin.zero_stage < 3
|
|
||||||
):
|
|
||||||
raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {training_args.deepspeed_plugin.zero_stage}")
|
|
||||||
|
|
||||||
if (finetuning_args.use_galore) and training_args.deepspeed is not None:
|
if (finetuning_args.use_galore) and training_args.deepspeed is not None:
|
||||||
raise ValueError("GaLore are incompatible with DeepSpeed yet.")
|
raise ValueError("GaLore are incompatible with DeepSpeed yet.")
|
||||||
|
@ -96,15 +96,9 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.callback_handler.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if (self.args.deepspeed_plugin is not None
|
|
||||||
and self.args.deepspeed_plugin.zero_stage == 3
|
|
||||||
):
|
|
||||||
from badam.utils import BAdamZeRO3Callback
|
|
||||||
self.callback_handler.add_callback(BAdamZeRO3Callback)
|
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -91,15 +91,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
self.ref_model.eval()
|
self.ref_model.eval()
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.callback_handler.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if (self.args.deepspeed_plugin is not None
|
|
||||||
and self.args.deepspeed_plugin.zero_stage == 3
|
|
||||||
):
|
|
||||||
from badam.utils import BAdamZeRO3Callback
|
|
||||||
self.callback_handler.add_callback(BAdamZeRO3Callback)
|
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -166,15 +166,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.callback_handler.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if (self.args.deepspeed_plugin is not None
|
|
||||||
and self.args.deepspeed_plugin.zero_stage == 3
|
|
||||||
):
|
|
||||||
from badam.utils import BAdamZeRO3Callback
|
|
||||||
self.callback_handler.add_callback(BAdamZeRO3Callback)
|
|
||||||
|
|
||||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
@ -48,15 +48,9 @@ class CustomTrainer(Trainer):
|
|||||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.callback_handler.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if (self.args.deepspeed_plugin is not None
|
|
||||||
and self.args.deepspeed_plugin.zero_stage == 3
|
|
||||||
):
|
|
||||||
from badam.utils import BAdamZeRO3Callback
|
|
||||||
self.callback_handler.add_callback(BAdamZeRO3Callback)
|
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -72,15 +72,9 @@ class PairwiseTrainer(Trainer):
|
|||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.callback_handler.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if (self.args.deepspeed_plugin is not None
|
|
||||||
and self.args.deepspeed_plugin.zero_stage == 3
|
|
||||||
):
|
|
||||||
from badam.utils import BAdamZeRO3Callback
|
|
||||||
self.callback_handler.add_callback(BAdamZeRO3Callback)
|
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -56,14 +56,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_old_version, BAdamCallback
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
|
self.callback_handler.add_callback(BAdamCallback)
|
||||||
if (self.args.deepspeed_plugin is not None
|
|
||||||
and self.args.deepspeed_plugin.zero_stage == 3
|
|
||||||
):
|
|
||||||
from badam.utils import BAdamZeRO3Callback
|
|
||||||
self.callback_handler.add_callback(BAdamZeRO3Callback)
|
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
@ -371,11 +371,8 @@ def _create_badam_optimizer(
|
|||||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||||
]
|
]
|
||||||
|
|
||||||
ds_zero3_enabled = False
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
|
ds_zero3_enabled = is_deepspeed_zero3_enabled()
|
||||||
assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
|
|
||||||
assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
|
|
||||||
ds_zero3_enabled = True
|
|
||||||
|
|
||||||
if finetuning_args.badam_mode == "layer":
|
if finetuning_args.badam_mode == "layer":
|
||||||
from badam import BlockOptimizer
|
from badam import BlockOptimizer
|
||||||
@ -400,6 +397,7 @@ def _create_badam_optimizer(
|
|||||||
elif finetuning_args.badam_mode == "ratio":
|
elif finetuning_args.badam_mode == "ratio":
|
||||||
from badam import BlockOptimizerRatio
|
from badam import BlockOptimizerRatio
|
||||||
|
|
||||||
|
assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer."
|
||||||
assert finetuning_args.badam_update_ratio > 1e-6
|
assert finetuning_args.badam_update_ratio > 1e-6
|
||||||
optimizer = BlockOptimizerRatio(
|
optimizer = BlockOptimizerRatio(
|
||||||
param_groups=param_groups,
|
param_groups=param_groups,
|
||||||
@ -411,7 +409,7 @@ def _create_badam_optimizer(
|
|||||||
**optim_kwargs,
|
**optim_kwargs,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
|
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user