diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/model/parser.py index aa32a0ae..f3626f69 100644 --- a/src/llmtuner/model/parser.py +++ b/src/llmtuner/model/parser.py @@ -192,7 +192,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: training_args.resume_from_checkpoint )) - if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: + if ( + finetuning_args.stage in ["rm", "ppo"] + and finetuning_args.finetuning_type == "lora" + and training_args.resume_from_checkpoint is not None + ): logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( training_args.resume_from_checkpoint )) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index d3ec0bb1..381436d2 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -83,46 +83,47 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if model_args.rope_scaling is not None: - if not hasattr(config, "rope_scaling"): - logger.warning("Current model does not support RoPE scaling.") + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + return + + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) else: - if is_trainable: - if model_args.rope_scaling == "dynamic": - logger.warning( - "Dynamic NTK scaling may not work well with fine-tuning. " - "See: https://github.com/huggingface/transformers/pull/24653" - ) + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 - current_max_length = getattr(config, "max_position_embeddings", None) - if current_max_length and model_args.model_max_length > current_max_length: - scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) - else: - logger.warning("Input length is smaller than max length. Consider increase input length.") - scaling_factor = 1.0 - else: - scaling_factor = 2.0 - - setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) - logger.info("Using {} scaling strategy and setting scaling factor to {}".format( - model_args.rope_scaling, scaling_factor - )) + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info("Using {} scaling strategy and setting scaling factor to {}".format( + model_args.rope_scaling, scaling_factor + )) -def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None: - if model_args.flash_attn and is_flash_attn2_available(): - config_kwargs["use_flash_attention_2"] = True - config_kwargs["torch_dtype"] = model_args.compute_dtype - logger.info("Using FlashAttention-2 for faster training and inference.") +def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None: + if not is_flash_attn2_available(): + logger.warning("FlashAttention2 is not installed.") + return + + config_kwargs["use_flash_attention_2"] = True + logger.info("Using FlashAttention-2 for faster training and inference.") -def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if is_trainable and model_args.shift_attn: - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: - setattr(config, "group_size_ratio", 0.25) - logger.info("Using shift short attention with group_size_ratio=1/4.") - else: - logger.warning("Current model does not support shift short attention.") +def _configure_longlora(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") def _configure_quantization( @@ -132,9 +133,9 @@ def _configure_quantization( config_kwargs: Dict[str, Any] ) -> None: r""" - Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) """ - if getattr(config, "quantization_config", None): # gptq or awq + if getattr(config, "quantization_config", None): # gptq if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") @@ -142,9 +143,9 @@ def _configure_quantization( quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4: quantization_config["use_exllama"] = False # disable exllama - logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) + logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1))) - elif model_args.export_quantization_bit is not None: # gptq + elif model_args.export_quantization_bit is not None: # auto-gptq require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") from accelerate.utils import get_max_memory @@ -232,15 +233,20 @@ def patch_config( ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - setattr(config, "torch_dtype", model_args.compute_dtype) if getattr(config, "model_type", None) == "qwen": for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: - setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) + setattr(config, dtype_name, model_args.compute_dtype == dtype) + + if model_args.rope_scaling is not None: + _configure_rope(config, model_args, is_trainable) + + if model_args.flash_attn: + _configure_flashattn(config_kwargs) + + if is_trainable and model_args.shift_attn: + _configure_longlora(config) - _configure_rope(config, model_args, is_trainable) - _configure_flashattn(model_args, config_kwargs) - _configure_longlora(config, model_args, is_trainable) _configure_quantization(config, tokenizer, model_args, config_kwargs) diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 63813edd..049f04ea 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -1,3 +1,4 @@ +import torch from typing import TYPE_CHECKING, Any, Dict, List, Optional from llmtuner.extras.callbacks import LogCallback @@ -46,7 +47,12 @@ def export_model(args: Optional[Dict[str, Any]] = None): logger.warning("Cannot merge adapters to a quantized model.") model.config.use_cache = True - model = model.to("cpu") + if getattr(model.config, "torch_dtype", None) == "bfloat16": + model = model.to(torch.bfloat16).to("cpu") + else: + model = model.to(torch.float16).to("cpu") + setattr(model.config, "torch_dtype", "float16") + model.save_pretrained( save_directory=model_args.export_dir, max_shard_size="{}GB".format(model_args.export_size),