implement efficient packing without cross-contamination attention

This commit is contained in:
ancv
2024-06-12 11:56:01 +07:00
parent 972ec9c668
commit b2c367bc61
9 changed files with 287 additions and 8 deletions

View File

@@ -36,7 +36,7 @@ def get_preprocess_and_print_func(
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
if data_args.packing:
if data_args.packing or data_args.efficient_packing:
preprocess_func = partial(
preprocess_packed_supervised_dataset,
template=template,