mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
fix kv cache
Former-commit-id: 96ce76cd2753bc91c781ad13aa8f7a972abe815a
This commit is contained in:
parent
bbf272f96e
commit
a74426df0f
@ -101,6 +101,10 @@ class ModelArguments:
|
||||
default="offload",
|
||||
metadata={"help": "Path to offload model weights."},
|
||||
)
|
||||
use_cache: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
|
@ -115,6 +115,9 @@ def _configure_attn_implementation(model_args: "ModelArguments", init_kwargs: Di
|
||||
|
||||
|
||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
return
|
||||
@ -141,7 +144,10 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
|
||||
)
|
||||
|
||||
|
||||
def _configure_longlora(config: "PretrainedConfig") -> None:
|
||||
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
apply_llama_patch()
|
||||
@ -242,7 +248,7 @@ def _prepare_model_for_training(
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
model.enable_input_require_grads()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
@ -276,15 +282,14 @@ def patch_config(
|
||||
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||
|
||||
_configure_attn_implementation(model_args, init_kwargs)
|
||||
|
||||
if model_args.rope_scaling is not None:
|
||||
_configure_rope(config, model_args, is_trainable)
|
||||
|
||||
if is_trainable and model_args.shift_attn:
|
||||
_configure_longlora(config)
|
||||
|
||||
_configure_rope(config, model_args, is_trainable)
|
||||
_configure_longlora(config, model_args, is_trainable)
|
||||
_configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
logger.info("Using KV cache for faster generation.")
|
||||
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||
|
Loading…
x
Reference in New Issue
Block a user