[model] support LiquidAI's LFM2.5-VL vision-language model (#9729)

This commit is contained in:
Vo Van Phuc
2026-01-07 16:20:29 +07:00
committed by GitHub
parent b4e051bea4
commit 958fb523a2
5 changed files with 118 additions and 0 deletions

View File

@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
return messages
@dataclass
class LFMVLPlugin(BasePlugin):
r"""Plugin for LFM2.5-VL vision-language models.
LFM2.5-VL uses dynamic image token counts based on image resolution.
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
"""
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
if self.expand_mm_tokens and len(images) > 0:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
spatial_shapes = mm_inputs.get("spatial_shapes", [])
else:
spatial_shapes = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
h, w = spatial_shapes[num_image_tokens].tolist()
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
else:
image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
return messages
PLUGINS = {
"base": BasePlugin,
"ernie_vl": ErnieVLPlugin,
@@ -2104,6 +2171,7 @@ PLUGINS = {
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"lfm2_vl": LFMVLPlugin,
"minicpm_v": MiniCPMVPlugin,
"mllama": MllamaPlugin,
"paligemma": PaliGemmaPlugin,

View File

@@ -1350,6 +1350,27 @@ register_template(
)
register_template(
name="lfm2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"),
format_observation=StringFormatter(
slots=[
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
"<|im_start|>assistant\n"
]
),
format_tools=ToolFormatter(tool_format="lfm"),
default_system="You are a helpful multimodal assistant by Liquid AI.",
stop_words=["<|im_end|>"],
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
)
register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),

View File

@@ -1506,6 +1506,17 @@ register_model_group(
)
register_model_group(
models={
"LFM2.5-VL-1.6B": {
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
},
},
template="lfm2_vl",
multimodal=True,
)
register_model_group(
models={
"Llama-7B": {

View File

@@ -151,6 +151,12 @@ def patch_config(
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
raise RuntimeError(
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
)
if getattr(config, "model_type", None) == "qwen3_omni_moe":
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()

View File

@@ -419,3 +419,15 @@ def test_video_llava_plugin():
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
def test_lfm2_vl_plugin():
"""Test LFM2.5-VL plugin instantiation."""
# Test plugin can be instantiated with correct tokens
lfm2_vl_plugin = get_mm_plugin(name="lfm2_vl", image_token="<image>")
assert lfm2_vl_plugin is not None
assert lfm2_vl_plugin.image_token == "<image>"
assert lfm2_vl_plugin.video_token is None
assert lfm2_vl_plugin.audio_token is None
assert lfm2_vl_plugin.__class__.__name__ == "LFMVLPlugin"