diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index cf60d944..67652125 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -157,7 +157,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0) new_features = {"data": features} - new_features.update(features) + new_features.update({"labels": features['labels']}) features = new_features return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 3de51904..85e0f62f 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -383,6 +383,7 @@ class CpmOPlugin(BasePlugin): self._validate_input(images, videos) image_bounds_list = [] valid_image_nums_ls = [] + flag = False for input_ids in batch_ids: input_ids_ = torch.tensor(input_ids) @@ -394,6 +395,8 @@ class CpmOPlugin(BasePlugin): image_start_tokens += 1 image_end_tokens = torch.where(end_cond)[0] valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + if valid_image_nums > 0: + flag = True valid_image_nums_ls.append(valid_image_nums) image_bounds = torch.hstack( [ @@ -402,6 +405,10 @@ class CpmOPlugin(BasePlugin): ] ) image_bounds_list.append(image_bounds) + + if not flag and len(images)>0: + valid_image_nums_ls = [1 for _ in range(len(batch_ids))] + image_bounds_list = [torch.arange(64) for _ in range(len(batch_ids))] mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls) mm_inputs.update({"image_bound": image_bounds_list}) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 0afa5821..58dcf561 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -571,12 +571,6 @@ _register_template( format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), - format_observation=StringFormatter( - slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] - ), - format_tools=ToolFormatter(tool_format="qwen"), - default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], mm_plugin=get_mm_plugin(name="cpm_o", image_token="", video_token="