diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 55db4626..3de51904 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -383,6 +383,7 @@ class CpmOPlugin(BasePlugin): self._validate_input(images, videos) image_bounds_list = [] valid_image_nums_ls = [] + for input_ids in batch_ids: input_ids_ = torch.tensor(input_ids) start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (