From 3c17f2722c4ac190add3a017bf80f3fda1292aaf Mon Sep 17 00:00:00 2001 From: Xunpeng Xiao <124695565+tangefly@users.noreply.github.com> Date: Fri, 26 Dec 2025 19:57:49 +0800 Subject: [PATCH] [model] Update ernie_vl to adapt new version (#9665) --- src/llamafactory/data/mm_plugin.py | 28 +++++++++++++++++++++------- src/llamafactory/data/template.py | 2 +- src/llamafactory/model/loader.py | 4 ---- src/llamafactory/model/patcher.py | 2 +- src/llamafactory/train/test_utils.py | 4 +--- 5 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4a277fd5a..291554021 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -480,21 +480,35 @@ class ErnieVLPlugin(BasePlugin): self._validate_input(processor, images, videos, audios) self._validate_messages(messages, images, videos, audios) messages = deepcopy(messages) + + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + image_idx, video_idx = 0, 0 for message in messages: content = message["content"] - image_token = self.image_token or "<|image@placeholder|>" - video_token = self.video_token or "<|video@placeholder|>" + image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>" + video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>" while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>", 1 + ) image_idx += 1 - content = content.replace( - IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1 - ) while VIDEO_PLACEHOLDER in content: - video_idx += 1 + video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1 content = content.replace( - VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1 + VIDEO_PLACEHOLDER, f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>", 1 ) + video_idx += 1 message["content"] = content return messages diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 6a8a38b7f..db9301063 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -981,7 +981,7 @@ register_template( replace_eos=True, replace_jinja_template=True, template_class=ReasoningTemplate, - mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"), + mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"), ) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 25710c31d..72f510a44 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -205,10 +205,6 @@ def load_model( if not is_trainable: model.requires_grad_(False) - for param in model.parameters(): - if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: - param.data = param.data.to(model_args.compute_dtype) - model.eval() else: model.train() diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 1bf2c1320..7401641aa 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -158,7 +158,7 @@ def patch_config( # do not cast data type of the model deepspeed zero3 without qlora if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None): - init_kwargs["torch_dtype"] = model_args.compute_dtype + init_kwargs["torch_dtype"] = "auto" if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map if "device_map" not in init_kwargs and model_args.device_map: diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index 631dbd87f..0f73d1c5e 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -84,9 +84,7 @@ def load_reference_model( model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto" ) - if not is_trainable: - model.v_head = model.v_head.to(torch.float16) - + return model model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")