mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
[model] add qwen2.5 vl models (#6779)
This commit is contained in:
@@ -135,12 +135,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
|
||||
input_ids=features["input_ids"],
|
||||
image_grid_thw=mm_inputs.get("image_grid_thw", None),
|
||||
video_grid_thw=mm_inputs.get("video_grid_thw", None),
|
||||
attention_mask=features["attention_mask"],
|
||||
)
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": features["attention_mask"],
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs:
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
|
||||
Reference in New Issue
Block a user