mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix some
Former-commit-id: 2ee8ba2f390551af1b865cfa813f5c8b7bbb41c5
This commit is contained in:
parent
aeb4f82ef2
commit
7138b43873
@ -1,8 +1,8 @@
|
|||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -277,31 +277,32 @@ class CpmOPlugin(BasePlugin):
|
|||||||
|
|
||||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||||
|
|
||||||
if num_image_tokens>0:
|
if num_image_tokens > 0:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
pattern = "(<image>./</image>)"
|
pattern = "(<image>./</image>)"
|
||||||
images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
|
images, image_sizes, _ = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
|
||||||
|
|
||||||
input_ids_list = []
|
|
||||||
image_bounds_list = []
|
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for index, message in enumerate(messages):
|
for index, message in enumerate(messages):
|
||||||
text = message['content']
|
text = message["content"]
|
||||||
image_tags = re.findall(pattern, text)
|
image_tags = re.findall(pattern, text)
|
||||||
text_chunks = text.split(pattern)
|
text_chunks = text.split(pattern)
|
||||||
final_text = ""
|
final_text = ""
|
||||||
for i in range(len(image_tags)):
|
for i in range(len(image_tags)):
|
||||||
final_text = final_text + text_chunks[i] + \
|
final_text = (
|
||||||
image_processor.get_slice_image_placeholder(
|
final_text
|
||||||
|
+ text_chunks[i]
|
||||||
|
+ image_processor.get_slice_image_placeholder(
|
||||||
image_sizes[image_index][i],
|
image_sizes[image_index][i],
|
||||||
i,
|
i,
|
||||||
image_processor.max_slice_nums,
|
image_processor.max_slice_nums,
|
||||||
image_processor.use_image_id,
|
image_processor.use_image_id,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
image_index += 1
|
image_index += 1
|
||||||
final_text += text_chunks[-1]
|
final_text += text_chunks[-1]
|
||||||
messages[index]['content'] = final_text
|
messages[index]["content"] = final_text
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
@ -316,7 +317,6 @@ class CpmOPlugin(BasePlugin):
|
|||||||
processor: "ProcessorMixin",
|
processor: "ProcessorMixin",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, "torch.Tensor"]:
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
@ -327,14 +327,16 @@ class CpmOPlugin(BasePlugin):
|
|||||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||||
)
|
)
|
||||||
if "valid_image_nums_ls" in kwargs:
|
if "valid_image_nums_ls" in kwargs:
|
||||||
valid_image_nums_ls = kwargs['valid_image_nums_ls']
|
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
||||||
new_images = []
|
new_images = []
|
||||||
idx = 0
|
idx = 0
|
||||||
for valid_image_nums in valid_image_nums_ls:
|
for valid_image_nums in valid_image_nums_ls:
|
||||||
new_images.append(images[idx:idx+valid_image_nums])
|
new_images.append(images[idx : idx + valid_image_nums])
|
||||||
idx += valid_image_nums
|
idx += valid_image_nums
|
||||||
images = new_images
|
images = new_images
|
||||||
image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt")
|
image_inputs = image_processor(
|
||||||
|
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||||
|
)
|
||||||
mm_inputs.update(image_inputs)
|
mm_inputs.update(image_inputs)
|
||||||
|
|
||||||
if len(videos) != 0:
|
if len(videos) != 0:
|
||||||
@ -348,22 +350,22 @@ class CpmOPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
def trim_and_pad(self, seq, padding_value=0):
|
def trim_and_pad(self, seq, padding_value=0):
|
||||||
return pad_sequence([s for s in seq], batch_first=True, padding_value=padding_value)
|
return pad_sequence(seq, batch_first=True, padding_value=padding_value)
|
||||||
|
|
||||||
def pad_data(self, features):
|
def pad_data(self, features):
|
||||||
features['position_ids'] = [torch.arange(input_ids.size(0)).long() for input_ids in features['input_ids']]
|
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
||||||
features['input_ids'] = self.trim_and_pad(
|
features["input_ids"] = self.trim_and_pad(
|
||||||
[input_ids for input_ids in features['input_ids']],
|
features["input_ids"],
|
||||||
)
|
)
|
||||||
features['position_ids'] = self.trim_and_pad(
|
features["position_ids"] = self.trim_and_pad(
|
||||||
[position_ids for position_ids in features['position_ids']],
|
features["position_ids"],
|
||||||
)
|
)
|
||||||
features['labels'] = self.trim_and_pad(
|
features["labels"] = self.trim_and_pad(
|
||||||
[labels for labels in features['labels']],
|
features["labels"],
|
||||||
padding_value=-100,
|
padding_value=-100,
|
||||||
)
|
)
|
||||||
features['attention_mask'] = self.trim_and_pad(
|
features["attention_mask"] = self.trim_and_pad(
|
||||||
[attention_mask for attention_mask in features['attention_mask']],
|
features["attention_mask"],
|
||||||
)
|
)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@ -379,11 +381,12 @@ class CpmOPlugin(BasePlugin):
|
|||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
image_bounds_list = []
|
image_bounds_list = []
|
||||||
position_ids = []
|
|
||||||
valid_image_nums_ls = []
|
valid_image_nums_ls = []
|
||||||
for input_ids in batch_ids:
|
for input_ids in batch_ids:
|
||||||
input_ids_ = torch.tensor(input_ids)
|
input_ids_ = torch.tensor(input_ids)
|
||||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id)
|
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||||
|
input_ids_ == processor.tokenizer.slice_start_id
|
||||||
|
)
|
||||||
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
||||||
image_start_tokens = torch.where(start_cond)[0]
|
image_start_tokens = torch.where(start_cond)[0]
|
||||||
image_start_tokens += 1
|
image_start_tokens += 1
|
||||||
@ -398,9 +401,11 @@ class CpmOPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
image_bounds_list.append(image_bounds)
|
image_bounds_list.append(image_bounds)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||||
mm_inputs.update({
|
mm_inputs.update(
|
||||||
|
{
|
||||||
"image_bound": image_bounds_list,
|
"image_bound": image_bounds_list,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -570,13 +570,13 @@ _register_template(
|
|||||||
_register_template(
|
_register_template(
|
||||||
name="cpm_o",
|
name="cpm_o",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
),
|
),
|
||||||
format_tools=ToolFormatter(tool_format="qwen"),
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
|
||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"),
|
mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"),
|
||||||
|
@ -24,7 +24,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
import copy
|
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
Loading…
x
Reference in New Issue
Block a user