diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index 5140f9d8..070a917c 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -53,32 +53,34 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr" if len(messages) == 0: continue - n_sys = 0 prompt = [] response = [] + n_sys = 0 for turn_idx, message in enumerate(messages): - accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag, dataset_attr.assistant_tag, dataset_attr.function_tag] - - if message[dataset_attr.role_tag] == "system": + if dataset_attr.system_tag and message[dataset_attr.role_tag] == dataset_attr.system_tag: outputs["system"].append(message[dataset_attr.content_tag]) - n_sys += 1 - elif message[dataset_attr.role_tag] not in accept_tags: - print("sytem attr", dataset_attr.system) - print("accepted tags", accept_tags) - raise ValueError("Invalid role tag in {}.".format(messages)) + n_sys = 1 + + if (turn_idx - n_sys) % 2 == 0: + accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag] else: - prompt.append( - {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} - ) + accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag] + + if message[dataset_attr.role_tag] not in accept_tags: + raise ValueError("Invalid role tag in {}.".format(messages)) + + prompt.append( + {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} + ) last_message = prompt.pop(-1) response.append(last_message) outputs["prompt"].append(prompt) outputs["response"].append(response) - if n_sys == 0: + if not dataset_attr.system_tag: outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") - assert n_sys <= 1 + return outputs diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 474aad83..aa6319fd 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -37,6 +37,10 @@ class DatasetAttr: assistant_tag: Optional[str] = "gpt" observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" + system_tag: Optional[str] = None + + assert system_tag is None or system is None, f"Can not provide both system message (system_tag={system_tag}) and system column(system={system})" + def __repr__(self) -> str: return self.dataset_name @@ -95,7 +99,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None)) if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: - for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]: + for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]: setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None)) dataset_list.append(dataset_attr)