[packing] fix GDN crash when meeting dummy image (#10453)

This commit is contained in:
Kingsley
2026-05-01 12:10:13 +08:00
committed by GitHub
parent 887ee2b121
commit 468723c5d9

View File

@@ -165,8 +165,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"), "video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(), "attention_mask": (features["attention_mask"] >= 1).float(),
} }
if features["attention_mask"].sum() == 0: if features["attention_mask"].sum() == 0: # for pad tokens
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape)) seq_len = features["input_ids"].shape[-1]
features["position_ids"] = (
torch.arange(seq_len).view(1, 1, seq_len).expand(3, *features["input_ids"].shape).contiguous()
)
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0]) features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
return return
@@ -220,7 +223,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
unpadded_length = int(features["attention_mask"][0].bool().sum().item()) unpadded_length = int(features["attention_mask"][0].bool().sum().item())
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0) right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length) fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length)) # avoid continual cuseqlens breaking varlen attention @kuangdd
# https://github.com/hiyouga/LlamaFactory/issues/10452
dummy_image_right_padding_mrope = (
torch.arange(fake_input_padding_length)
.view(1, 1, fake_input_padding_length)
.expand(3, bsz, fake_input_padding_length)
)
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length)) dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected" assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
dummy_mm_inputs = copy.deepcopy(mm_inputs) dummy_mm_inputs = copy.deepcopy(mm_inputs)
@@ -280,6 +289,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
) )
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq) self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
sample_position_ids.append(subseq_features["position_ids"]) sample_position_ids.append(subseq_features["position_ids"])
all_position_ids.append(torch.cat(sample_position_ids, dim=-1)) all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0 batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
@@ -418,14 +428,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list] boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list]
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list) has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
if has_dummy_image and has_packing: if has_dummy_image and has_packing:
# FIXME: too tricky, need to be refactored # FIXME: too tricky, need to be refactored @kuangdd
features["has_dummy_image"] = True features["has_dummy_image"] = True
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path. # When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
if not has_packing: if not has_packing:
self._compute_rope_position_ids(features, mm_inputs) self._compute_rope_position_ids(features, mm_inputs)
else: else:
if is_omni: if is_omni: # TODO: support omni models for packed sequences @kuangdd
raise RuntimeError("Omni models are not supported for packed sequences for now.") raise RuntimeError("Omni models are not supported for packed sequences for now.")
self._compute_rope_position_ids_with_packing( self._compute_rope_position_ids_with_packing(