mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
32545bd6d9
commit
682d81caa9
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
@ -202,6 +201,7 @@ def load_model_and_tokenizer(
|
|||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if stage in ["rm", "ppo"]:
|
if stage in ["rm", "ppo"]:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
|
||||||
vhead_path = (
|
vhead_path = (
|
||||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||||
)
|
)
|
||||||
|
@ -40,6 +40,18 @@ _EVAL_CLS = Tuple[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||||
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_args.checkpoint_dir is not None
|
||||||
|
and len(model_args.checkpoint_dir) != 1
|
||||||
|
and finetuning_args.finetuning_type != "lora"
|
||||||
|
):
|
||||||
|
raise ValueError("Multiple checkpoints are only available for LoRA tuning.")
|
||||||
|
|
||||||
|
|
||||||
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
return parse_args(parser, args)
|
return parse_args(parser, args)
|
||||||
@ -81,19 +93,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"]:
|
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
|
||||||
if training_args.resume_from_checkpoint is not None:
|
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
|
||||||
if training_args.load_best_model_at_end:
|
|
||||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||||
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "dpo"]:
|
if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])):
|
||||||
for dataset_attr in data_args.dataset_list:
|
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||||
if not dataset_attr.ranking:
|
|
||||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
@ -107,15 +114,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
_verify_model_args(model_args, finetuning_args)
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_args.checkpoint_dir is not None
|
|
||||||
and len(model_args.checkpoint_dir) != 1
|
|
||||||
and finetuning_args.finetuning_type != "lora"
|
|
||||||
):
|
|
||||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
@ -154,9 +153,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
logger.info(
|
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
training_args.resume_from_checkpoint
|
||||||
)
|
))
|
||||||
|
|
||||||
|
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||||
|
logger.warning("Add {} to `checkpoint_dir` to resume training from checkpoint.".format(
|
||||||
|
training_args.resume_from_checkpoint
|
||||||
|
))
|
||||||
|
|
||||||
# postprocess model_args
|
# postprocess model_args
|
||||||
model_args.compute_dtype = (
|
model_args.compute_dtype = (
|
||||||
@ -183,15 +187,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
_verify_model_args(model_args, finetuning_args)
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_args.checkpoint_dir is not None
|
|
||||||
and len(model_args.checkpoint_dir) != 1
|
|
||||||
and finetuning_args.finetuning_type != "lora"
|
|
||||||
):
|
|
||||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
@ -202,8 +198,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
_verify_model_args(model_args, finetuning_args)
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
|
||||||
|
|
||||||
transformers.set_seed(eval_args.seed)
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
|
@ -74,10 +74,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
else:
|
else:
|
||||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
|
|
||||||
def ppo_train(self) -> None:
|
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||||
"""
|
"""
|
||||||
|
if resume_from_checkpoint is not None:
|
||||||
|
raise ValueError("`resume_from_checkpoint` will be supported in the future version.")
|
||||||
|
|
||||||
total_train_batch_size = (
|
total_train_batch_size = (
|
||||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||||
)
|
)
|
||||||
|
@ -94,7 +94,7 @@ def run_ppo(
|
|||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
ppo_trainer.ppo_train()
|
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
@ -47,7 +47,7 @@ def run_rm(
|
|||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train()
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user