mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
refactor model_dtype, fix PPO trainer
This commit is contained in:
@@ -24,7 +24,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
@@ -86,11 +86,17 @@ def load_model_and_tokenizer(
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is not None:
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
else: # priority: bf16 > fp16 > fp32
|
||||
optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
setattr(config, "torch_dtype", optim_dtype)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
||||
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
@@ -131,9 +137,7 @@ def load_model_and_tokenizer(
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
|
||||
LlamaPatches._prepare_decoder_attention_mask
|
||||
)
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif getattr(config, "model_type", None) == "qwen":
|
||||
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
||||
@@ -180,7 +184,6 @@ def load_model_and_tokenizer(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
@@ -203,7 +206,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
# Initialize adapters
|
||||
if is_trainable:
|
||||
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
|
||||
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
|
||||
@@ -8,16 +8,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
try:
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
|
||||
is_fp16_available = is_torch_cuda_available()
|
||||
is_bf16_available = is_torch_bf16_gpu_available()
|
||||
is_npu_available = is_torch_npu_available()
|
||||
except ImportError:
|
||||
is_fp16_available = torch.cuda.is_available()
|
||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
is_npu_available = False
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
@@ -31,17 +21,6 @@ from llmtuner.hparams import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _infer_dtype() -> torch.dtype:
|
||||
if is_npu_available:
|
||||
return torch.float16
|
||||
elif is_bf16_available:
|
||||
return torch.bfloat16
|
||||
elif is_fp16_available:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
@@ -178,12 +157,15 @@ def get_train_args(
|
||||
if not finetuning_args.resume_lora_training:
|
||||
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
# postprocess data_args
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
@@ -206,10 +188,9 @@ def get_train_args(
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
@@ -220,26 +201,7 @@ def get_train_args(
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
if training_args.bf16:
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
else:
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
if model_args.layernorm_dtype == "bf16":
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 type.")
|
||||
model_args.layernorm_dtype = torch.bfloat16
|
||||
elif model_args.layernorm_dtype == "fp16":
|
||||
model_args.layernorm_dtype = torch.float16
|
||||
elif model_args.layernorm_dtype == "fp32":
|
||||
model_args.layernorm_dtype = torch.float32
|
||||
else:
|
||||
model_args.layernorm_dtype = model_args.compute_dtype
|
||||
|
||||
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
@@ -278,7 +240,4 @@ def get_infer_args(
|
||||
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
# auto-detect cuda capability
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -31,11 +31,11 @@ def find_all_linear_modules(
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
layernorm_dtype: torch.dtype,
|
||||
upcast_layernorm: bool,
|
||||
finetuning_type: str,
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
@@ -44,9 +44,10 @@ def prepare_model_for_training(
|
||||
(3) upcast the lm_head to fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(layernorm_dtype)
|
||||
if upcast_layernorm:
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
|
||||
Reference in New Issue
Block a user