diff --git a/README.md b/README.md index af6ef66f..b9059426 100644 --- a/README.md +++ b/README.md @@ -68,16 +68,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage. + [24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv! [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage. +
Full Changelog + [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage. [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage. -
Full Changelog - [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. @@ -165,6 +167,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | > [!NOTE] > Use `--quantization_bit 4` argument to enable QLoRA. diff --git a/README_zh.md b/README_zh.md index d018ee32..5c81be44 100644 --- a/README_zh.md +++ b/README_zh.md @@ -68,16 +68,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 +[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。 + [24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看! [24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`。 +
展开日志 + [24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。 [24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。 -
展开日志 - [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。) [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 @@ -165,6 +167,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | > [!NOTE] > 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。 diff --git a/data/README.md b/data/README.md index fa2c9ee0..2ea0c117 100644 --- a/data/README.md +++ b/data/README.md @@ -34,6 +34,8 @@ If you are using a custom dataset, please provide your dataset definition in the Given above, you can use the custom dataset via specifying `--dataset dataset_name`. +---- + Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format: ```json @@ -84,6 +86,10 @@ For the preference datasets, the `response` column should be a string list whose } ``` +Remember to set `"ranking": true` for the preference datasets. + +---- + The dataset in sharegpt format should follow the below format: ```json diff --git a/data/README_zh.md b/data/README_zh.md index e0004f4a..b00f81d9 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -34,6 +34,8 @@ 添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。 +---- + 该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织: ```json @@ -84,6 +86,10 @@ } ``` +添加偏好数据集需要额外指定 `"ranking": true`。 + +---- + 而 sharegpt 格式的数据集按照以下方式组织: ```json diff --git a/examples/lora_single_gpu/README.md b/examples/lora_single_gpu/README.md index ae0f4722..151d0784 100644 --- a/examples/lora_single_gpu/README.md +++ b/examples/lora_single_gpu/README.md @@ -1,8 +1,9 @@ Usage: - `pretrain.sh`: do pre-train (optional) -- `sft.sh`: do supervised fine-tune +- `sft.sh`: do supervised fine-tuning - `reward.sh`: do reward modeling (must after sft.sh) - `ppo.sh`: do PPO training (must after sft.sh and reward.sh) - `dpo.sh`: do DPO training (must after sft.sh) +- `orpo.sh`: do ORPO training - `predict.sh`: do predict (must after sft.sh and dpo.sh) diff --git a/examples/lora_single_gpu/orpo.sh b/examples/lora_single_gpu/orpo.sh new file mode 100644 index 00000000..77662ecf --- /dev/null +++ b/examples/lora_single_gpu/orpo.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ + --stage orpo \ + --do_train \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --dataset comparison_gpt4_en \ + --dataset_dir ../../data \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,v_proj \ + --output_dir ../../saves/LLaMA2-7B/lora/orpo \ + --overwrite_cache \ + --overwrite_output_dir \ + --cutoff_len 1024 \ + --preprocessing_num_workers 16 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --warmup_steps 20 \ + --save_steps 100 \ + --eval_steps 100 \ + --evaluation_strategy steps \ + --load_best_model_at_end \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --max_samples 1000 \ + --val_size 0.1 \ + --plot_loss \ + --fp16 diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py index 80dbf5ff..792e89d9 100644 --- a/src/llmtuner/data/__init__.py +++ b/src/llmtuner/data/__init__.py @@ -1,6 +1,15 @@ +from .collator import PairwiseDataCollatorWithPadding from .loader import get_dataset from .template import Template, get_template_and_fix_tokenizer, templates from .utils import Role, split_dataset -__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"] +__all__ = [ + "PairwiseDataCollatorWithPadding", + "get_dataset", + "Template", + "get_template_and_fix_tokenizer", + "templates", + "Role", + "split_dataset", +] diff --git a/src/llmtuner/data/collator.py b/src/llmtuner/data/collator.py new file mode 100644 index 00000000..5e506546 --- /dev/null +++ b/src/llmtuner/data/collator.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple + +import torch +from transformers import DataCollatorForSeq2Seq + + +@dataclass +class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: + r""" + Masks out the input ids except for the responses. + """ + padded_labels = [] + for feature, (prompt_len, answer_len) in zip(batch, positions): + if self.tokenizer.padding_side == "left": + start, end = feature.size(0) - answer_len, feature.size(0) + else: + start, end = prompt_len, prompt_len + answer_len + padded_tensor = self.label_pad_token_id * torch.ones_like(feature) + padded_tensor[start:end] = feature[start:end] + padded_labels.append(padded_tensor) + return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + label_positions = [] + for key in ("chosen_ids", "rejected_ids"): + for feature in features: + prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) + concatenated_features.append( + { + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (prompt_len + answer_len), + } + ) + label_positions.append((prompt_len, answer_len)) + + batch = super().__call__(concatenated_features) + batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) + return batch diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 935695ad..0ab734e0 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -117,7 +117,6 @@ def get_dataset( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"], - # split: Optional[str] = "train", # TODO: add split ) -> Union["Dataset", "IterableDataset"]: template = get_template_and_fix_tokenizer(tokenizer, data_args.template) if data_args.train_on_prompt and template.efficient_eos: @@ -138,6 +137,9 @@ def get_dataset( with training_args.main_process_first(desc="load dataset"): all_datasets = [] for dataset_attr in get_dataset_list(data_args): + if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): + raise ValueError("The dataset is not applicable in the current training stage.") + all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) dataset = merge_dataset(all_datasets, data_args, training_args) diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 7fb0a9b6..b8edfa10 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -23,23 +23,25 @@ def preprocess_pretrain_dataset( ) -> Dict[str, List[List[int]]]: # build grouped texts with format `X1 X2 X3 ...` if packing is enabled text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] - if not data_args.packing: - return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) - tokenized_examples = tokenizer(text_examples, add_special_tokens=False) - concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} - total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) - block_size = data_args.cutoff_len - # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // block_size) * block_size - # split by chunks of cutoff_len - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - if data_args.template == "gemma": - for i in range(len(result["input_ids"])): - result["input_ids"][i][0] = tokenizer.bos_token_id + if not data_args.packing: + if data_args.template == "gemma": + text_examples = [tokenizer.bos_token + example for example in text_examples] + + result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) + else: + tokenized_examples = tokenizer(text_examples, add_special_tokens=False) + concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} + total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) + block_size = data_args.cutoff_len + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + if data_args.template == "gemma": + for i in range(len(result["input_ids"])): + result["input_ids"][i][0] = tokenizer.bos_token_id return result diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index c0b6d6c2..83ee0610 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -44,7 +44,7 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]: max_target_len = int(max_len * (target_len / (source_len + target_len))) max_target_len = max(max_target_len, reserved_label_len) - max_source_len = max_len - max_target_len + max_source_len = max_len - min(max_target_len, target_len) return max_source_len, max_target_len diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 985b0292..6e347c3c 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -134,6 +134,7 @@ class LogCallback(TrainerCallback): eval_loss=state.log_history[-1].get("eval_loss", None), predict_loss=state.log_history[-1].get("predict_loss", None), reward=state.log_history[-1].get("reward", None), + accuracy=state.log_history[-1].get("rewards/accuracies", None), learning_rate=state.log_history[-1].get("learning_rate", None), epoch=state.log_history[-1].get("epoch", None), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 12ba8b23..8af8d8e8 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -39,9 +39,12 @@ TRAINING_STAGES = { "Reward Modeling": "rm", "PPO": "ppo", "DPO": "dpo", + "ORPO": "orpo", "Pre-Training": "pt", } +STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] + V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index c1f08334..177a9f8a 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -110,6 +110,10 @@ class RLHFArguments: default=0.0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, ) + orpo_beta: float = field( + default=0.1, + metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."}, + ) ppo_buffer_size: int = field( default=1, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, @@ -209,7 +213,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 39e84679..c7e385da 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -74,7 +74,7 @@ class CustomDPOTrainer(DPOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) - def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor: + def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor": r""" Computes supervised cross-entropy loss of given labels under the given logits. @@ -85,8 +85,8 @@ class CustomDPOTrainer(DPOTrainer): return -all_logps def concatenated_forward( - self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor] - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error all_logits = model( @@ -107,9 +107,9 @@ class CustomDPOTrainer(DPOTrainer): def get_batch_loss_metrics( self, model: "PreTrainedModel", - batch: Dict[str, torch.Tensor], + batch: Dict[str, "torch.Tensor"], train_eval: Literal["train", "eval"] = "train", - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: r""" Computes the DPO loss and other metrics for the given batch of inputs for train or test. """ @@ -142,21 +142,22 @@ class CustomDPOTrainer(DPOTrainer): reference_chosen_logps, reference_rejected_logps, ) + batch_loss = losses.mean() if self.ftx_gamma > 1e-6: batch_size = batch["input_ids"].size(0) // 2 chosen_labels, _ = batch["labels"].split(batch_size, dim=0) - losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels) + batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean() reward_accuracies = (chosen_rewards > rejected_rewards).float() prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() + metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean() + metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean() + metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean() + metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean() + metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean() + metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean() + metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean() + metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean() - return losses.mean(), metrics + return batch_loss, metrics diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 851de982..4a1e867e 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -2,13 +2,12 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import get_dataset, split_dataset +from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer from ..utils import create_modelcard_and_push, create_ref_model -from .collator import DPODataCollatorWithPadding from .trainer import CustomDPOTrainer @@ -29,7 +28,7 @@ def run_dpo( dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) - data_collator = DPODataCollatorWithPadding( + data_collator = PairwiseDataCollatorWithPadding( tokenizer=tokenizer, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, @@ -64,7 +63,7 @@ def run_dpo( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"]) # Evaluation if training_args.do_eval: diff --git a/src/llmtuner/train/orpo/__init__.py b/src/llmtuner/train/orpo/__init__.py new file mode 100644 index 00000000..e79d5ea3 --- /dev/null +++ b/src/llmtuner/train/orpo/__init__.py @@ -0,0 +1,4 @@ +from .workflow import run_orpo + + +__all__ = ["run_orpo"] diff --git a/src/llmtuner/train/orpo/trainer.py b/src/llmtuner/train/orpo/trainer.py new file mode 100644 index 00000000..291351e4 --- /dev/null +++ b/src/llmtuner/train/orpo/trainer.py @@ -0,0 +1,150 @@ +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import Trainer +from trl import DPOTrainer +from trl.trainer.utils import disable_dropout_in_model + +from ...extras.constants import IGNORE_INDEX +from ..utils import create_custom_optimzer, create_custom_scheduler + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import FinetuningArguments + + +class CustomORPOTrainer(DPOTrainer): + def __init__( + self, + model: Union["PreTrainedModel", "torch.nn.Module"], + finetuning_args: "FinetuningArguments", + disable_dropout: bool = True, + **kwargs, + ): + if disable_dropout: + disable_dropout_in_model(model) + + self.finetuning_args = finetuning_args + self.reference_free = False + self.use_dpo_data_collator = True # hack to avoid warning + self.generate_during_eval = False # disable at evaluation + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.is_encoder_decoder = model.config.is_encoder_decoder + self.precompute_ref_log_probs = False + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + self._peft_has_been_casted_to_bf16 = False + + self.beta = finetuning_args.orpo_beta + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + Trainer.__init__(self, model=model, **kwargs) + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor": + r""" + Computes supervised cross-entropy loss of given labels under the given logits. + + Returns: + A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. + """ + all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True) + return -all_logps + + # Borrowed from: + # https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/orpo_trainer.py#L592 + def odds_ratio_loss( + self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor" + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes ORPO's odds ratio (OR) loss. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of five tensors: (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen). + """ + + # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (chosen_logps - rejected_logps) - ( + torch.log(1 - torch.exp(chosen_logps)) - torch.log(1 - torch.exp(rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + losses = self.beta * ratio + + chosen_rewards = self.beta * chosen_logps.detach() + rejected_rewards = self.beta * rejected_logps.detach() + + return losses, chosen_rewards, rejected_rewards, ratio, log_odds + + def concatenated_forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + all_logits = model( + input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True + ).logits.to(torch.float32) + + all_logps = self.get_batch_logps( + all_logits, + batch["labels"], + average_log_prob=False, + label_pad_token_id=self.label_pad_token_id, + ) + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits + + def get_batch_loss_metrics( + self, + model: "PreTrainedModel", + batch: Dict[str, "torch.Tensor"], + train_eval: Literal["train", "eval"] = "train", + ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: + r""" + Computes the ORPO loss and other metrics for the given batch of inputs for train or test. + """ + metrics = {} + chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch) + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + chosen_logps, rejected_logps + ) + batch_size = batch["input_ids"].size(0) // 2 + chosen_labels, _ = batch["labels"].split(batch_size, dim=0) + sft_loss = self.sft_loss(chosen_logits, chosen_labels) + batch_loss = (sft_loss - losses).mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean() + metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean() + metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean() + metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean() + metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean() + metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean() + metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean() + metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean() + metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean() + metrics["{}log_odds_ratio".format(prefix)] = log_odds_ratio.detach().cpu().mean() + metrics["{}log_odds_chosen".format(prefix)] = log_odds_chosen.detach().cpu().mean() + + return batch_loss, metrics diff --git a/src/llmtuner/train/orpo/workflow.py b/src/llmtuner/train/orpo/workflow.py new file mode 100644 index 00000000..1d549d28 --- /dev/null +++ b/src/llmtuner/train/orpo/workflow.py @@ -0,0 +1,68 @@ +# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py + +from typing import TYPE_CHECKING, List, Optional + +from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.ploting import plot_loss +from ...hparams import ModelArguments +from ...model import load_model, load_tokenizer +from ..utils import create_modelcard_and_push +from .trainer import CustomORPOTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments + + +def run_orpo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer = load_tokenizer(model_args) + dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + data_collator = PairwiseDataCollatorWithPadding( + tokenizer=tokenizer, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + ) + + # Update arguments + training_args.remove_unused_columns = False # important for pairwise dataset + + # Initialize our Trainer + trainer = CustomORPOTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **split_dataset(dataset, data_args, training_args), + ) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index dd4b8467..f260f82e 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -2,13 +2,12 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import get_dataset, split_dataset +from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..utils import create_modelcard_and_push -from .collator import PairwiseDataCollatorWithPadding from .metric import compute_accuracy from .trainer import PairwiseTrainer diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 1b8e3cb7..299e4f2a 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -9,6 +9,7 @@ from ..extras.logging import get_logger from ..hparams import get_infer_args, get_train_args from ..model import load_model_and_tokenizer from .dpo import run_dpo +from .orpo import run_orpo from .ppo import run_ppo from .pt import run_pt from .rm import run_rm @@ -36,6 +37,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif finetuning_args.stage == "dpo": run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "orpo": + run_orpo(model_args, data_args, training_args, finetuning_args, callbacks) else: raise ValueError("Unknown task.") diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 798e6408..96ef2737 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -11,6 +11,7 @@ from ..extras.constants import ( DEFAULT_MODULE, DEFAULT_TEMPLATE, PEFT_METHODS, + STAGES_USE_PAIR_DATA, SUPPORTED_MODELS, TRAINING_STAGES, DownloadSource, @@ -127,7 +128,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] + ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] return gr.Dropdown(value=[], choices=datasets)