mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
250 lines
10 KiB
Python
250 lines
10 KiB
Python
# Copyright 2024 the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from collections import defaultdict
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
|
|
|
from ...extras.constants import IGNORE_INDEX
|
|
from ...extras.logging import get_logger
|
|
from .processor_utils import (
|
|
get_paligemma_token_type_ids,
|
|
get_pixel_values,
|
|
get_qwen2vl_image_inputs,
|
|
greedy_knapsack,
|
|
infer_seqlen,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from PIL.Image import Image as ImageObject
|
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
|
|
from ...hparams import DataArguments
|
|
from ..template import Template
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def _encode_supervised_example(
|
|
prompt: Sequence[Dict[str, str]],
|
|
response: Sequence[Dict[str, str]],
|
|
system: Optional[str],
|
|
tools: Optional[str],
|
|
template: "Template",
|
|
images: Sequence["ImageObject"],
|
|
tokenizer: "PreTrainedTokenizer",
|
|
processor: Optional["ProcessorMixin"],
|
|
cutoff_len: int,
|
|
train_on_prompt: bool,
|
|
mask_history: bool,
|
|
) -> Tuple[List[int], List[int]]:
|
|
if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
|
image_processor = getattr(processor, "image_processor")
|
|
merge_length = image_processor.merge_size**2
|
|
if len(images) > 0:
|
|
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
|
|
index = 0
|
|
for message in prompt:
|
|
content = message["content"]
|
|
while "<|image_pad|>" in content:
|
|
content = content.replace(
|
|
"<|image_pad|>",
|
|
template.vision_start_token
|
|
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
|
|
+ template.vision_end_token,
|
|
1,
|
|
)
|
|
index += 1
|
|
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
|
|
elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
|
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
|
|
|
messages = prompt + response
|
|
input_ids, labels = [], []
|
|
|
|
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
|
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
|
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
|
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
|
|
|
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
|
total_length = 1 if template.efficient_eos else 0
|
|
if mask_history:
|
|
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
|
|
|
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
|
if total_length >= cutoff_len:
|
|
break
|
|
|
|
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)
|
|
source_ids = source_ids[:source_len]
|
|
target_ids = target_ids[:target_len]
|
|
total_length += source_len + target_len
|
|
|
|
if train_on_prompt:
|
|
source_label = source_ids
|
|
elif template.efficient_eos:
|
|
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
|
else:
|
|
source_label = [IGNORE_INDEX] * source_len
|
|
|
|
if mask_history and turn_idx != 0: # train on the last turn only
|
|
target_label = [IGNORE_INDEX] * target_len
|
|
else:
|
|
target_label = target_ids
|
|
|
|
if mask_history: # reversed sequences
|
|
input_ids = source_ids + target_ids + input_ids
|
|
labels = source_label + target_label + labels
|
|
else:
|
|
input_ids += source_ids + target_ids
|
|
labels += source_label + target_label
|
|
|
|
if template.efficient_eos:
|
|
input_ids += [tokenizer.eos_token_id]
|
|
labels += [tokenizer.eos_token_id]
|
|
|
|
return input_ids, labels
|
|
|
|
|
|
def preprocess_supervised_dataset(
|
|
examples: Dict[str, List[Any]],
|
|
template: "Template",
|
|
tokenizer: "PreTrainedTokenizer",
|
|
processor: Optional["ProcessorMixin"],
|
|
data_args: "DataArguments",
|
|
) -> Dict[str, List[List[int]]]:
|
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
if processor is not None:
|
|
model_inputs["pixel_values"] = []
|
|
if hasattr(processor, "image_seq_length"): # paligemma models
|
|
model_inputs["token_type_ids"] = []
|
|
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
|
model_inputs["image_grid_thw"] = []
|
|
|
|
for i in range(len(examples["prompt"])):
|
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
|
continue
|
|
|
|
input_ids, labels = _encode_supervised_example(
|
|
prompt=examples["prompt"][i],
|
|
response=examples["response"][i],
|
|
system=examples["system"][i],
|
|
tools=examples["tools"][i],
|
|
images=examples["images"][i],
|
|
template=template,
|
|
tokenizer=tokenizer,
|
|
processor=processor,
|
|
cutoff_len=data_args.cutoff_len,
|
|
train_on_prompt=data_args.train_on_prompt,
|
|
mask_history=data_args.mask_history,
|
|
)
|
|
model_inputs["input_ids"].append(input_ids)
|
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
|
model_inputs["labels"].append(labels)
|
|
if processor is not None:
|
|
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
|
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
|
|
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
|
|
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
|
|
else:
|
|
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
|
if hasattr(processor, "image_seq_length"): # paligemma models
|
|
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
|
|
|
return model_inputs
|
|
|
|
|
|
def preprocess_packed_supervised_dataset(
|
|
examples: Dict[str, List[Any]],
|
|
template: "Template",
|
|
tokenizer: "PreTrainedTokenizer",
|
|
data_args: "DataArguments",
|
|
) -> Dict[str, List[List[int]]]:
|
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
|
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
|
valid_num = 0
|
|
batch_input_ids, batch_labels = [], []
|
|
lengths = []
|
|
length2indexes = defaultdict(list)
|
|
for i in range(len(examples["prompt"])):
|
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
|
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
|
continue
|
|
|
|
input_ids, labels = _encode_supervised_example(
|
|
prompt=examples["prompt"][i],
|
|
response=examples["response"][i],
|
|
system=examples["system"][i],
|
|
tools=examples["tools"][i],
|
|
template=template,
|
|
tokenizer=tokenizer,
|
|
processor=None,
|
|
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
|
|
train_on_prompt=data_args.train_on_prompt,
|
|
mask_history=data_args.mask_history,
|
|
)
|
|
length = len(input_ids)
|
|
if length > data_args.cutoff_len:
|
|
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
|
else:
|
|
lengths.append(length)
|
|
length2indexes[length].append(valid_num)
|
|
batch_input_ids.append(input_ids)
|
|
batch_labels.append(labels)
|
|
valid_num += 1
|
|
|
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
|
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
|
|
for knapsack in knapsacks:
|
|
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
|
for i, length in enumerate(knapsack):
|
|
index = length2indexes[length].pop()
|
|
packed_input_ids += batch_input_ids[index]
|
|
packed_labels += batch_labels[index]
|
|
if data_args.neat_packing:
|
|
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
|
else:
|
|
packed_attention_masks += [1] * len(batch_input_ids[index])
|
|
|
|
if len(packed_input_ids) < data_args.cutoff_len:
|
|
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
|
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
|
packed_labels += [IGNORE_INDEX] * pad_length
|
|
if data_args.neat_packing:
|
|
packed_attention_masks += [0] * pad_length
|
|
else:
|
|
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
|
|
|
if len(packed_input_ids) != data_args.cutoff_len:
|
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
|
|
|
model_inputs["input_ids"].append(packed_input_ids)
|
|
model_inputs["attention_mask"].append(packed_attention_masks)
|
|
model_inputs["labels"].append(packed_labels)
|
|
|
|
return model_inputs
|
|
|
|
|
|
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
|
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
|
print("input_ids:\n{}".format(example["input_ids"]))
|
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
print("label_ids:\n{}".format(example["labels"]))
|
|
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|