From 8350e508d3612374194cb825594c7a92e454068c Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 31 May 2024 15:33:54 +0800 Subject: [PATCH] supervised packing with greedy knapsack algorithm Former-commit-id: f9db439cb7511b12aa3524d5fdcc45864aebda91 --- .../data/processors/supervised.py | 102 ++++++++++++++++-- 1 file changed, 92 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index b119aa22..65aa4b4e 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -1,3 +1,5 @@ +import itertools +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional from ...extras.constants import IGNORE_INDEX @@ -16,6 +18,52 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def binary_search_for_fit(numbers, capacity): + """ + Perform binary search to find the largest number that fits into the knapsack with the given capacity. + """ + left, right = 0, len(numbers) - 1 + result = -1 # If no number fits, return -1 + + while left <= right: + mid = (left + right) // 2 + if numbers[mid] <= capacity: + result = mid + left = mid + 1 + else: + right = mid - 1 + + return result + + +def efficient_greedy_knapsack(numbers, capacity): + """ + An efficient greedy algorithm with binary search for the knapsack problem. + """ + numbers.sort() # Sort numbers in ascending order for binary search + knapsacks = [] + + while numbers: + current_knapsack = [] + remaining_capacity = capacity + + while True: + index = binary_search_for_fit(numbers, remaining_capacity) + if index == -1: + break # No more numbers fit in this knapsack + + # Add the found number to the knapsack and update the remaining capacity + current_knapsack.append(numbers[index]) + remaining_capacity -= numbers[index] + + # Remove the number from the list + numbers.pop(index) + + knapsacks.append(current_knapsack) + + return knapsacks + + def preprocess_supervised_dataset( examples: Dict[str, List[Any]], template: "Template", @@ -115,16 +163,50 @@ def preprocess_packed_supervised_dataset( input_ids += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id] - total_length = len(input_ids) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - for i in range(0, total_length, block_size): - if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]): - model_inputs["input_ids"].append(input_ids[i : i + block_size]) - model_inputs["attention_mask"].append([1] * block_size) - model_inputs["labels"].append(labels[i : i + block_size]) + # prepare for packing + lengths = [] + length2examples_idx = defaultdict(list) + for idx, example in enumerate(input_ids): + length = len(example) + if length > data_args.cutoff_len: + logger.warning("Dropped example with length {} > cutoff_len {}".format(length, data_args.cutoff_len)) + continue + lengths.append(length) + length2examples_idx[length].append(idx) + + knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len) + + for knapsack in knapsacks: + packed_input_ids = [] + packed_labels = [] + + total_length = 0 + for length in knapsack: + total_length += length + idx = length2examples_idx[length].pop() + packed_input_ids.append(input_ids[idx]) + packed_labels.append(labels[idx]) + + # padding to cutoff_len + if total_length < data_args.cutoff_len: + pad_length = data_args.cutoff_len - total_length + packed_input_ids.append([tokenizer.eos_token_id] * pad_length) + packed_labels.append([IGNORE_INDEX] * pad_length) + elif total_length == data_args.cutoff_len: + pad_length = 0 + else: + logger.warning( + "Dropped packed example with total length {} > cutoff_len {}".format( + total_length, data_args.cutoff_len + ) + ) + continue + + # concat all + model_inputs["input_ids"].append(list(itertools.chain(*packed_input_ids))) + + model_inputs["labels"].append(list(itertools.chain(*packed_labels))) + model_inputs["attention_mask"].append([1] * total_length + [0] * pad_length) return model_inputs