From b8a827faebb230fbedd04014e0548d4a8e830a81 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Jan 2024 16:19:18 +0800 Subject: [PATCH] fix #2320 Former-commit-id: 2bc30763e9a40a82484c27b9a472425fdb9b3bd8 --- src/llmtuner/extras/patches/mixtral_patch.py | 38 +++++++++++++++++++ src/llmtuner/model/patcher.py | 39 +------------------- src/llmtuner/train/tuner.py | 4 +- 3 files changed, 43 insertions(+), 38 deletions(-) create mode 100644 src/llmtuner/extras/patches/mixtral_patch.py diff --git a/src/llmtuner/extras/patches/mixtral_patch.py b/src/llmtuner/extras/patches/mixtral_patch.py new file mode 100644 index 00000000..382492e0 --- /dev/null +++ b/src/llmtuner/extras/patches/mixtral_patch.py @@ -0,0 +1,38 @@ +import torch +import torch.nn.functional as F +from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock + + +def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor: + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py +def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False) + topk_weight /= topk_weight.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + topk_weight = topk_weight.to(hidden_states.dtype) + + hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states) + flat_topk_idx = topk_idx.view(-1) + for i in range(self.num_experts): + expert = self.experts[i] + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +def patch_mixtral_replace_moe_impl() -> None: + MixtralBLockSparseTop2MLP.forward = mlp_forward + MixtralSparseMoeBlock.forward = moe_forward diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 0c5d4470..477d267e 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -16,6 +16,7 @@ from ..extras.logging import get_logger from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.packages import is_flash_attn2_available from ..extras.patches.llama_patch import apply_llama_patch +from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl if TYPE_CHECKING: @@ -268,43 +269,6 @@ def patch_config( _configure_quantization(config, tokenizer, model_args, config_kwargs) -def patch_mixtral_replace_moe_impl() -> None: - import torch.nn.functional as F - def mlp_forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - ## Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py - def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False) - topk_weight /= topk_weight.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - topk_weight = topk_weight.to(hidden_states.dtype) - - hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) - y = torch.empty_like(hidden_states) - flat_topk_idx = topk_idx.view(-1) - for i in range(self.num_experts): - expert = self.experts[i] - y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits - - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP - - MixtralBLockSparseTop2MLP.forward = mlp_forward - MixtralSparseMoeBlock.forward = moe_forward - - def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool ) -> None: @@ -325,6 +289,7 @@ def patch_model( require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") from deepspeed.utils import set_z3_leaf_modules # type: ignore from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) if is_trainable: diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index c24e5eac..decacbce 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -56,7 +56,9 @@ def export_model(args: Optional[Dict[str, Any]] = None): if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") - if hasattr(model.config, "torch_dtype"): + if getattr(model, "quantization_method", None): + model = model.to("cpu") + elif hasattr(model.config, "torch_dtype"): model = model.to(getattr(model.config, "torch_dtype")).to("cpu") else: model = model.to(torch.float16).to("cpu")