[model] Update ernie_vl to adapt new version (#9665)

This commit is contained in:
Xunpeng Xiao
2025-12-26 19:57:49 +08:00
committed by GitHub
parent a882e2d5fc
commit 3c17f2722c
5 changed files with 24 additions and 16 deletions

View File

@@ -480,21 +480,35 @@ class ErnieVLPlugin(BasePlugin):
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios) self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages) messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
image_idx, video_idx = 0, 0 image_idx, video_idx = 0, 0
for message in messages: for message in messages:
content = message["content"] content = message["content"]
image_token = self.image_token or "<|image@placeholder|>" image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
video_token = self.video_token or "<|video@placeholder|>" video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER, f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>", 1
)
image_idx += 1 image_idx += 1
content = content.replace(
IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1
)
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
video_idx += 1 video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace( content = content.replace(
VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1 VIDEO_PLACEHOLDER, f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>", 1
) )
video_idx += 1
message["content"] = content message["content"] = content
return messages return messages

View File

@@ -981,7 +981,7 @@ register_template(
replace_eos=True, replace_eos=True,
replace_jinja_template=True, replace_jinja_template=True,
template_class=ReasoningTemplate, template_class=ReasoningTemplate,
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"), mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
) )

View File

@@ -205,10 +205,6 @@ def load_model(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval() model.eval()
else: else:
model.train() model.train()

View File

@@ -158,7 +158,7 @@ def patch_config(
# do not cast data type of the model deepspeed zero3 without qlora # do not cast data type of the model deepspeed zero3 without qlora
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None): if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = "auto"
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map: if "device_map" not in init_kwargs and model_args.device_map:

View File

@@ -84,9 +84,7 @@ def load_reference_model(
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained( model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="auto" model_path, torch_dtype=torch.float16, device_map="auto"
) )
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto") model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")