Update patcher.py

Former-commit-id: 33556cc6b0b65cc6db02e66f4f6e75112c33d966
This commit is contained in:
hoshi-hiyouga 2024-01-22 23:27:39 +08:00 committed by GitHub
parent 882a6a1d51
commit b36c4b99cc

View File

@ -285,6 +285,7 @@ def patch_model(
_prepare_model_for_training(model, model_args)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])