diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 656bfa6d..b02a4560 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -96,11 +96,6 @@ def load_model_and_tokenizer( **config_kwargs, ) - if getattr(config, "model_type", None) == "mistral" and is_deepspeed_zero3_enabled(): - from deepspeed.utils import set_z3_leaf_modules - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) - patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 5f67f618..52690e68 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -284,6 +284,11 @@ def patch_model( if is_trainable: _prepare_model_for_training(model, model_args) + if getattr(config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled(): + from deepspeed.utils import set_z3_leaf_modules + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: