Update supervised.py

Former-commit-id: 8cecade7082a52f413517ea20b1c5dd812db8e53
This commit is contained in:
hoshi-hiyouga 2024-06-07 03:38:04 +08:00 committed by GitHub
parent 62d55b71a3
commit 21df5f0bd0

View File

@ -1,10 +1,10 @@
import itertools import bisect
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values from .mm_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
if TYPE_CHECKING: if TYPE_CHECKING:
@ -18,29 +18,19 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def binary_search_for_fit(numbers, capacity): def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
r"""
Finds the index of largest number that fits into the knapsack with the given capacity.
""" """
Perform binary search to find the largest number that fits into the knapsack with the given capacity. index = bisect.bisect(numbers, capacity)
""" return -1 if index == 0 else (index - 1)
left, right = 0, len(numbers) - 1
result = -1 # If no number fits, return -1
while left <= right:
mid = (left + right) // 2
if numbers[mid] <= capacity:
result = mid
left = mid + 1
else:
right = mid - 1
return result
def efficient_greedy_knapsack(numbers, capacity): def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
""" r"""
An efficient greedy algorithm with binary search for the knapsack problem. An efficient greedy algorithm with binary search for the knapsack problem.
""" """
numbers.sort() # Sort numbers in ascending order for binary search numbers.sort() # sort numbers in ascending order for binary search
knapsacks = [] knapsacks = []
while numbers: while numbers:
@ -48,22 +38,60 @@ def efficient_greedy_knapsack(numbers, capacity):
remaining_capacity = capacity remaining_capacity = capacity
while True: while True:
index = binary_search_for_fit(numbers, remaining_capacity) index = search_for_fit(numbers, remaining_capacity)
if index == -1: if index == -1:
break # No more numbers fit in this knapsack break # no more numbers fit in this knapsack
# Add the found number to the knapsack and update the remaining capacity remaining_capacity -= numbers[index] # update the remaining capacity
current_knapsack.append(numbers[index]) current_knapsack.append(numbers.pop(index)) # add the number to knapsack
remaining_capacity -= numbers[index]
# Remove the number from the list
numbers.pop(index)
knapsacks.append(current_knapsack) knapsacks.append(current_knapsack)
return knapsacks return knapsacks
def _encode_supervised_example(
prompt: Sequence[Dict[str, str]],
response: Sequence[Dict[str, str]],
system: Optional[str],
tools: Optional[str],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response
input_ids, labels = [], []
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
return input_ids, labels
def preprocess_supervised_dataset( def preprocess_supervised_dataset(
examples: Dict[str, List[Any]], examples: Dict[str, List[Any]],
template: "Template", template: "Template",
@ -84,41 +112,16 @@ def preprocess_supervised_dataset(
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models input_ids, labels = _encode_supervised_example(
examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] prompt=examples["prompt"][i],
response=examples["response"][i],
messages = examples["prompt"][i] + examples["response"][i] system=examples["system"][i],
input_ids, labels = [], [] tools=examples["tools"][i],
template=template,
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models tokenizer=tokenizer,
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) processor=processor,
input_ids += [image_token_id] * getattr(processor, "image_seq_length") data_args=data_args,
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
@ -138,76 +141,54 @@ def preprocess_packed_supervised_dataset(
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} valid_num = 0
input_ids, labels = [], [] batch_input_ids, batch_labels = [], []
lengths = []
length2indexes = defaultdict(list)
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
messages = examples["prompt"][i] + examples["response"][i] input_ids, labels = _encode_supervised_example(
for source_ids, target_ids in template.encode_multiturn( prompt=examples["prompt"][i],
tokenizer, messages, examples["system"][i], examples["tools"][i] response=examples["response"][i],
): system=examples["system"][i],
if data_args.train_on_prompt: tools=examples["tools"][i],
source_mask = source_ids template=template,
else: tokenizer=tokenizer,
source_mask = [IGNORE_INDEX] * len(source_ids) processor=None,
data_args=data_args,
input_ids.append(source_ids + target_ids) )
labels.append(source_mask + target_ids) length = len(input_ids)
# prepare for packing
lengths = []
length2examples_idx = defaultdict(list)
for idx, example in enumerate(input_ids):
length = len(example)
if length > data_args.cutoff_len: if length > data_args.cutoff_len:
logger.warning("Dropped example with length {} > cutoff_len {}".format(length, data_args.cutoff_len)) logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
continue else:
lengths.append(length) lengths.append(length)
length2examples_idx[length].append(idx) length2indexes[length].append(valid_num)
batch_input_ids.append(input_ids)
# cutoff_len - 1 for efficient_eos batch_labels.append(labels)
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos)) valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids = [] packed_input_ids, packed_labels = [], []
packed_labels = []
total_length = 0
for length in knapsack: for length in knapsack:
total_length += length index = length2indexes[length].pop()
idx = length2examples_idx[length].pop() packed_input_ids += batch_input_ids[index]
packed_input_ids.append(input_ids[idx]) packed_labels += batch_labels[index]
packed_labels.append(labels[idx])
# padding to cutoff_len if len(packed_input_ids) <= data_args.cutoff_len:
if total_length < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids)
pad_length = data_args.cutoff_len - total_length packed_input_ids += [tokenizer.pad_token_id] * pad_length
if template.efficient_eos: packed_labels += [IGNORE_INDEX] * pad_length
# 确保有 eos
packed_input_ids.append([tokenizer.eos_token_id] * pad_length)
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
else: else:
# 无 eos 的情况下,使用 0 填充? raise ValueError("The length of packed example exceeds the cutoff length.")
packed_input_ids.append([0] * pad_length)
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
elif total_length == data_args.cutoff_len: model_inputs["input_ids"].append(packed_input_ids)
pad_length = 0 model_inputs["attention_mask"].append([1] * len(packed_input_ids))
else: model_inputs["labels"].append(packed_labels)
logger.warning(
"Dropped packed example with total length {} > cutoff_len {}".format(
total_length, data_args.cutoff_len
)
)
continue
# concat all
model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids)))
model_inputs["labels"].append(list(itertools.chain(*packed_labels)))
model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length)
return model_inputs return model_inputs