diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 2e2fb2c8..6a74a843 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -57,9 +57,9 @@ def convert_alpaca( prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery" - if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example + if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}] - if examples[dataset_attr.kto_tag]: + if examples[dataset_attr.kto_tag][i]: response = response + [{"role": Role.ASSISTANT.value, "content": ""}] else: response = [{"role": Role.ASSISTANT.value, "content": ""}] + response