mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
fix vlm zero3 training
Former-commit-id: dbb9e5b70efab37ed057b2d5822b9d0d23e99fb1
This commit is contained in:
parent
b34c3bb796
commit
0ef1dc4dd5
@ -22,6 +22,13 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import ProcessorMixin
|
||||
@ -73,7 +80,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator that supports VLMs.
|
||||
|
||||
Features should contain input_ids, attention_mask, labels and images.
|
||||
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
|
||||
"""
|
||||
|
||||
template: Optional["Template"] = None
|
||||
@ -90,6 +97,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
|
||||
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
|
||||
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
|
||||
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
|
||||
batch_images = fake_images
|
||||
batch_input_ids[0] = features[0]["input_ids"]
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||
)
|
||||
@ -99,7 +117,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
seq_len = features["input_ids"].size(1)
|
||||
orig_len = cross_attention_mask.size(1)
|
||||
|
@ -12,9 +12,105 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import os
|
||||
|
||||
from llamafactory.data.collator import prepare_4d_attention_mask
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
|
||||
def test_base_collator():
|
||||
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
p = tokenizer_module["tokenizer"].pad_token_id
|
||||
q = IGNORE_INDEX
|
||||
features = [
|
||||
{
|
||||
"input_ids": [0, 1, 2, 3, 4, 5],
|
||||
"attention_mask": [1, 1, 1, 1, 1, 1],
|
||||
"labels": [q, q, 2, 3, 4, 5],
|
||||
},
|
||||
{
|
||||
"input_ids": [6, 7],
|
||||
"attention_mask": [1, 1],
|
||||
"labels": [q, 7],
|
||||
},
|
||||
]
|
||||
batch_input = data_collator(features)
|
||||
expected_input = {
|
||||
"input_ids": [
|
||||
[0, 1, 2, 3, 4, 5, p, p],
|
||||
[6, 7, p, p, p, p, p, p],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
"labels": [
|
||||
[q, q, 2, 3, 4, 5, q, q],
|
||||
[q, 7, q, q, q, q, q, q],
|
||||
],
|
||||
}
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def test_multimodal_collator():
|
||||
model_args, data_args, *_ = get_infer_args(
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
p = tokenizer_module["tokenizer"].pad_token_id
|
||||
q = IGNORE_INDEX
|
||||
s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>")
|
||||
e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>")
|
||||
m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>")
|
||||
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
|
||||
features = [
|
||||
{
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"attention_mask": [1, 1, 1, 1],
|
||||
"labels": [0, 1, 2, 3],
|
||||
},
|
||||
]
|
||||
batch_input = data_collator(features)
|
||||
expected_input = {
|
||||
"input_ids": [
|
||||
[0, 1, 2, 3, s, m, m, m, m, e, p, p],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
"labels": [
|
||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||
],
|
||||
**tokenizer_module["processor"].image_processor(fake_image),
|
||||
}
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def test_4d_attention_mask():
|
||||
|
@ -13,14 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.hparams import ModelArguments
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ if TYPE_CHECKING:
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from llamafactory.data.mm_plugin import BasePlugin
|
||||
from llamafactory.model.loader import TokenizerModule
|
||||
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
@ -82,10 +83,9 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
|
||||
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
|
||||
model_args = ModelArguments(model_name_or_path=model_name_or_path)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
|
||||
def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
|
||||
model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
|
||||
return load_tokenizer(model_args)
|
||||
|
||||
|
||||
def _check_plugin(
|
||||
@ -121,73 +121,75 @@ def _check_plugin(
|
||||
|
||||
|
||||
def test_base_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs = {"plugin": base_plugin, **tokenizer_module}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
image_seqlen = 576
|
||||
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
|
||||
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_next_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
|
||||
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
image_seqlen = 1176
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
|
||||
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
|
||||
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_next_video_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
|
||||
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
image_seqlen = 1176
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||||
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
|
||||
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
def test_paligemma_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
image_seqlen = 256
|
||||
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
|
||||
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
|
||||
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
|
||||
check_inputs["expected_input_ids"] = [
|
||||
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
|
||||
] * image_seqlen + INPUT_IDS
|
||||
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_pixtral_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||
image_slice_height, image_slice_width = 2, 2
|
||||
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace(
|
||||
@ -199,17 +201,17 @@ def test_pixtral_plugin():
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
check_inputs["expected_mm_inputs"].pop("image_sizes")
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
image_seqlen = 4
|
||||
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
|
||||
@ -217,18 +219,18 @@ def test_qwen2_vl_plugin():
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_video_llava_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
||||
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
|
||||
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
image_seqlen = 256
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
||||
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
|
||||
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
_check_plugin(**check_inputs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user