Former-commit-id: db1d5a4f51
This commit is contained in:
enji.zhou
2024-05-17 13:09:17 +08:00
parent 1bbbcb5895
commit 03956053b8
14 changed files with 5923 additions and 8 deletions

View File

@@ -1,4 +1,4 @@
from .collator import PairwiseDataCollatorWithPadding
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
@@ -6,6 +6,7 @@ from .utils import Role, split_dataset
__all__ = [
"PairwiseDataCollatorWithPadding",
"KTODataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",

View File

@@ -29,7 +29,7 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
@@ -61,6 +61,7 @@ def convert_alpaca(
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
return outputs
@@ -137,6 +138,7 @@ def align_dataset(
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
"tag": {"dtype": "bool", "_type": "Value"},
}
)
kwargs = {}

View File

@@ -49,3 +49,36 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
batch = super().__call__(concatenated_features)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features, return_tensors=None):
concatenated_features = []
kl_concatenated_features = []
tags = []
for feature in features:
concatenated_features.append(
{
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
)
kl_concatenated_features.append(
{
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
)
tags.append(feature["tag"])
batch = super().__call__(concatenated_features)
kl_batch = super().__call__(kl_concatenated_features)
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
batch["tag"] = torch.tensor(tags)
return batch

View File

@@ -116,7 +116,7 @@ def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:

View File

@@ -28,6 +28,7 @@ class DatasetAttr:
""" columns """
system: Optional[str] = None
images: Optional[str] = None
tag: Optional[bool] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
@@ -106,7 +107,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system", "images"]
column_names = ["system", "images", "tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:

View File

@@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
@@ -111,11 +111,102 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["tag"].append(examples["tag"])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
examples['kl_response'] = examples['response'][::-1]
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
input_ids, labels = [], []
kl_input_ids, kl_labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
kl_input_ids += source_ids + target_ids
kl_labels += source_mask + target_ids
if template.efficient_eos:
kl_input_ids += [tokenizer.eos_token_id]
kl_labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["tag"].append(examples["tag"][i])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
if desirable == 0 or undesirable == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
@@ -289,7 +380,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
@@ -328,6 +419,15 @@ def get_preprocess_and_print_func(
data_args=data_args,
)
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
elif stage == "kto":
preprocess_func = partial(
preprocess_kto_dataset,
template=template,
tokenizer=tokenizer,
processor=processor,
data_args=data_args,
)
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
else:
preprocess_func = partial(
preprocess_unsupervised_dataset,