mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
fix bug in galore optimizer
Former-commit-id: 5c62881c5a59cfcc5a76d365263c8ad8c817ce49
This commit is contained in:
parent
ec81d45d27
commit
9e45f82be7
@ -11,8 +11,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||||||
--use_galore \
|
--use_galore \
|
||||||
--galore_layerwise \
|
--galore_layerwise \
|
||||||
--galore_target mlp,self_attn \
|
--galore_target mlp,self_attn \
|
||||||
--galore_scale 2.0 \
|
|
||||||
--galore_rank 128 \
|
--galore_rank 128 \
|
||||||
|
--galore_scale 2.0 \
|
||||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
@ -29,8 +29,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||||||
--evaluation_strategy steps \
|
--evaluation_strategy steps \
|
||||||
--load_best_model_at_end \
|
--load_best_model_at_end \
|
||||||
--learning_rate 5e-5 \
|
--learning_rate 5e-5 \
|
||||||
--num_train_epochs 30.0 \
|
--num_train_epochs 3.0 \
|
||||||
--max_samples 300 \
|
--max_samples 3000 \
|
||||||
--val_size 0.1 \
|
--val_size 0.1 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--pure_bf16
|
--pure_bf16
|
||||||
|
@ -234,14 +234,6 @@ def _create_galore_optimizer(
|
|||||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
|
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
|
||||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||||
|
|
||||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
|
||||||
if param.grad is not None:
|
|
||||||
optimizer_dict[param].step()
|
|
||||||
optimizer_dict[param].zero_grad()
|
|
||||||
|
|
||||||
for param in trainable_params:
|
|
||||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
|
||||||
|
|
||||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||||
else:
|
else:
|
||||||
param_groups = [
|
param_groups = [
|
||||||
@ -391,9 +383,11 @@ def create_custom_scheduler(
|
|||||||
num_training_steps=num_training_steps * 2,
|
num_training_steps=num_training_steps * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def scheduler_hook(param: "torch.nn.Parameter"):
|
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
|
optimizer_dict[param].step()
|
||||||
|
optimizer_dict[param].zero_grad()
|
||||||
scheduler_dict[param].step()
|
scheduler_dict[param].step()
|
||||||
|
|
||||||
for param in optimizer_dict.keys():
|
for param in optimizer_dict.keys():
|
||||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user