Former-commit-id: 84aee579013f0c095a918a8c61611ccbb1d7fc84
This commit is contained in:
ylfeng 2024-05-31 21:40:41 +08:00
parent 8350e508d3
commit 0feb2ad35c

View File

@ -151,17 +151,11 @@ def preprocess_packed_supervised_dataset(
): ):
if data_args.train_on_prompt: if data_args.train_on_prompt:
source_mask = source_ids source_mask = source_ids
elif len(input_ids) != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else: else:
source_mask = [IGNORE_INDEX] * len(source_ids) source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids input_ids.append(source_ids + target_ids)
labels += source_mask + target_ids labels.append(source_mask + target_ids)
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
# prepare for packing # prepare for packing
lengths = [] lengths = []
@ -174,7 +168,8 @@ def preprocess_packed_supervised_dataset(
lengths.append(length) lengths.append(length)
length2examples_idx[length].append(idx) length2examples_idx[length].append(idx)
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len) # cutoff_len - 1 for efficient_eos
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos))
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids = [] packed_input_ids = []
@ -190,8 +185,15 @@ def preprocess_packed_supervised_dataset(
# padding to cutoff_len # padding to cutoff_len
if total_length < data_args.cutoff_len: if total_length < data_args.cutoff_len:
pad_length = data_args.cutoff_len - total_length pad_length = data_args.cutoff_len - total_length
packed_input_ids.append([tokenizer.eos_token_id] * pad_length) if template.efficient_eos:
packed_labels.append([IGNORE_INDEX] * pad_length) # 确保有 eos
packed_input_ids.append([tokenizer.eos_token_id] * pad_length)
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
else:
# 无 eos 的情况下,使用 0 填充?
packed_input_ids.append([0] * pad_length)
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
elif total_length == data_args.cutoff_len: elif total_length == data_args.cutoff_len:
pad_length = 0 pad_length = 0
else: else: