mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
fix eos
Former-commit-id: 84aee579013f0c095a918a8c61611ccbb1d7fc84
This commit is contained in:
parent
8350e508d3
commit
0feb2ad35c
@ -151,17 +151,11 @@ def preprocess_packed_supervised_dataset(
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
input_ids.append(source_ids + target_ids)
|
||||
labels.append(source_mask + target_ids)
|
||||
|
||||
# prepare for packing
|
||||
lengths = []
|
||||
@ -174,7 +168,8 @@ def preprocess_packed_supervised_dataset(
|
||||
lengths.append(length)
|
||||
length2examples_idx[length].append(idx)
|
||||
|
||||
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
# cutoff_len - 1 for efficient_eos
|
||||
knapsacks = efficient_greedy_knapsack(lengths, data_args.cutoff_len - int(template.efficient_eos))
|
||||
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids = []
|
||||
@ -190,8 +185,15 @@ def preprocess_packed_supervised_dataset(
|
||||
# padding to cutoff_len
|
||||
if total_length < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - total_length
|
||||
packed_input_ids.append([tokenizer.eos_token_id] * pad_length)
|
||||
packed_labels.append([IGNORE_INDEX] * pad_length)
|
||||
if template.efficient_eos:
|
||||
# 确保有 eos
|
||||
packed_input_ids.append([tokenizer.eos_token_id] * pad_length)
|
||||
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
|
||||
else:
|
||||
# 无 eos 的情况下,使用 0 填充?
|
||||
packed_input_ids.append([0] * pad_length)
|
||||
packed_labels.append([tokenizer.eos_token_id] + [IGNORE_INDEX] * (pad_length - 1))
|
||||
|
||||
elif total_length == data_args.cutoff_len:
|
||||
pad_length = 0
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user