mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
format style
This commit is contained in:
@@ -27,7 +27,9 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
|
||||
if dataset_attr.response:
|
||||
if isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
||||
response = [
|
||||
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
@@ -47,10 +49,10 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
dataset_attr.user_tag: Role.USER,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||
dataset_attr.function_tag: Role.FUNCTION
|
||||
dataset_attr.function_tag: Role.FUNCTION,
|
||||
}
|
||||
for i, messages in enumerate(examples[dataset_attr.messages]):
|
||||
messages = messages[:len(messages) // 2 * 2] # should be multiples of 2
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
@@ -65,7 +67,9 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
|
||||
if message[dataset_attr.role_tag] not in accept_tags:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
|
||||
prompt.append({"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]})
|
||||
prompt.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
last_message = prompt.pop(-1)
|
||||
response.append(last_message)
|
||||
@@ -98,12 +102,7 @@ def align_dataset(
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset"
|
||||
desc="Converting format of dataset",
|
||||
)
|
||||
|
||||
return dataset.map(
|
||||
convert_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user