mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
@@ -7,9 +7,9 @@ from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.dpo.collator import DPODataCollatorWithPadding
|
||||
from ...train.dpo.trainer import CustomDPOTrainer
|
||||
from ...train.utils import create_modelcard_and_push, create_ref_model
|
||||
from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model
|
||||
from .collator import DPODataCollatorWithPadding
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -44,6 +44,7 @@ def run_dpo(
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
loss_type=finetuning_args.dpo_loss,
|
||||
@@ -54,6 +55,7 @@ def run_dpo(
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
optimizers=(optimizer, None),
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.ppo.trainer import CustomPPOTrainer
|
||||
from ...train.utils import create_ref_model, create_reward_model
|
||||
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
|
||||
from .trainer import CustomPPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -64,7 +64,10 @@ def run_ppo(
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
if optimizer is None:
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -30,12 +30,14 @@ def run_pt(
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
optimizers=(optimizer, None),
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from ...train.rm.metric import compute_accuracy
|
||||
from ...train.rm.trainer import PairwiseTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
||||
from .collator import PairwiseDataCollatorWithPadding
|
||||
from .metric import compute_accuracy
|
||||
from .trainer import PairwiseTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,12 +35,14 @@ def run_rm(
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
optimizers=(optimizer, None),
|
||||
compute_metrics=compute_accuracy,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from ...model import load_model, load_tokenizer
|
||||
from ...train.sft.metric import ComputeMetrics
|
||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -49,12 +50,14 @@ def run_sft(
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
optimizers=(optimizer, None),
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
@@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import load_model_and_tokenizer, load_valuehead_params
|
||||
|
||||
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
@@ -118,3 +123,45 @@ def create_reward_model(
|
||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if not finetuning_args.use_galore:
|
||||
return None
|
||||
|
||||
galore_params = []
|
||||
galore_targets = finetuning_args.galore_target.split(",")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
|
||||
|
||||
id_galore_params = [id(p) for p in galore_params]
|
||||
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
||||
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
||||
|
||||
# define param groups as galore_params and non_galore_params
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adamw_8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
|
||||
return optimizer
|
||||
|
||||
Reference in New Issue
Block a user