From 8350e508d3612374194cb825594c7a92e454068c Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 31 May 2024 15:33:54 +0800 Subject: [PATCH 1/6] supervised packing with greedy knapsack algorithm Former-commit-id: f9db439cb7511b12aa3524d5fdcc45864aebda91 --- .../data/processors/supervised.py | 102 ++++++++++++++++-- 1 file changed, 92 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index b119aa22..65aa4b4e 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -1,3 +1,5 @@ +import itertools +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional from ...extras.constants import IGNORE_INDEX @@ -16,6 +18,52 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def binary_search_for_fit(numbers, capacity): + """ + Perform binary search to find the largest number that fits into the knapsack with the given capacity. + """ + left, right = 0, len(numbers) - 1 + result = -1 # If no number fits, return -1 + + while left <= right: + mid = (left + right) // 2 + if numbers[mid] <= capacity: + result = mid + left = mid + 1 + else: + right = mid - 1 + + return result + + +def efficient_greedy_knapsack(numbers, capacity): + """ + An efficient greedy algorithm with binary search for the knapsack problem. + """ + numbers.sort() # Sort numbers in ascending order for binary search + knapsacks = [] + + while numbers: + current_knapsack = [] + remaining_capacity = capacity + + while True: + index = binary_search_for_fit(numbers, remaining_capacity) + if index == -1: + break # No more numbers fit in this knapsack + + # Add the found number to the knapsack and update the remaining capacity + current_knapsack.append(numbers[index]) + remaining_capacity -= numbers[index] + + # Remove the number from the list + numbers.pop(index) + + knapsacks.append(current_knapsack) + + return knapsacks + + def preprocess_supervised_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -115,16 +163,50 @@ def preprocess_packed_supervised_dataset( input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) + # prepare for packing + lengths = [] + length2examples_idx = defaultdict(list) + for idx, example in enumerate(input_ids): + length = len(example) + if length > data_args.cutoff_len: + logger.warning("Dropped example with length {} > cutoff_len {}".format(length, data_args.cutoff_len)) + continue + lengths.append(length) + length2examples_idx[length].append(idx) + + knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len) + + for knapsack in knapsacks: + packed_input_ids = [] + packed_labels = [] + + total_length = 0 + for length in knapsack: + total_length += length + idx = length2examples_idx[length].pop() + packed_input_ids.append(input_ids[idx]) + packed_labels.append(labels[idx]) + + # padding to cutoff_len + if total_length < data_args.cutoff_len: + pad_length = data_args.cutoff_len - total_length + packed_input_ids.append([tokenizer.eos_token_id] * pad_length) + packed_labels.append([IGNORE_INDEX] * pad_length) + elif total_length == data_args.cutoff_len: + pad_length = 0 + else: + logger.warning( + "Dropped packed example with total length {} > cutoff_len {}".format( + total_length, data_args.cutoff_len + ) + ) + continue + + # concat all + model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids))) + + model_inputs["labels"].append(list(itertools.chain(*packed_labels))) + model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length) return model_inputs From 0feb2ad35c430dd018dbb187c1eb9371e225167f Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 31 May 2024 21:40:41 +0800 Subject: [PATCH 2/6] fix eos Former-commit-id: 84aee579013f0c095a918a8c61611ccbb1d7fc84 --- .../data/processors/supervised.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 65aa4b4e..f94cebba 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -151,17 +151,11 @@ def preprocess_packed_supervised_dataset( ): if data_args.train_on_prompt: source_mask = source_ids - elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) else: source_mask = [IGNORE_INDEX] * len(source_ids) - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] + input_ids.append(source_ids + target_ids) + labels.append(source_mask + target_ids) # prepare for packing lengths = [] @@ -174,7 +168,8 @@ def preprocess_packed_supervised_dataset( lengths.append(length) length2examples_idx[length].append(idx) - knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len) + # cutoff_len - 1 for efficient_eos + knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos)) for knapsack in knapsacks: packed_input_ids = [] @@ -190,8 +185,15 @@ def preprocess_packed_supervised_dataset( # padding to cutoff_len if total_length < data_args.cutoff_len: pad_length = data_args.cutoff_len - total_length - packed_input_ids.append([tokenizer.eos_token_id] * pad_length) - packed_labels.append([IGNORE_INDEX] * pad_length) + if template.efficient_eos: + # 确保有 eos + packed_input_ids.append([tokenizer.eos_token_id] * pad_length) + packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1)) + else: + # 无 eos 的情况下,使用 0 填充? + packed_input_ids.append([0] * pad_length) + packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1)) + elif total_length == data_args.cutoff_len: pad_length = 0 else: From 62d55b71a36dc5b79b506d06e2f60c88c5d76306 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 31 May 2024 21:43:08 +0800 Subject: [PATCH 3/6] remove empty line Former-commit-id: b47e3174472f458a3a8b84a66b475da8fce6db79 --- src/llamafactory/data/processors/supervised.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index f94cebba..eaceb5b8 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -206,7 +206,6 @@ def preprocess_packed_supervised_dataset( # concat all model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids))) - model_inputs["labels"].append(list(itertools.chain(*packed_labels))) model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length) From 21df5f0bd04a4bc3330a3d6e1b3715628ffb7555 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 7 Jun 2024 03:38:04 +0800 Subject: [PATCH 4/6] Update supervised.py Former-commit-id: 8cecade7082a52f413517ea20b1c5dd812db8e53 --- .../data/processors/supervised.py | 233 ++++++++---------- 1 file changed, 107 insertions(+), 126 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index eaceb5b8..cd49fd0c 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -1,10 +1,10 @@ -import itertools +import bisect from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack if TYPE_CHECKING: @@ -18,29 +18,19 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def binary_search_for_fit(numbers, capacity): +def search_for_fit(numbers: Sequence[int], capacity: int) -> int: + r""" + Finds the index of largest number that fits into the knapsack with the given capacity. """ - Perform binary search to find the largest number that fits into the knapsack with the given capacity. - """ - left, right = 0, len(numbers) - 1 - result = -1 # If no number fits, return -1 - - while left <= right: - mid = (left + right) // 2 - if numbers[mid] <= capacity: - result = mid - left = mid + 1 - else: - right = mid - 1 - - return result + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) -def efficient_greedy_knapsack(numbers, capacity): - """ +def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: + r""" An efficient greedy algorithm with binary search for the knapsack problem. """ - numbers.sort() # Sort numbers in ascending order for binary search + numbers.sort() # sort numbers in ascending order for binary search knapsacks = [] while numbers: @@ -48,22 +38,60 @@ def efficient_greedy_knapsack(numbers, capacity): remaining_capacity = capacity while True: - index = binary_search_for_fit(numbers, remaining_capacity) + index = search_for_fit(numbers, remaining_capacity) if index == -1: - break # No more numbers fit in this knapsack + break # no more numbers fit in this knapsack - # Add the found number to the knapsack and update the remaining capacity - current_knapsack.append(numbers[index]) - remaining_capacity -= numbers[index] - - # Remove the number from the list - numbers.pop(index) + remaining_capacity -= numbers[index] # update the remaining capacity + current_knapsack.append(numbers.pop(index)) # add the number to knapsack knapsacks.append(current_knapsack) return knapsacks +def _encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Tuple[List[int], List[int]]: + if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models + prompt[0]["content"] = template.image_token + prompt[0]["content"] + + messages = prompt + response + input_ids, labels = [], [] + + if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models + image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) + input_ids += [image_token_id] * getattr(processor, "image_seq_length") + labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + + encoded_pairs = template.encode_multiturn( + tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len + ) + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels + + def preprocess_supervised_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -84,41 +112,16 @@ def preprocess_supervised_dataset( logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - examples["prompt"][i][0]["content"] = template.image_token + examples["prompt"][i][0]["content"] - - messages = examples["prompt"][i] + examples["response"][i] - input_ids, labels = [], [] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - input_ids += [image_token_id] * getattr(processor, "image_seq_length") - labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") - - for turn_idx, (source_ids, target_ids) in enumerate( - template.encode_multiturn( - tokenizer, - messages, - examples["system"][i], - examples["tools"][i], - data_args.cutoff_len, - data_args.reserved_label_len, - ) - ): - if data_args.train_on_prompt: - source_mask = source_ids - elif turn_idx != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] - + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) @@ -138,76 +141,54 @@ def preprocess_packed_supervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - input_ids, labels = [], [] + valid_num = 0 + batch_input_ids, batch_labels = [], [] + lengths = [] + length2indexes = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue - messages = examples["prompt"][i] + examples["response"][i] - for source_ids, target_ids in template.encode_multiturn( - tokenizer, messages, examples["system"][i], examples["tools"][i] - ): - if data_args.train_on_prompt: - source_mask = source_ids - else: - source_mask = [IGNORE_INDEX] * len(source_ids) - - input_ids.append(source_ids + target_ids) - labels.append(source_mask + target_ids) - - # prepare for packing - lengths = [] - length2examples_idx = defaultdict(list) - for idx, example in enumerate(input_ids): - length = len(example) + input_ids, labels = _encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + tools=examples["tools"][i], + template=template, + tokenizer=tokenizer, + processor=None, + data_args=data_args, + ) + length = len(input_ids) if length > data_args.cutoff_len: - logger.warning("Dropped example with length {} > cutoff_len {}".format(length, data_args.cutoff_len)) - continue - lengths.append(length) - length2examples_idx[length].append(idx) - - # cutoff_len - 1 for efficient_eos - knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos)) - - for knapsack in knapsacks: - packed_input_ids = [] - packed_labels = [] - - total_length = 0 - for length in knapsack: - total_length += length - idx = length2examples_idx[length].pop() - packed_input_ids.append(input_ids[idx]) - packed_labels.append(labels[idx]) - - # padding to cutoff_len - if total_length < data_args.cutoff_len: - pad_length = data_args.cutoff_len - total_length - if template.efficient_eos: - # 确保有 eos - packed_input_ids.append([tokenizer.eos_token_id] * pad_length) - packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1)) - else: - # 无 eos 的情况下,使用 0 填充? - packed_input_ids.append([0] * pad_length) - packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1)) - - elif total_length == data_args.cutoff_len: - pad_length = 0 + logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) else: - logger.warning( - "Dropped packed example with total length {} > cutoff_len {}".format( - total_length, data_args.cutoff_len - ) - ) - continue + lengths.append(length) + length2indexes[length].append(valid_num) + batch_input_ids.append(input_ids) + batch_labels.append(labels) + valid_num += 1 - # concat all - model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids))) - model_inputs["labels"].append(list(itertools.chain(*packed_labels))) - model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length) + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len) + for knapsack in knapsacks: + packed_input_ids, packed_labels = [], [] + for length in knapsack: + index = length2indexes[length].pop() + packed_input_ids += batch_input_ids[index] + packed_labels += batch_labels[index] + + if len(packed_input_ids) <= data_args.cutoff_len: + pad_length = data_args.cutoff_len - len(packed_input_ids) + packed_input_ids += [tokenizer.pad_token_id] * pad_length + packed_labels += [IGNORE_INDEX] * pad_length + else: + raise ValueError("The length of packed example exceeds the cutoff length.") + + model_inputs["input_ids"].append(packed_input_ids) + model_inputs["attention_mask"].append([1] * len(packed_input_ids)) + model_inputs["labels"].append(packed_labels) return model_inputs From fd7bd911a6573b9cd31b4a7b9dd1fdf699a39592 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 7 Jun 2024 03:38:23 +0800 Subject: [PATCH 5/6] Update supervised.py Former-commit-id: 788e8232fc4ed58ab2439a9bc2e38f64e12c6eb3 --- src/llamafactory/data/processors/supervised.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index cd49fd0c..502b591c 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack +from .mm_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: From e3ef239bc06d513db606a3bfb20e8b83e4b44388 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 7 Jun 2024 03:42:08 +0800 Subject: [PATCH 6/6] Update supervised.py Former-commit-id: c09ad8bab38bc2f151da3a924eba225111af2481 --- src/llamafactory/data/processors/supervised.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 502b591c..a340a1ab 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -179,15 +179,16 @@ def preprocess_packed_supervised_dataset( packed_input_ids += batch_input_ids[index] packed_labels += batch_labels[index] - if len(packed_input_ids) <= data_args.cutoff_len: + if len(packed_input_ids) < data_args.cutoff_len: pad_length = data_args.cutoff_len - len(packed_input_ids) packed_input_ids += [tokenizer.pad_token_id] * pad_length packed_labels += [IGNORE_INDEX] * pad_length - else: - raise ValueError("The length of packed example exceeds the cutoff length.") + + if len(packed_input_ids) != data_args.cutoff_len: + raise ValueError("The length of packed example should be identical to the cutoff length.") model_inputs["input_ids"].append(packed_input_ids) - model_inputs["attention_mask"].append([1] * len(packed_input_ids)) + model_inputs["attention_mask"].append([1] * data_args.cutoff_len) model_inputs["labels"].append(packed_labels) return model_inputs