diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index db9849cf..379b0c48 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -336,7 +336,7 @@ def patch_model( if is_trainable and getattr(model.config, "model_type", None) == "qwen2" and model_args.flash_attn: setattr(model.config, "use_cache", False) # qwen2 does not support use_cache when using flashattn - if is_trainable and model_args.resize_vocab: + if model_args.resize_vocab: _resize_embedding_layer(model, tokenizer) if is_trainable: