fix bug in packed sft dataset

Former-commit-id: de196143064772db770a45235424b3c911b2e147
This commit is contained in:
hiyouga 2023-09-28 01:16:46 +08:00
parent f61a000e73
commit f88088c43d

View File

@ -116,7 +116,7 @@ def preprocess_dataset(
# split by chunks of cutoff_len # split by chunks of cutoff_len
for i in range(0, total_length, block_size): for i in range(0, total_length, block_size):
model_inputs["input_ids"].append(input_ids[i: i + block_size]) model_inputs["input_ids"].append(input_ids[i: i + block_size])
model_inputs["attention_mask"].append([1] * len(block_size)) model_inputs["attention_mask"].append([1] * block_size)
model_inputs["labels"].append(labels[i: i + block_size]) model_inputs["labels"].append(labels[i: i + block_size])
return model_inputs return model_inputs