mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[misc] bump transformers version upperbound (#10446)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
@@ -230,22 +231,39 @@ def _make_packed_features(
|
||||
]
|
||||
|
||||
|
||||
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
|
||||
def _get_expected_position_ids(
|
||||
packing_params,
|
||||
get_rope_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
image_token_id: int | None = None,
|
||||
video_token_id: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
bound_list = packing_params["sequence_boundaries"]
|
||||
input_ids_slices = [input_ids[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)]
|
||||
attention_mask_slices = [attention_mask[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)]
|
||||
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
|
||||
needs_mm_token_type_ids = "mm_token_type_ids" in inspect.signature(get_rope_func).parameters
|
||||
all_position_ids = []
|
||||
for i, input_ids_slice in enumerate(input_ids_slices):
|
||||
img_cnt = img_counts_by_subseq[i]
|
||||
if sum(attention_mask_slices[i]) == 0:
|
||||
continue
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids_slice).unsqueeze(0)
|
||||
rope_func_kwargs = {
|
||||
"input_ids": torch.tensor(input_ids_slice).unsqueeze(0),
|
||||
"input_ids": input_ids_tensor,
|
||||
"attention_mask": torch.tensor(attention_mask_slices[i]).unsqueeze(0),
|
||||
"image_grid_thw": [torch.tensor([1, 4, 4])] * img_cnt,
|
||||
}
|
||||
if needs_mm_token_type_ids:
|
||||
mm_token_type_ids = torch.zeros_like(input_ids_tensor)
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[input_ids_tensor == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[input_ids_tensor == video_token_id] = 2
|
||||
rope_func_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
position_ids, _ = get_rope_func(**rope_func_kwargs)
|
||||
all_position_ids.append(position_ids)
|
||||
|
||||
@@ -296,6 +314,8 @@ def test_multimodal_collator_with_packing():
|
||||
data_collator.get_rope_func,
|
||||
features[0]["input_ids"],
|
||||
features[0]["attention_mask"],
|
||||
image_token_id=getattr(model.config, "image_token_id", None),
|
||||
video_token_id=getattr(model.config, "video_token_id", None),
|
||||
)
|
||||
batch_input = data_collator(features) # [3, bsz, seq_len]
|
||||
valid_len = expected_position_ids.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user