refactor model_dtype, fix PPO trainer

This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent 5310e4d182
commit 2818af0b09
10 changed files with 104 additions and 119 deletions

View File

@@ -24,7 +24,7 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
@@ -86,11 +86,17 @@ def load_model_and_tokenizer(
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is not None:
setattr(config, "torch_dtype", model_args.compute_dtype)
else: # priority: bf16 > fp16 > fp32
optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", optim_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
# Set RoPE scaling
if model_args.rope_scaling is not None:
@@ -131,9 +137,7 @@ def load_model_and_tokenizer(
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
LlamaPatches._prepare_decoder_attention_mask
)
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Qwen models automatically enable FlashAttention if installed.")
@@ -180,7 +184,6 @@ def load_model_and_tokenizer(
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
@@ -203,7 +206,7 @@ def load_model_and_tokenizer(
# Initialize adapters
if is_trainable:
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()

View File

@@ -8,16 +8,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
try:
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
is_fp16_available = is_torch_cuda_available()
is_bf16_available = is_torch_bf16_gpu_available()
is_npu_available = is_torch_npu_available()
except ImportError:
is_fp16_available = torch.cuda.is_available()
is_bf16_available = torch.cuda.is_bf16_supported()
is_npu_available = False
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
ModelArguments,
@@ -31,17 +21,6 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def _infer_dtype() -> torch.dtype:
if is_npu_available:
return torch.float16
elif is_bf16_available:
return torch.bfloat16
elif is_fp16_available:
return torch.float16
else:
return torch.float32
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
@@ -178,12 +157,15 @@ def get_train_args(
if not finetuning_args.resume_lora_training:
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
# postprocess data_args
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
@@ -206,10 +188,9 @@ def get_train_args(
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
@@ -220,26 +201,7 @@ def get_train_args(
)
# postprocess model_args
if training_args.bf16:
if not is_bf16_available:
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
model_args.compute_dtype = torch.float16
else:
model_args.compute_dtype = _infer_dtype()
if model_args.layernorm_dtype == "bf16":
if not is_bf16_available:
raise ValueError("Current device does not support bf16 type.")
model_args.layernorm_dtype = torch.bfloat16
elif model_args.layernorm_dtype == "fp16":
model_args.layernorm_dtype = torch.float16
elif model_args.layernorm_dtype == "fp32":
model_args.layernorm_dtype = torch.float32
else:
model_args.layernorm_dtype = model_args.compute_dtype
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
@@ -278,7 +240,4 @@ def get_infer_args(
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
# auto-detect cuda capability
model_args.compute_dtype = _infer_dtype()
return model_args, data_args, finetuning_args, generating_args

View File

@@ -31,11 +31,11 @@ def find_all_linear_modules(
def prepare_model_for_training(
model: "PreTrainedModel",
layernorm_dtype: torch.dtype,
upcast_layernorm: bool,
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
@@ -44,9 +44,10 @@ def prepare_model_for_training(
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(layernorm_dtype)
if upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):