mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
Merge pull request #3785 from enji-zhou/feature/add_kto
add kto Former-commit-id: 33a354548e78a7f7f51d63f80974920827d30252
This commit is contained in:
commit
97469892c3
@ -32,6 +32,15 @@
|
|||||||
"history": "history"
|
"history": "history"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"kto-mix-test": {
|
||||||
|
"file_name": "kto-mix-test.json",
|
||||||
|
"file_sha1": "91b59f657007dc4b17529fc643v9b9cd6d640fha",
|
||||||
|
"columns": {
|
||||||
|
"prompt": "instruction",
|
||||||
|
"response": "output",
|
||||||
|
"tag": "tag"
|
||||||
|
}
|
||||||
|
},
|
||||||
"glaive_toolcall": {
|
"glaive_toolcall": {
|
||||||
"file_name": "glaive_toolcall_10k.json",
|
"file_name": "glaive_toolcall_10k.json",
|
||||||
"formatting": "sharegpt",
|
"formatting": "sharegpt",
|
||||||
|
5462
data/kto-mix-test.json
Normal file
5462
data/kto-mix-test.json
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
|||||||
from .collator import PairwiseDataCollatorWithPadding
|
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||||
from .utils import Role, split_dataset
|
from .utils import Role, split_dataset
|
||||||
@ -6,6 +6,7 @@ from .utils import Role, split_dataset
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PairwiseDataCollatorWithPadding",
|
"PairwiseDataCollatorWithPadding",
|
||||||
|
"KTODataCollatorWithPadding",
|
||||||
"get_dataset",
|
"get_dataset",
|
||||||
"Template",
|
"Template",
|
||||||
"get_template_and_fix_tokenizer",
|
"get_template_and_fix_tokenizer",
|
||||||
|
@ -29,7 +29,7 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
|
|||||||
def convert_alpaca(
|
def convert_alpaca(
|
||||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[Any]]:
|
) -> Dict[str, List[Any]]:
|
||||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
|
||||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
for i in range(len(examples[dataset_attr.prompt])):
|
for i in range(len(examples[dataset_attr.prompt])):
|
||||||
prompt = []
|
prompt = []
|
||||||
@ -61,6 +61,7 @@ def convert_alpaca(
|
|||||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
outputs["tools"].append("")
|
outputs["tools"].append("")
|
||||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||||
|
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ def align_dataset(
|
|||||||
"system": {"dtype": "string", "_type": "Value"},
|
"system": {"dtype": "string", "_type": "Value"},
|
||||||
"tools": {"dtype": "string", "_type": "Value"},
|
"tools": {"dtype": "string", "_type": "Value"},
|
||||||
"images": [{"_type": "Image"}],
|
"images": [{"_type": "Image"}],
|
||||||
|
"tag": {"dtype": "bool", "_type": "Value"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -49,3 +49,36 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
batch = super().__call__(concatenated_features)
|
batch = super().__call__(concatenated_features)
|
||||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for KTO data.
|
||||||
|
"""
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
concatenated_features = []
|
||||||
|
kl_concatenated_features = []
|
||||||
|
tags = []
|
||||||
|
for feature in features:
|
||||||
|
concatenated_features.append(
|
||||||
|
{
|
||||||
|
"input_ids": feature["input_ids"],
|
||||||
|
"attention_mask": feature["attention_mask"],
|
||||||
|
"labels": feature["labels"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
kl_concatenated_features.append(
|
||||||
|
{
|
||||||
|
"input_ids": feature["kl_input_ids"],
|
||||||
|
"attention_mask": feature["kl_attention_mask"],
|
||||||
|
"labels": feature["kl_labels"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tags.append(feature["tag"])
|
||||||
|
batch = super().__call__(concatenated_features)
|
||||||
|
kl_batch = super().__call__(kl_concatenated_features)
|
||||||
|
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
|
||||||
|
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
|
||||||
|
batch["kl_labels"] = kl_batch["labels"]
|
||||||
|
batch["tag"] = torch.tensor(tags)
|
||||||
|
return batch
|
@ -116,7 +116,7 @@ def get_dataset(
|
|||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
@ -28,6 +28,7 @@ class DatasetAttr:
|
|||||||
""" columns """
|
""" columns """
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
images: Optional[str] = None
|
images: Optional[str] = None
|
||||||
|
tag: Optional[bool] = None
|
||||||
""" columns for the alpaca format """
|
""" columns for the alpaca format """
|
||||||
prompt: Optional[str] = "instruction"
|
prompt: Optional[str] = "instruction"
|
||||||
query: Optional[str] = "input"
|
query: Optional[str] = "input"
|
||||||
@ -106,7 +107,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
column_names = ["system", "images"]
|
column_names = ["system", "images", "tag"]
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
column_names.extend(["prompt", "query", "response", "history"])
|
column_names.extend(["prompt", "query", "response", "history"])
|
||||||
else:
|
else:
|
||||||
|
@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
|
|||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
model_inputs["pixel_values"] = []
|
model_inputs["pixel_values"] = []
|
||||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||||
@ -111,11 +111,102 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
|
model_inputs["tag"].append(examples["tag"])
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
def preprocess_kto_dataset(
|
||||||
|
examples: Dict[str, List[Any]],
|
||||||
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
|
||||||
|
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
|
||||||
|
examples['kl_response'] = examples['response'][::-1]
|
||||||
|
if processor is not None:
|
||||||
|
model_inputs["pixel_values"] = []
|
||||||
|
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||||
|
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if processor is not None:
|
||||||
|
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||||
|
|
||||||
|
messages = examples["prompt"][i] + examples["response"][i]
|
||||||
|
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
|
||||||
|
input_ids, labels = [], []
|
||||||
|
kl_input_ids, kl_labels = [], []
|
||||||
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
|
template.encode_multiturn(
|
||||||
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
|
input_ids += source_ids + target_ids
|
||||||
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
|
template.encode_multiturn(
|
||||||
|
tokenizer,
|
||||||
|
kl_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if data_args.train_on_prompt:
|
||||||
|
source_mask = source_ids
|
||||||
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
|
kl_input_ids += source_ids + target_ids
|
||||||
|
kl_labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
kl_input_ids += [tokenizer.eos_token_id]
|
||||||
|
kl_labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
|
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||||
|
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||||
|
model_inputs["kl_labels"].append(kl_labels)
|
||||||
|
model_inputs["tag"].append(examples["tag"][i])
|
||||||
|
if processor is not None:
|
||||||
|
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||||
|
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
|
||||||
|
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
|
||||||
|
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
|
||||||
|
if desirable == 0 or undesirable == 0:
|
||||||
|
logger.warning("Your dataset only has one preference type.")
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(
|
def preprocess_packed_supervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
@ -289,7 +380,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
|||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
@ -328,6 +419,15 @@ def get_preprocess_and_print_func(
|
|||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||||
|
elif stage == "kto":
|
||||||
|
preprocess_func = partial(
|
||||||
|
preprocess_kto_dataset,
|
||||||
|
template=template,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
|
data_args=data_args,
|
||||||
|
)
|
||||||
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||||
else:
|
else:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_unsupervised_dataset,
|
preprocess_unsupervised_dataset,
|
||||||
|
@ -45,6 +45,7 @@ TRAINING_STAGES = {
|
|||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
"PPO": "ppo",
|
"PPO": "ppo",
|
||||||
"DPO": "dpo",
|
"DPO": "dpo",
|
||||||
|
"KTO": "kto",
|
||||||
"ORPO": "orpo",
|
"ORPO": "orpo",
|
||||||
"Pre-Training": "pt",
|
"Pre-Training": "pt",
|
||||||
}
|
}
|
||||||
|
@ -133,6 +133,22 @@ class RLHFArguments:
|
|||||||
default=0.0,
|
default=0.0,
|
||||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||||
)
|
)
|
||||||
|
kto_beta: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta parameter for the KTO loss."},
|
||||||
|
)
|
||||||
|
kto_ftx: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
|
||||||
|
)
|
||||||
|
kto_desirable_weight: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "The desirable weight for the KTO loss."},
|
||||||
|
)
|
||||||
|
kto_undesirable_weight: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "The undesirable weight for the KTO loss."},
|
||||||
|
)
|
||||||
orpo_beta: float = field(
|
orpo_beta: float = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
||||||
@ -291,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||||
)
|
)
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
|
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."},
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
|
4
src/llamafactory/train/kto/__init__.py
Normal file
4
src/llamafactory/train/kto/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .workflow import run_kto
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_kto"]
|
206
src/llamafactory/train/kto/trainer.py
Normal file
206
src/llamafactory/train/kto/trainer.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from types import MethodType
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Trainer
|
||||||
|
from trl import KTOTrainer
|
||||||
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
class CustomKTOTrainer(KTOTrainer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
disable_dropout: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if disable_dropout:
|
||||||
|
disable_dropout_in_model(model)
|
||||||
|
if ref_model is not None:
|
||||||
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
self.reference_free = False
|
||||||
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
|
self.generate_during_eval = False # disable at evaluation
|
||||||
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
|
self.padding_value = 0
|
||||||
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
self.precompute_ref_log_probs = False
|
||||||
|
self._precomputed_train_ref_log_probs = False
|
||||||
|
self._precomputed_eval_ref_log_probs = False
|
||||||
|
self._peft_has_been_casted_to_bf16 = False
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
# KTO parameter
|
||||||
|
self.beta = finetuning_args.kto_beta
|
||||||
|
self.ftx_gamma = finetuning_args.kto_ftx
|
||||||
|
self.desirable_weight = finetuning_args.kto_desirable_weight
|
||||||
|
self.undesirable_weight = finetuning_args.kto_undesirable_weight
|
||||||
|
|
||||||
|
|
||||||
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
|
if not hasattr(self, "accelerator"):
|
||||||
|
raise AttributeError("Please update `transformers`.")
|
||||||
|
|
||||||
|
if ref_model is not None:
|
||||||
|
if self.is_deepspeed_enabled:
|
||||||
|
if not (
|
||||||
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
if finetuning_args.use_badam:
|
||||||
|
from badam import clip_grad_norm_for_sparse_tensor
|
||||||
|
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||||
|
|
||||||
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
if self.optimizer is None:
|
||||||
|
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||||
|
r"""
|
||||||
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||||
|
"""
|
||||||
|
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||||
|
return -all_logps.nanmean()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
|
with torch.no_grad():
|
||||||
|
KL_logits = model(
|
||||||
|
batch["KL_completion_input_ids"],
|
||||||
|
attention_mask=batch["KL_completion_attention_mask"],
|
||||||
|
).logits
|
||||||
|
|
||||||
|
completion_logits = model(
|
||||||
|
batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
).logits
|
||||||
|
|
||||||
|
completion_logps = self.get_batch_logps(
|
||||||
|
completion_logits,
|
||||||
|
batch["labels"],
|
||||||
|
average_log_prob=False,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
label_pad_token_id=self.label_pad_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
KL_logps = self.get_batch_logps(
|
||||||
|
KL_logits,
|
||||||
|
batch["kl_labels"],
|
||||||
|
average_log_prob=False,
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
label_pad_token_id=self.label_pad_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if completion_logps.shape[0] != len(batch["tag"]):
|
||||||
|
raise ValueError(
|
||||||
|
"There is a mismatch between the number of examples in this batch and the number of "
|
||||||
|
"examples for which an output sequence was predicted."
|
||||||
|
)
|
||||||
|
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
|
||||||
|
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
|
||||||
|
|
||||||
|
chosen_logps = completion_logps[chosen_idx, ...]
|
||||||
|
rejected_logps = completion_logps[rejected_idx, ...]
|
||||||
|
|
||||||
|
chosen_logits = completion_logits[chosen_idx, ...]
|
||||||
|
rejected_logits = completion_logits[rejected_idx, ...]
|
||||||
|
|
||||||
|
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||||
|
):
|
||||||
|
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
metrics = {}
|
||||||
|
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||||
|
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_KL_logps,
|
||||||
|
) = self.forward(model, batch)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.ref_model is None:
|
||||||
|
ref_model = self.model
|
||||||
|
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||||
|
else:
|
||||||
|
ref_model = self.ref_model
|
||||||
|
ref_context = nullcontext()
|
||||||
|
with ref_context:
|
||||||
|
(
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
reference_KL_logps,
|
||||||
|
) = self.forward(ref_model, batch)
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_KL_logps,
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
reference_KL_logps,
|
||||||
|
)
|
||||||
|
losses = losses.nanmean()
|
||||||
|
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
|
||||||
|
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
|
||||||
|
|
||||||
|
|
||||||
|
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
||||||
|
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
||||||
|
|
||||||
|
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
|
||||||
|
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
|
||||||
|
|
||||||
|
if all_num_chosen > 0:
|
||||||
|
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
|
||||||
|
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
|
||||||
|
metrics["count/chosen"] = all_num_chosen
|
||||||
|
|
||||||
|
if all_num_rejected > 0:
|
||||||
|
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
|
||||||
|
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
|
||||||
|
metrics["count/rejected"] = all_num_rejected
|
||||||
|
|
||||||
|
metrics["kl"] = kl.item()
|
||||||
|
|
||||||
|
return losses, metrics
|
78
src/llamafactory/train/kto/workflow.py
Normal file
78
src/llamafactory/train/kto/workflow.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.ploting import plot_loss
|
||||||
|
from ...hparams import ModelArguments
|
||||||
|
from ...model import load_model, load_tokenizer
|
||||||
|
from ..utils import create_modelcard_and_push, create_ref_model
|
||||||
|
from .trainer import CustomKTOTrainer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
def run_kto(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
|
):
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
|
data_collator = KTODataCollatorWithPadding(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
pad_to_multiple_of=8,
|
||||||
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create reference model
|
||||||
|
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||||
|
ref_model = model
|
||||||
|
else:
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
|
# Update arguments
|
||||||
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = CustomKTOTrainer(
|
||||||
|
model=model,
|
||||||
|
ref_model=ref_model,
|
||||||
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**split_dataset(dataset, data_args, training_args),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
if training_args.do_eval:
|
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||||
|
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||||
|
for key in remove_keys:
|
||||||
|
metrics.pop(key)
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
@ -14,7 +14,7 @@ from .ppo import run_ppo
|
|||||||
from .pt import run_pt
|
from .pt import run_pt
|
||||||
from .rm import run_rm
|
from .rm import run_rm
|
||||||
from .sft import run_sft
|
from .sft import run_sft
|
||||||
|
from .kto import run_kto
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
@ -39,6 +39,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
|||||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif finetuning_args.stage == "orpo":
|
elif finetuning_args.stage == "orpo":
|
||||||
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif finetuning_args.stage == "kto":
|
||||||
|
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown task.")
|
raise ValueError("Unknown task.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user