mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-28 11:14:18 +08:00
[v1] Support fused moe kernel for qwen3vlmoe model. (#9532)
This commit is contained in:
parent
2b6f16f261
commit
2c4fb3c97e
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
42
examples/ascend/qwen3vlmoe_lora_sft_fsdp.yaml
Normal file
@ -0,0 +1,42 @@
|
||||
### model
|
||||
model_name_or_path: Qwen/Qwen3-VL-30B-A3B-Instruct
|
||||
image_max_pixels: 262144
|
||||
video_max_pixels: 16384
|
||||
trust_remote_code: true
|
||||
use_kernels: true # replaced kernels: [NpuRMSNormKernel, NpuRoPEKernel, NpuQwen3VLMoEFusedMoEKernel]
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_rank: 8
|
||||
lora_target: all
|
||||
disable_gradient_checkpointing: false
|
||||
flash_attn: disabled
|
||||
|
||||
### dataset
|
||||
dataset: alpaca_zh_demo, alpaca_en_demo
|
||||
template: qwen3_vl
|
||||
cutoff_len: 1024
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen3vlmoe/lora/sft
|
||||
logging_steps: 1
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: true
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 8
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-4
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
resume_from_checkpoint: null
|
||||
seed: 1234
|
||||
@ -11,3 +11,103 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from .....extras.types import HFModel
|
||||
from ....trainer_plugins.distributed.accelerate import is_torch_npu_available
|
||||
from ..constants import DeviceType, KernelType
|
||||
from ..registry import MetaMoEKernel
|
||||
|
||||
|
||||
class GmmFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, group_list):
|
||||
ctx.save_for_backward(x, weight)
|
||||
ctx.group_list = group_list
|
||||
|
||||
fwd_output = torch_npu.npu_grouped_matmul(
|
||||
[x], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
|
||||
)[0]
|
||||
return fwd_output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input_tensor, weight = ctx.saved_tensors
|
||||
group_list = ctx.group_list
|
||||
|
||||
weight = torch.transpose(weight, 1, 2)
|
||||
grad_input = torch_npu.npu_grouped_matmul(
|
||||
[grad_output], [weight], bias=None, group_list=group_list, split_item=2, group_type=0, group_list_type=1
|
||||
)[0]
|
||||
grad_weight = torch_npu.npu_grouped_matmul(
|
||||
[input_tensor.T],
|
||||
[grad_output],
|
||||
bias=None,
|
||||
group_list=group_list,
|
||||
split_item=3,
|
||||
group_type=2,
|
||||
group_list_type=1,
|
||||
)[0]
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
def npu_group_gemm(x, weight, group_list):
|
||||
output = GmmFunction.apply(x, weight, group_list)
|
||||
return output
|
||||
|
||||
|
||||
def npu_experts_qwen3vlmoe_forward(
|
||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(
|
||||
hidden_states, router_indices.to(torch.int32)
|
||||
)
|
||||
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
|
||||
intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert)
|
||||
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
|
||||
output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert)
|
||||
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
|
||||
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
||||
return next_states
|
||||
|
||||
|
||||
def npu_moe_block_qwen3vlmoe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
||||
router_logits = self.gate(hidden_states)
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
|
||||
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
|
||||
routed_out = self.experts(hidden_states, routing_weights, router_indices)
|
||||
return routed_out
|
||||
|
||||
|
||||
class NpuQwen3VLMoEFusedMoEKernel(MetaMoEKernel):
|
||||
type = KernelType.MOE
|
||||
device = DeviceType.NPU
|
||||
npu_experts_kernel = npu_experts_qwen3vlmoe_forward
|
||||
npu_moe_block_kernel = npu_moe_block_qwen3vlmoe_forward
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> HFModel:
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
npu_experts_pattern = re.compile("Qwen3VLMoeTextExperts", re.IGNORECASE)
|
||||
npu_moe_block_pattern = re.compile("Qwen3VLMoeTextSparseMoeBlock", re.IGNORECASE)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if re.search(npu_experts_pattern, module.__class__.__name__):
|
||||
module.forward = types.MethodType(cls.npu_experts_kernel, module)
|
||||
elif re.search(npu_moe_block_pattern, module.__class__.__name__):
|
||||
module.forward = types.MethodType(cls.npu_moe_block_kernel, module)
|
||||
return model
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user