From 0feb2ad35c430dd018dbb187c1eb9371e225167f Mon Sep 17 00:00:00 2001 From: ylfeng Date: Fri, 31 May 2024 21:40:41 +0800 Subject: [PATCH] fix eos Former-commit-id: 84aee579013f0c095a918a8c61611ccbb1d7fc84 --- .../data/processors/supervised.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 65aa4b4e..f94cebba 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -151,17 +151,11 @@ def preprocess_packed_supervised_dataset( ): if data_args.train_on_prompt: source_mask = source_ids - elif len(input_ids) != 0 and template.efficient_eos: - source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) else: source_mask = [IGNORE_INDEX] * len(source_ids) - input_ids += source_ids + target_ids - labels += source_mask + target_ids - - if template.efficient_eos: - input_ids += [tokenizer.eos_token_id] - labels += [tokenizer.eos_token_id] + input_ids.append(source_ids + target_ids) + labels.append(source_mask + target_ids) # prepare for packing lengths = [] @@ -174,7 +168,8 @@ def preprocess_packed_supervised_dataset( lengths.append(length) length2examples_idx[length].append(idx) - knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len) + # cutoff_len - 1 for efficient_eos + knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos)) for knapsack in knapsacks: packed_input_ids = [] @@ -190,8 +185,15 @@ def preprocess_packed_supervised_dataset( # padding to cutoff_len if total_length < data_args.cutoff_len: pad_length = data_args.cutoff_len - total_length - packed_input_ids.append([tokenizer.eos_token_id] * pad_length) - packed_labels.append([IGNORE_INDEX] * pad_length) + if template.efficient_eos: + # 确保有 eos + packed_input_ids.append([tokenizer.eos_token_id] * pad_length) + packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1)) + else: + # 无 eos 的情况下,使用 0 填充? + packed_input_ids.append([0] * pad_length) + packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1)) + elif total_length == data_args.cutoff_len: pad_length = 0 else: