mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[model] Update ernie_vl to adapt new version (#9665)
This commit is contained in:
@@ -480,21 +480,35 @@ class ErnieVLPlugin(BasePlugin):
|
|||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
self._validate_messages(messages, images, videos, audios)
|
self._validate_messages(messages, images, videos, audios)
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
|
|
||||||
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
|
else:
|
||||||
|
image_grid_thw = [None] * len(images)
|
||||||
|
video_grid_thw = [None] * len(videos)
|
||||||
|
|
||||||
image_idx, video_idx = 0, 0
|
image_idx, video_idx = 0, 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
image_token = self.image_token or "<|image@placeholder|>"
|
image_token = self.image_token or "<|IMAGE_PLACEHOLDER|>"
|
||||||
video_token = self.video_token or "<|video@placeholder|>"
|
video_token = self.video_token or "<|VIDEO_PLACEHOLDER|>"
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
image_seqlen = image_grid_thw[image_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER, f"Picture {image_idx + 1}:<|IMAGE_START|>{image_token * image_seqlen}<|IMAGE_END|>", 1
|
||||||
|
)
|
||||||
image_idx += 1
|
image_idx += 1
|
||||||
content = content.replace(
|
|
||||||
IMAGE_PLACEHOLDER, f"Picture {image_idx}:<|IMAGE_START|>{image_token}<|IMAGE_END|>", 1
|
|
||||||
)
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
video_idx += 1
|
video_seqlen = video_grid_thw[video_idx].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER, f"Video {video_idx}:<|VIDEO_START|>{video_token}<|VIDEO_END|>", 1
|
VIDEO_PLACEHOLDER, f"Video {video_idx + 1}:<|VIDEO_START|>{video_token * video_seqlen}<|VIDEO_END|>", 1
|
||||||
)
|
)
|
||||||
|
video_idx += 1
|
||||||
message["content"] = content
|
message["content"] = content
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|||||||
@@ -981,7 +981,7 @@ register_template(
|
|||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
replace_jinja_template=True,
|
replace_jinja_template=True,
|
||||||
template_class=ReasoningTemplate,
|
template_class=ReasoningTemplate,
|
||||||
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|image@placeholder|>", video_token="<|video@placeholder|>"),
|
mm_plugin=get_mm_plugin(name="ernie_vl", image_token="<|IMAGE_PLACEHOLDER|>", video_token="<|VIDEO_PLACEHOLDER|>"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -205,10 +205,6 @@ def load_model(
|
|||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
for param in model.parameters():
|
|
||||||
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
|
|
||||||
param.data = param.data.to(model_args.compute_dtype)
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def patch_config(
|
|||||||
|
|
||||||
# do not cast data type of the model deepspeed zero3 without qlora
|
# do not cast data type of the model deepspeed zero3 without qlora
|
||||||
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
|
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = "auto"
|
||||||
|
|
||||||
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
|
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
|
||||||
if "device_map" not in init_kwargs and model_args.device_map:
|
if "device_map" not in init_kwargs and model_args.device_map:
|
||||||
|
|||||||
@@ -84,8 +84,6 @@ def load_reference_model(
|
|||||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||||
model_path, torch_dtype=torch.float16, device_map="auto"
|
model_path, torch_dtype=torch.float16, device_map="auto"
|
||||||
)
|
)
|
||||||
if not is_trainable:
|
|
||||||
model.v_head = model.v_head.to(torch.float16)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user