mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-28 19:24:20 +08:00
[v1] Support fused moe kernel for qwen3vlmoe model. (#9532)
This commit is contained in:
parent
2b6f16f261
commit
2c4fb3c97e
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
trust_remote_code: true
|
||||||
|
use_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel]
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_rank: 8
|
||||||
|
lora_target: all
|
||||||
|
disable_gradient_checkpointing: false
|
||||||
|
flash_attn: disabled
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: alpaca_zh_demo, alpaca_en_demo
|
||||||
|
template: qwen3_vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen3vlmoe/lora/sft
|
||||||
|
logging_steps: 1
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: true
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
resume_from_checkpoint: null
|
||||||
|
seed: 1234
|
||||||
@ -11,3 +11,103 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
from .....extras.types import HFModel
|
||||||
|
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||||
|
from ..constants import DeviceType, KernelType
|
||||||
|
from ..registry import MetaMoEKernel
|
||||||
|
|
||||||
|
|
||||||
|
class GmmFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, weight, group_list):
|
||||||
|
ctx.save_for_backward(x, weight)
|
||||||
|
ctx.group_list = group_list
|
||||||
|
|
||||||
|
fwd_output = torch_npu.npu_grouped_matmul(
|
||||||
|
[x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
|
||||||
|
)[0]
|
||||||
|
return fwd_output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_tensor, weight = ctx.saved_tensors
|
||||||
|
group_list = ctx.group_list
|
||||||
|
|
||||||
|
weight = torch.transpose(weight, 1, 2)
|
||||||
|
grad_input = torch_npu.npu_grouped_matmul(
|
||||||
|
[grad_output], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
|
||||||
|
)[0]
|
||||||
|
grad_weight = torch_npu.npu_grouped_matmul(
|
||||||
|
[input_tensor.T],
|
||||||
|
[grad_output],
|
||||||
|
bias=None,
|
||||||
|
group_list=group_list,
|
||||||
|
split_item=3,
|
||||||
|
group_type=2,
|
||||||
|
group_list_type=1,
|
||||||
|
)[0]
|
||||||
|
return grad_input, grad_weight, None
|
||||||
|
|
||||||
|
|
||||||
|
def npu_group_gemm(x, weight, group_list):
|
||||||
|
output = GmmFunction.apply(x, weight, group_list)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def npu_experts_qwen3vlmoe_forward(
|
||||||
|
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||||
|
hidden_states, router_indices.to(torch.int32)
|
||||||
|
)
|
||||||
|
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
|
||||||
|
intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert)
|
||||||
|
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||||
|
output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert)
|
||||||
|
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
|
||||||
|
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
||||||
|
return next_states
|
||||||
|
|
||||||
|
|
||||||
|
def npu_moe_block_qwen3vlmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
|
||||||
|
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||||
|
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
|
||||||
|
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||||
|
return routed_out
|
||||||
|
|
||||||
|
|
||||||
|
class NpuQwen3VLMoEFusedMoEKernel(MetaMoEKernel):
|
||||||
|
type = KernelType.MOE
|
||||||
|
device = DeviceType.NPU
|
||||||
|
npu_experts_kernel = npu_experts_qwen3vlmoe_forward
|
||||||
|
npu_moe_block_kernel = npu_moe_block_qwen3vlmoe_forward
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, model, **kwargs) -> HFModel:
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
return model
|
||||||
|
|
||||||
|
npu_experts_pattern = re.compile("Qwen3VLMoeTextExperts", re.IGNORECASE)
|
||||||
|
npu_moe_block_pattern = re.compile("Qwen3VLMoeTextSparseMoeBlock", re.IGNORECASE)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
if re.search(npu_experts_pattern, module.__class__.__name__):
|
||||||
|
module.forward = types.MethodType(cls.npu_experts_kernel, module)
|
||||||
|
elif re.search(npu_moe_block_pattern, module.__class__.__name__):
|
||||||
|
module.forward = types.MethodType(cls.npu_moe_block_kernel, module)
|
||||||
|
return model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user