diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 792e89d9..0b3a8dcf 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -1,4 +1,4 @@ -from .collator import PairwiseDataCollatorWithPadding +from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding from .loader import get_dataset from .template import Template, get_template_and_fix_tokenizer, templates from .utils import Role, split_dataset @@ -6,6 +6,7 @@ from .utils import Role, split_dataset __all__ = [ "PairwiseDataCollatorWithPadding", + "KTODataCollatorWithPadding", "get_dataset", "Template", "get_template_and_fix_tokenizer", diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 6bd12aad..2cf8a4f3 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -29,7 +29,7 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: " def convert_alpaca( examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" ) -> Dict[str, List[Any]]: - outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} + outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []} convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) for i in range(len(examples[dataset_attr.prompt])): prompt = [] @@ -61,6 +61,7 @@ def convert_alpaca( outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["tools"].append("") outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else []) + outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True) return outputs @@ -137,6 +138,7 @@ def align_dataset( "system": {"dtype": "string", "_type": "Value"}, "tools": {"dtype": "string", "_type": "Value"}, "images": [{"_type": "Image"}], + "tag": {"dtype": "bool", "_type": "Value"}, } ) kwargs = {} diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 5e506546..517fa68c 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -49,3 +49,36 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): batch = super().__call__(concatenated_features) batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) return batch + +@dataclass +class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for KTO data. + """ + def __call__(self, features, return_tensors=None): + concatenated_features = [] + kl_concatenated_features = [] + tags = [] + for feature in features: + concatenated_features.append( + { + "input_ids": feature["input_ids"], + "attention_mask": feature["attention_mask"], + "labels": feature["labels"], + } + ) + kl_concatenated_features.append( + { + "input_ids": feature["kl_input_ids"], + "attention_mask": feature["kl_attention_mask"], + "labels": feature["kl_labels"], + } + ) + tags.append(feature["tag"]) + batch = super().__call__(concatenated_features) + kl_batch = super().__call__(kl_concatenated_features) + batch["KL_completion_input_ids"] = kl_batch["input_ids"] + batch["KL_completion_attention_mask"] = kl_batch["attention_mask"] + batch["kl_labels"] = kl_batch["labels"] + batch["tag"] = torch.tensor(tags) + return batch \ No newline at end of file diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 3cc01b0d..a04bf377 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -116,7 +116,7 @@ def get_dataset( model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], + stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, ) -> Union["Dataset", "IterableDataset"]: diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 848fd66c..33136551 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -28,6 +28,7 @@ class DatasetAttr: """ columns """ system: Optional[str] = None images: Optional[str] = None + tag: Optional[bool] = None """ columns for the alpaca format """ prompt: Optional[str] = "instruction" query: Optional[str] = "input" @@ -106,7 +107,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") if "columns" in dataset_info[name]: - column_names = ["system", "images"] + column_names = ["system", "images", "tag"] if dataset_attr.formatting == "alpaca": column_names.extend(["prompt", "query", "response", "history"]) else: diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index 38211b0c..4a348ce2 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -70,7 +70,7 @@ def preprocess_supervised_dataset( ) -> Dict[str, List[List[int]]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []} if processor is not None: model_inputs["pixel_values"] = [] preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) @@ -111,11 +111,102 @@ def preprocess_supervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) + model_inputs["tag"].append(examples["tag"]) if processor is not None: model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) return model_inputs +def preprocess_kto_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # for multiturn examples, we only mask the prompt part in each prompt-response pair. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []} + """Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions.""" + examples['kl_response'] = examples['response'][::-1] + if processor is not None: + model_inputs["pixel_values"] = [] + preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor) + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + continue + + if processor is not None: + examples["prompt"][i][0]["content"] = "" + examples["prompt"][i][0]["content"] + + messages = examples["prompt"][i] + examples["response"][i] + kl_messages = examples["prompt"][i] + examples["kl_response"][i] + input_ids, labels = [], [] + kl_input_ids, kl_labels = [], [] + for turn_idx, (source_ids, target_ids) in enumerate( + template.encode_multiturn( + tokenizer, + messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + input_ids += source_ids + target_ids + labels += source_mask + target_ids + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + for turn_idx, (source_ids, target_ids) in enumerate( + template.encode_multiturn( + tokenizer, + kl_messages, + examples["system"][i], + examples["tools"][i], + data_args.cutoff_len, + data_args.reserved_label_len, + ) + ): + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + kl_input_ids += source_ids + target_ids + kl_labels += source_mask + target_ids + + if template.efficient_eos: + kl_input_ids += [tokenizer.eos_token_id] + kl_labels += [tokenizer.eos_token_id] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["kl_input_ids"].append(kl_input_ids) + model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) + model_inputs["kl_labels"].append(kl_labels) + model_inputs["tag"].append(examples["tag"][i]) + if processor is not None: + model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i])) + desirable = sum([1 for tag in model_inputs["tag"] if tag is True]) + undesirable = sum([1 for tag in model_inputs["tag"] if tag is False]) + logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable)) + if desirable == 0 or undesirable == 0: + logger.warning("Your dataset only has one preference type.") + return model_inputs def preprocess_packed_supervised_dataset( examples: Dict[str, List[Any]], @@ -289,7 +380,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: def get_preprocess_and_print_func( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], + stage: Literal["pt", "sft", "rm", "ppo", "kto"], template: "Template", tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], @@ -328,6 +419,15 @@ def get_preprocess_and_print_func( data_args=data_args, ) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) + elif stage == "kto": + preprocess_func = partial( + preprocess_kto_dataset, + template=template, + tokenizer=tokenizer, + processor=processor, + data_args=data_args, + ) + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) else: preprocess_func = partial( preprocess_unsupervised_dataset, diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 6b967517..fecf0c38 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -45,6 +45,7 @@ TRAINING_STAGES = { "Reward Modeling": "rm", "PPO": "ppo", "DPO": "dpo", + "KTO": "kto", "ORPO": "orpo", "Pre-Training": "pt", } diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index e728c30a..e6840518 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -133,6 +133,22 @@ class RLHFArguments: default=0.0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, ) + kto_beta: float = field( + default=0.1, + metadata={"help": "The beta parameter for the KTO loss."}, + ) + kto_ftx: float = field( + default=0.0, + metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."}, + ) + kto_desirable_weight: float = field( + default=1.0, + metadata={"help": "The desirable weight for the KTO loss."}, + ) + kto_undesirable_weight: float = field( + default=1.0, + metadata={"help": "The undesirable weight for the KTO loss."}, + ) orpo_beta: float = field( default=0.1, metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."}, @@ -291,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llamafactory/train/kto/__init__.py b/src/llamafactory/train/kto/__init__.py new file mode 100644 index 00000000..34c7905a --- /dev/null +++ b/src/llamafactory/train/kto/__init__.py @@ -0,0 +1,4 @@ +from .workflow import run_kto + + +__all__ = ["run_kto"] diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py new file mode 100644 index 00000000..6f9f6754 --- /dev/null +++ b/src/llamafactory/train/kto/trainer.py @@ -0,0 +1,206 @@ +from collections import defaultdict +from contextlib import nullcontext +from types import MethodType +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Trainer +from trl import KTOTrainer +from trl.trainer.utils import disable_dropout_in_model + +from ...extras.constants import IGNORE_INDEX +from ..utils import create_custom_optimzer, create_custom_scheduler + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import FinetuningArguments + + +class CustomKTOTrainer(KTOTrainer): + def __init__( + self, + model: Union["PreTrainedModel", torch.nn.Module], + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]], + finetuning_args: "FinetuningArguments", + disable_dropout: bool = True, + **kwargs, + ): + if disable_dropout: + disable_dropout_in_model(model) + if ref_model is not None: + disable_dropout_in_model(ref_model) + + self.finetuning_args = finetuning_args + self.reference_free = False + self.use_dpo_data_collator = True # hack to avoid warning + self.generate_during_eval = False # disable at evaluation + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.is_encoder_decoder = model.config.is_encoder_decoder + self.precompute_ref_log_probs = False + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + self._peft_has_been_casted_to_bf16 = False + self.ref_model = ref_model + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # KTO parameter + self.beta = finetuning_args.kto_beta + self.ftx_gamma = finetuning_args.kto_ftx + self.desirable_weight = finetuning_args.kto_desirable_weight + self.undesirable_weight = finetuning_args.kto_undesirable_weight + + + Trainer.__init__(self, model=model, **kwargs) + if not hasattr(self, "accelerator"): + raise AttributeError("Please update `transformers`.") + + if ref_model is not None: + if self.is_deepspeed_enabled: + if not ( + getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if finetuning_args.use_badam: + from badam import clip_grad_norm_for_sparse_tensor + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor": + r""" + Computes supervised cross-entropy loss of given labels under the given logits. + Returns: + A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. + """ + all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True) + return -all_logps.nanmean() + + + def forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + with torch.no_grad(): + KL_logits = model( + batch["KL_completion_input_ids"], + attention_mask=batch["KL_completion_attention_mask"], + ).logits + + completion_logits = model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + ).logits + + completion_logps = self.get_batch_logps( + completion_logits, + batch["labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + KL_logps = self.get_batch_logps( + KL_logits, + batch["kl_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + if completion_logps.shape[0] != len(batch["tag"]): + raise ValueError( + "There is a mismatch between the number of examples in this batch and the number of " + "examples for which an output sequence was predicted." + ) + chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]] + rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]] + + chosen_logps = completion_logps[chosen_idx, ...] + rejected_logps = completion_logps[rejected_idx, ...] + + chosen_logits = completion_logits[chosen_idx, ...] + rejected_logits = completion_logits[rejected_idx, ...] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) + + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + ): + """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_KL_logps, + ) = self.forward(model, batch) + + with torch.no_grad(): + if self.ref_model is None: + ref_model = self.model + ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() + else: + ref_model = self.ref_model + ref_context = nullcontext() + with ref_context: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_KL_logps, + ) = self.forward(ref_model, batch) + + losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_KL_logps, + ) + losses = losses.nanmean() + if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0: + losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']]) + + + num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) + num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) + + all_num_chosen = self.accelerator.gather(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather(num_rejected).sum().item() + + if all_num_chosen > 0: + metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() + metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() + metrics["count/chosen"] = all_num_chosen + + if all_num_rejected > 0: + metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() + metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() + metrics["count/rejected"] = all_num_rejected + + metrics["kl"] = kl.item() + + return losses, metrics \ No newline at end of file diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py new file mode 100644 index 00000000..a2d0ec24 --- /dev/null +++ b/src/llamafactory/train/kto/workflow.py @@ -0,0 +1,78 @@ +from typing import TYPE_CHECKING, List, Optional + +from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.ploting import plot_loss +from ...hparams import ModelArguments +from ...model import load_model, load_tokenizer +from ..utils import create_modelcard_and_push, create_ref_model +from .trainer import CustomKTOTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments + + +def run_kto( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + data_collator = KTODataCollatorWithPadding( + tokenizer=tokenizer, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + ) + + # Create reference model + if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself + ref_model = model + else: + ref_model = create_ref_model(model_args, finetuning_args) + + # Update arguments + training_args.remove_unused_columns = False # important for pairwise dataset + + # Initialize our Trainer + trainer = CustomKTOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + finetuning_args=finetuning_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **split_dataset(dataset, data_args, training_args), + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + if id(model) == id(ref_model): # unable to compute rewards without a reference model + remove_keys = [key for key in metrics.keys() if "rewards" in key] + for key in remove_keys: + metrics.pop(key) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 8f103ca1..89dcb9ac 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -14,7 +14,7 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft - +from .kto import run_kto if TYPE_CHECKING: from transformers import TrainerCallback @@ -39,6 +39,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "orpo": run_orpo(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "kto": + run_kto(model_args, data_args, training_args, finetuning_args, callbacks) else: raise ValueError("Unknown task.")