support galore

This commit is contained in:
hiyouga
2024-03-07 22:41:36 +08:00
parent 725f7cd70f
commit 28f7862188
12 changed files with 115 additions and 16 deletions

View File

@@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import load_model_and_tokenizer, load_valuehead_params
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
@@ -118,3 +123,45 @@ def create_reward_model(
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model
def create_custom_optimzer(
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return None
galore_params = []
galore_targets = finetuning_args.galore_target.split(",")
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
id_galore_params = [id(p) for p in galore_params]
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
# define param groups as galore_params and non_galore_params
param_groups = [
{"params": non_galore_params},
{
"params": galore_params,
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
},
]
if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
elif training_args.optim == "adamw_8bit":
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer