From 4741eec2d106bbba30050b0c410ffbbc193c8fce Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Mon, 13 Jan 2025 14:19:38 +0800 Subject: [PATCH] fix style Former-commit-id: 0cc7260a93bf7c65451e376245aa143f9237d7d8 --- src/llamafactory/data/collator.py | 2 +- src/llamafactory/data/mm_plugin.py | 6 +++--- tests/data/test_mm_plugin.py | 32 ++++++++++++++++++------------ 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 67652125..dfd853ca 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -157,7 +157,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0) new_features = {"data": features} - new_features.update({"labels": features['labels']}) + new_features.update({"labels": features["labels"]}) features = new_features return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 85e0f62f..acd1981e 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -384,7 +384,7 @@ class CpmOPlugin(BasePlugin): image_bounds_list = [] valid_image_nums_ls = [] flag = False - + for input_ids in batch_ids: input_ids_ = torch.tensor(input_ids) start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( @@ -405,8 +405,8 @@ class CpmOPlugin(BasePlugin): ] ) image_bounds_list.append(image_bounds) - - if not flag and len(images)>0: + + if not flag and len(images) > 0: valid_image_nums_ls = [1 for _ in range(len(batch_ids))] image_bounds_list = [torch.arange(64) for _ in range(len(batch_ids))] diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 21a88e62..12b8dfab 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -76,11 +76,16 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: if isinstance(batch_a[key], torch.Tensor): assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]): - assert len(batch_a[key]) == len(batch_b[key]) - for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]): - assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5) - elif isinstance(batch_a[key], list) and all(isinstance(item, list) for item in batch_a[key]) \ - and len(batch_a[key])>0 and len(batch_a[key][0])>0 and isinstance(batch_a[key][0][0], torch.Tensor): + assert len(batch_a[key]) == len(batch_b[key]) + for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]): + assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5) + elif ( + isinstance(batch_a[key], list) + and all(isinstance(item, list) for item in batch_a[key]) + and len(batch_a[key]) > 0 + and len(batch_a[key][0]) > 0 + and isinstance(batch_a[key][0][0], torch.Tensor) + ): for item_a, item_b in zip(batch_a[key], batch_b[key]): assert len(item_a) == len(item_a) for tensor_a, tensor_b in zip(item_a, item_b): @@ -140,18 +145,19 @@ def test_cpm_o_plugin(): check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module} image_seqlen = 64 check_inputs["expected_mm_messages"] = [ - {key: value.replace("", f"0{'' * image_seqlen}") for key, value in message.items()} + { + key: value.replace("", f"0{'' * image_seqlen}") + for key, value in message.items() + } for message in MM_MESSAGES ] check_inputs["expected_mm_inputs"] = { - "pixel_values": [[]], - "image_sizes": [[]], - "tgt_sizes": [[]], - "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)] - } - check_inputs["expected_no_mm_inputs"] = { - "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)] + "pixel_values": [[]], + "image_sizes": [[]], + "tgt_sizes": [[]], + "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)], } + check_inputs["expected_no_mm_inputs"] = {"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)]} _check_plugin(**check_inputs)