mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] specify position_ids in PackedSupervisedDatasetProcessor for neat_packing (#7318)
* use position_ids for neat_packing with fa2 * revert fa2 changes
This commit is contained in:
parent
6faa6fb53d
commit
f06a74ad4e
@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
packed_images, packed_videos, packed_audios, packed_position_ids = [], [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
packed_audios += batch_audios[index]
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
packed_position_ids += list(range(len(batch_input_ids[index])))
|
||||
else:
|
||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||
|
||||
@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [0] * pad_length
|
||||
packed_position_ids += [0] * pad_length
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
model_inputs["images"].append(packed_images or None)
|
||||
model_inputs["videos"].append(packed_videos or None)
|
||||
model_inputs["audios"].append(packed_audios or None)
|
||||
model_inputs["position_ids"].append(packed_position_ids or None)
|
||||
|
||||
return model_inputs
|
||||
|
Loading…
x
Reference in New Issue
Block a user