Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.

Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: 5f50c02f0e425737cd80abdf8fde9e25abf13083
This commit is contained in:
ldwang 2024-01-24 15:25:31 +08:00
parent 36ac14a566
commit 786a2f1103

View File

@ -269,6 +269,7 @@ def patch_config(
def patch_mixtral_replace_moe_impl() -> None:
import torch.nn.functional as F
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)