fix bug in galore optimizer

Former-commit-id: 5c62881c5a59cfcc5a76d365263c8ad8c817ce49
This commit is contained in:
hiyouga 2024-04-21 18:53:22 +08:00
parent ec81d45d27
commit 9e45f82be7
2 changed files with 7 additions and 13 deletions

View File

@ -11,8 +11,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--use_galore \ --use_galore \
--galore_layerwise \ --galore_layerwise \
--galore_target mlp,self_attn \ --galore_target mlp,self_attn \
--galore_scale 2.0 \
--galore_rank 128 \ --galore_rank 128 \
--galore_scale 2.0 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
@ -29,8 +29,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--evaluation_strategy steps \ --evaluation_strategy steps \
--load_best_model_at_end \ --load_best_model_at_end \
--learning_rate 5e-5 \ --learning_rate 5e-5 \
--num_train_epochs 30.0 \ --num_train_epochs 3.0 \
--max_samples 300 \ --max_samples 3000 \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--pure_bf16 --pure_bf16

View File

@ -234,14 +234,6 @@ def _create_galore_optimizer(
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else: else:
param_groups = [ param_groups = [
@ -391,9 +383,11 @@ def create_custom_scheduler(
num_training_steps=num_training_steps * 2, num_training_steps=num_training_steps * 2,
) )
def scheduler_hook(param: "torch.nn.Parameter"): def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None: if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
scheduler_dict[param].step() scheduler_dict[param].step()
for param in optimizer_dict.keys(): for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook) param.register_post_accumulate_grad_hook(optimizer_hook)