format style

This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent f6d6e00337
commit 638234ceee
73 changed files with 1492 additions and 2325 deletions

View File

@@ -27,7 +27,9 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
if dataset_attr.response:
if isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
response = [
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
]
else:
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
else:
@@ -47,10 +49,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
dataset_attr.user_tag: Role.USER,
dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION
dataset_attr.function_tag: Role.FUNCTION,
}
for i, messages in enumerate(examples[dataset_attr.messages]):
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
@@ -65,7 +67,9 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
prompt.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
last_message = prompt.pop(-1)
response.append(last_message)
@@ -98,12 +102,7 @@ def align_dataset(
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache),
desc="Converting format of dataset"
desc="Converting format of dataset",
)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
**kwargs
)
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)