mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
support galore
This commit is contained in:
@@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import load_model_and_tokenizer, load_valuehead_params
|
||||
|
||||
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
@@ -118,3 +123,45 @@ def create_reward_model(
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if not finetuning_args.use_galore:
|
||||
return None
|
||||
|
||||
galore_params = []
|
||||
galore_targets = finetuning_args.galore_target.split(",")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
|
||||
|
||||
id_galore_params = [id(p) for p in galore_params]
|
||||
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
||||
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
||||
|
||||
# define param groups as galore_params and non_galore_params
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adamw_8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
|
||||
return optimizer
|
||||
|
||||
Reference in New Issue
Block a user