diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index fe707af7..c48df995 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -316,7 +316,7 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn - if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable: + if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable: setattr(config, "output_router_logits", True) init_kwargs["torch_dtype"] = model_args.compute_dtype