mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 01:20:35 +08:00
[model] support LiquidAI's LFM2.5-VL vision-language model (#9729)
This commit is contained in:
@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LFMVLPlugin(BasePlugin):
|
||||||
|
r"""Plugin for LFM2.5-VL vision-language models.
|
||||||
|
|
||||||
|
LFM2.5-VL uses dynamic image token counts based on image resolution.
|
||||||
|
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
|
||||||
|
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> dict[str, "torch.Tensor"]:
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||||
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||||
|
)["images"]
|
||||||
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
|
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
|
||||||
|
|
||||||
|
if self.expand_mm_tokens and len(images) > 0:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
spatial_shapes = mm_inputs.get("spatial_shapes", [])
|
||||||
|
else:
|
||||||
|
spatial_shapes = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
|
||||||
|
h, w = spatial_shapes[num_image_tokens].tolist()
|
||||||
|
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
|
||||||
|
else:
|
||||||
|
image_seqlen = 1
|
||||||
|
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"ernie_vl": ErnieVLPlugin,
|
"ernie_vl": ErnieVLPlugin,
|
||||||
@@ -2104,6 +2171,7 @@ PLUGINS = {
|
|||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
"lfm2_vl": LFMVLPlugin,
|
||||||
"minicpm_v": MiniCPMVPlugin,
|
"minicpm_v": MiniCPMVPlugin,
|
||||||
"mllama": MllamaPlugin,
|
"mllama": MllamaPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
|
|||||||
@@ -1350,6 +1350,27 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="lfm2_vl",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="lfm"),
|
||||||
|
default_system="You are a helpful multimodal assistant by Liquid AI.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
|
|||||||
@@ -1506,6 +1506,17 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LFM2.5-VL-1.6B": {
|
||||||
|
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="lfm2_vl",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Llama-7B": {
|
"Llama-7B": {
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
|
||||||
|
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||||
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||||
|
|
||||||
|
|||||||
@@ -419,3 +419,15 @@ def test_video_llava_plugin():
|
|||||||
]
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm2_vl_plugin():
|
||||||
|
"""Test LFM2.5-VL plugin instantiation."""
|
||||||
|
# Test plugin can be instantiated with correct tokens
|
||||||
|
lfm2_vl_plugin = get_mm_plugin(name="lfm2_vl", image_token="<image>")
|
||||||
|
assert lfm2_vl_plugin is not None
|
||||||
|
assert lfm2_vl_plugin.image_token == "<image>"
|
||||||
|
assert lfm2_vl_plugin.video_token is None
|
||||||
|
assert lfm2_vl_plugin.audio_token is None
|
||||||
|
assert lfm2_vl_plugin.__class__.__name__ == "LFMVLPlugin"
|
||||||
|
|||||||
Reference in New Issue
Block a user