diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index f4f97feb..8539a879 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -22,6 +22,13 @@ import torch import torch.nn.functional as F from transformers import DataCollatorForSeq2Seq +from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER +from ..extras.packages import is_pillow_available + + +if is_pillow_available(): + from PIL import Image + if TYPE_CHECKING: from transformers import ProcessorMixin @@ -73,7 +80,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): r""" Data collator that supports VLMs. - Features should contain input_ids, attention_mask, labels and images. + Features should contain input_ids, attention_mask, labels, and optionally contain images and videos. """ template: Optional["Template"] = None @@ -90,6 +97,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): batch_vidlens.append(len(videos)) batch_input_ids.append(feature["input_ids"]) + if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case + fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] + fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] + fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) + fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) + features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids + features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) + features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids) + batch_images = fake_images + batch_input_ids[0] = features[0]["input_ids"] + mm_inputs = self.template.mm_plugin.get_mm_inputs( batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor ) @@ -99,7 +117,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): feature["token_type_ids"] = token_type_ids[i] features: Dict[str, "torch.Tensor"] = super().__call__(features) - if "cross_attention_mask" in mm_inputs: # for mllama inputs + if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") seq_len = features["input_ids"].size(1) orig_len = cross_attention_mask.size(1) diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index 58035ac2..dcb53d6b 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -12,9 +12,105 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +import os -from llamafactory.data.collator import prepare_4d_attention_mask +import torch +from PIL import Image + +from llamafactory.data import get_template_and_fix_tokenizer +from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask +from llamafactory.extras.constants import IGNORE_INDEX +from llamafactory.hparams import get_infer_args +from llamafactory.model import load_tokenizer + + +TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + + +def test_base_collator(): + model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"}) + tokenizer_module = load_tokenizer(model_args) + template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) + data_collator = MultiModalDataCollatorForSeq2Seq( + template=template, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX, + **tokenizer_module, + ) + p = tokenizer_module["tokenizer"].pad_token_id + q = IGNORE_INDEX + features = [ + { + "input_ids": [0, 1, 2, 3, 4, 5], + "attention_mask": [1, 1, 1, 1, 1, 1], + "labels": [q, q, 2, 3, 4, 5], + }, + { + "input_ids": [6, 7], + "attention_mask": [1, 1], + "labels": [q, 7], + }, + ] + batch_input = data_collator(features) + expected_input = { + "input_ids": [ + [0, 1, 2, 3, 4, 5, p, p], + [6, 7, p, p, p, p, p, p], + ], + "attention_mask": [ + [1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + ], + "labels": [ + [q, q, 2, 3, 4, 5, q, q], + [q, 7, q, q, q, q, q, q], + ], + } + for k in batch_input.keys(): + assert batch_input[k].eq(torch.tensor(expected_input[k])).all() + + +def test_multimodal_collator(): + model_args, data_args, *_ = get_infer_args( + {"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"} + ) + tokenizer_module = load_tokenizer(model_args) + template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) + data_collator = MultiModalDataCollatorForSeq2Seq( + template=template, + pad_to_multiple_of=4, + label_pad_token_id=IGNORE_INDEX, + **tokenizer_module, + ) + p = tokenizer_module["tokenizer"].pad_token_id + q = IGNORE_INDEX + s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>") + e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>") + m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>") + fake_image = Image.new("RGB", (64, 64), (255, 255, 255)) + + features = [ + { + "input_ids": [0, 1, 2, 3], + "attention_mask": [1, 1, 1, 1], + "labels": [0, 1, 2, 3], + }, + ] + batch_input = data_collator(features) + expected_input = { + "input_ids": [ + [0, 1, 2, 3, s, m, m, m, m, e, p, p], + ], + "attention_mask": [ + [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + ], + "labels": [ + [0, 1, 2, 3, q, q, q, q, q, q, q, q], + ], + **tokenizer_module["processor"].image_processor(fake_image), + } + for k in batch_input.keys(): + assert batch_input[k].eq(torch.tensor(expected_input[k])).all() def test_4d_attention_mask(): diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 21d0ebad..c9084af0 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -13,14 +13,14 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Sequence import pytest import torch from PIL import Image from llamafactory.data.mm_plugin import get_mm_plugin -from llamafactory.hparams import ModelArguments +from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer @@ -29,6 +29,7 @@ if TYPE_CHECKING: from transformers.image_processing_utils import BaseImageProcessor from llamafactory.data.mm_plugin import BasePlugin + from llamafactory.model.loader import TokenizerModule HF_TOKEN = os.getenv("HF_TOKEN") @@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: assert batch_a[key] == batch_b[key] -def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]: - model_args = ModelArguments(model_name_or_path=model_name_or_path) - tokenizer_module = load_tokenizer(model_args) - return tokenizer_module["tokenizer"], tokenizer_module["processor"] +def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule": + model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"}) + return load_tokenizer(model_args) def _check_plugin( @@ -121,73 +121,75 @@ def _check_plugin( def test_base_plugin(): - tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) + tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) base_plugin = get_mm_plugin(name="base", image_token="") - check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor} + check_inputs = {"plugin": base_plugin, **tokenizer_module} _check_plugin(**check_inputs) def test_llava_plugin(): - tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") - llava_plugin = get_mm_plugin(name="llava", image_token="") image_seqlen = 576 - check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor} + tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") + llava_plugin = get_mm_plugin(name="llava", image_token="") + check_inputs = {"plugin": llava_plugin, **tokenizer_module} check_inputs["expected_mm_messages"] = [ {key: value.replace("", "" * image_seqlen) for key, value in message.items()} for message in MM_MESSAGES ] - check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) + check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) _check_plugin(**check_inputs) def test_llava_next_plugin(): - tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf") - llava_next_plugin = get_mm_plugin(name="llava_next", image_token="") - check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor} image_seqlen = 1176 + tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf") + llava_next_plugin = get_mm_plugin(name="llava_next", image_token="") + check_inputs = {"plugin": llava_next_plugin, **tokenizer_module} check_inputs["expected_mm_messages"] = [ {key: value.replace("", "" * image_seqlen) for key, value in message.items()} for message in MM_MESSAGES ] - check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) + check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) _check_plugin(**check_inputs) def test_llava_next_video_plugin(): - tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf") - llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="", video_token="