diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index f634f21e..b71964b0 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -145,6 +145,7 @@ def convert_sharegpt( if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: logger.warning_rank0(f"Invalid role tag in {messages}.") broken_data = True + break aligned_messages.append( {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} @@ -156,7 +157,10 @@ def convert_sharegpt( logger.warning_rank0(f"Invalid message count in {messages}.") broken_data = True - if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example + if broken_data: + logger.warning_rank0("Skipping this abnormal example.") + prompt, response = [], [] + elif dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example prompt = aligned_messages[:-1] response = aligned_messages[-1:] if example[dataset_attr.kto_tag]: @@ -186,10 +190,6 @@ def convert_sharegpt( prompt = aligned_messages[:-1] response = aligned_messages[-1:] - if broken_data: - logger.warning_rank0("Skipping this abnormal example.") - prompt, response = [], [] - regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args) output = { "_prompt": prompt,