mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[packing] fix GDN crash when meeting dummy image (#10453)
This commit is contained in:
@@ -165,8 +165,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if features["attention_mask"].sum() == 0:
|
||||
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
|
||||
if features["attention_mask"].sum() == 0: # for pad tokens
|
||||
seq_len = features["input_ids"].shape[-1]
|
||||
features["position_ids"] = (
|
||||
torch.arange(seq_len).view(1, 1, seq_len).expand(3, *features["input_ids"].shape).contiguous()
|
||||
)
|
||||
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
|
||||
return
|
||||
|
||||
@@ -220,7 +223,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
|
||||
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
|
||||
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
|
||||
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
|
||||
# avoid continual cuseqlens breaking varlen attention @kuangdd
|
||||
# https://github.com/hiyouga/LlamaFactory/issues/10452
|
||||
dummy_image_right_padding_mrope = (
|
||||
torch.arange(fake_input_padding_length)
|
||||
.view(1, 1, fake_input_padding_length)
|
||||
.expand(3, bsz, fake_input_padding_length)
|
||||
)
|
||||
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
|
||||
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
|
||||
dummy_mm_inputs = copy.deepcopy(mm_inputs)
|
||||
@@ -280,6 +289,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
)
|
||||
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
||||
sample_position_ids.append(subseq_features["position_ids"])
|
||||
|
||||
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
|
||||
|
||||
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
|
||||
@@ -418,14 +428,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list]
|
||||
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
||||
if has_dummy_image and has_packing:
|
||||
# FIXME: too tricky, need to be refactored
|
||||
# FIXME: too tricky, need to be refactored @kuangdd
|
||||
features["has_dummy_image"] = True
|
||||
|
||||
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
|
||||
if not has_packing:
|
||||
self._compute_rope_position_ids(features, mm_inputs)
|
||||
else:
|
||||
if is_omni:
|
||||
if is_omni: # TODO: support omni models for packed sequences @kuangdd
|
||||
raise RuntimeError("Omni models are not supported for packed sequences for now.")
|
||||
|
||||
self._compute_rope_position_ids_with_packing(
|
||||
|
||||
Reference in New Issue
Block a user