diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 2c7c14b3..0c5d4470 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -269,6 +269,7 @@ def patch_config( def patch_mixtral_replace_moe_impl() -> None: + import torch.nn.functional as F def mlp_forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states)