mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-27 18:29:08 +08:00
[misc] code lint (#10439)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -152,7 +152,7 @@ def _make_packed_feature(
|
||||
video_subseq_ids = packing_params["video_subseq_ids"]
|
||||
audio_subseq_ids = packing_params["audio_subseq_ids"]
|
||||
unpadded_length = packing_params["unpadded_length"]
|
||||
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
|
||||
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
|
||||
cutoff_plus_one = sequence_boundaries[-1]
|
||||
content_len = unpadded_length
|
||||
pad_len = right_padding_length
|
||||
@@ -229,10 +229,11 @@ def _make_packed_features(
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
|
||||
bound_list = packing_params["sequence_boundaries"]
|
||||
input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
|
||||
attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
|
||||
input_ids_slices = [input_ids[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)]
|
||||
attention_mask_slices = [attention_mask[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)]
|
||||
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
|
||||
all_position_ids = []
|
||||
for i, input_ids_slice in enumerate(input_ids_slices):
|
||||
@@ -296,7 +297,7 @@ def test_multimodal_collator_with_packing():
|
||||
features[0]["input_ids"],
|
||||
features[0]["attention_mask"],
|
||||
)
|
||||
batch_input = data_collator(features) # [3, bsz, seq_len]
|
||||
batch_input = data_collator(features) # [3, bsz, seq_len]
|
||||
valid_len = expected_position_ids.shape[-1]
|
||||
assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user