mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support ORPO
Former-commit-id: f44a4c27e2461cdaa1b16865f597a31033c0e6d9
This commit is contained in:
		
							parent
							
								
									526111a303
								
							
						
					
					
						commit
						d764cd8736
					
				@ -68,16 +68,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
 | 
			
		||||
 | 
			
		||||
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!
 | 
			
		||||
 | 
			
		||||
[24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage.
 | 
			
		||||
 | 
			
		||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
 | 
			
		||||
 | 
			
		||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
 | 
			
		||||
@ -165,6 +167,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
 | 
			
		||||
| Reward Modeling        | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| PPO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| DPO Training           | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| ORPO Training          | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
 | 
			
		||||
> [!NOTE]
 | 
			
		||||
> Use `--quantization_bit 4` argument to enable QLoRA.
 | 
			
		||||
 | 
			
		||||
@ -68,16 +68,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`。
 | 
			
		||||
 | 
			
		||||
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!
 | 
			
		||||
 | 
			
		||||
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。
 | 
			
		||||
 | 
			
		||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
 | 
			
		||||
 | 
			
		||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
 | 
			
		||||
@ -165,6 +167,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
| 奖励模型训练            | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| PPO 训练               | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| DPO 训练               | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
| ORPO 训练              | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
 | 
			
		||||
 | 
			
		||||
> [!NOTE]
 | 
			
		||||
> 请使用 `--quantization_bit 4` 参数来启用 QLoRA 训练。
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,8 @@ If you are using a custom dataset, please provide your dataset definition in the
 | 
			
		||||
 | 
			
		||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
 | 
			
		||||
 | 
			
		||||
----
 | 
			
		||||
 | 
			
		||||
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
@ -84,6 +86,10 @@ For the preference datasets, the `response` column should be a string list whose
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Remember to set `"ranking": true` for the preference datasets.
 | 
			
		||||
 | 
			
		||||
----
 | 
			
		||||
 | 
			
		||||
The dataset in sharegpt format should follow the below format:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,8 @@
 | 
			
		||||
 | 
			
		||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
 | 
			
		||||
 | 
			
		||||
----
 | 
			
		||||
 | 
			
		||||
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
@ -84,6 +86,10 @@
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
添加偏好数据集需要额外指定 `"ranking": true`。
 | 
			
		||||
 | 
			
		||||
----
 | 
			
		||||
 | 
			
		||||
而 sharegpt 格式的数据集按照以下方式组织:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,9 @@
 | 
			
		||||
Usage:
 | 
			
		||||
 | 
			
		||||
- `pretrain.sh`: do pre-train (optional)
 | 
			
		||||
- `sft.sh`: do supervised fine-tune
 | 
			
		||||
- `sft.sh`: do supervised fine-tuning
 | 
			
		||||
- `reward.sh`: do reward modeling (must after sft.sh)
 | 
			
		||||
- `ppo.sh`: do PPO training (must after sft.sh and reward.sh)
 | 
			
		||||
- `dpo.sh`: do DPO training (must after sft.sh)
 | 
			
		||||
- `orpo.sh`: do ORPO training
 | 
			
		||||
- `predict.sh`: do predict (must after sft.sh and dpo.sh)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										32
									
								
								examples/lora_single_gpu/orpo.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								examples/lora_single_gpu/orpo.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
 | 
			
		||||
    --stage orpo \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --model_name_or_path meta-llama/Llama-2-7b-hf \
 | 
			
		||||
    --dataset comparison_gpt4_en \
 | 
			
		||||
    --dataset_dir ../../data \
 | 
			
		||||
    --template default \
 | 
			
		||||
    --finetuning_type lora \
 | 
			
		||||
    --lora_target q_proj,v_proj \
 | 
			
		||||
    --output_dir ../../saves/LLaMA2-7B/lora/orpo \
 | 
			
		||||
    --overwrite_cache \
 | 
			
		||||
    --overwrite_output_dir \
 | 
			
		||||
    --cutoff_len 1024 \
 | 
			
		||||
    --preprocessing_num_workers 16 \
 | 
			
		||||
    --per_device_train_batch_size 1 \
 | 
			
		||||
    --per_device_eval_batch_size 1 \
 | 
			
		||||
    --gradient_accumulation_steps 8 \
 | 
			
		||||
    --lr_scheduler_type cosine \
 | 
			
		||||
    --logging_steps 10 \
 | 
			
		||||
    --warmup_steps 20 \
 | 
			
		||||
    --save_steps 100 \
 | 
			
		||||
    --eval_steps 100 \
 | 
			
		||||
    --evaluation_strategy steps \
 | 
			
		||||
    --load_best_model_at_end \
 | 
			
		||||
    --learning_rate 1e-5 \
 | 
			
		||||
    --num_train_epochs 1.0 \
 | 
			
		||||
    --max_samples 1000 \
 | 
			
		||||
    --val_size 0.1 \
 | 
			
		||||
    --plot_loss \
 | 
			
		||||
    --fp16
 | 
			
		||||
@ -1,6 +1,15 @@
 | 
			
		||||
from .collator import PairwiseDataCollatorWithPadding
 | 
			
		||||
from .loader import get_dataset
 | 
			
		||||
from .template import Template, get_template_and_fix_tokenizer, templates
 | 
			
		||||
from .utils import Role, split_dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ["get_dataset", "Template", "get_template_and_fix_tokenizer", "templates", "Role", "split_dataset"]
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "PairwiseDataCollatorWithPadding",
 | 
			
		||||
    "get_dataset",
 | 
			
		||||
    "Template",
 | 
			
		||||
    "get_template_and_fix_tokenizer",
 | 
			
		||||
    "templates",
 | 
			
		||||
    "Role",
 | 
			
		||||
    "split_dataset",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										51
									
								
								src/llmtuner/data/collator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								src/llmtuner/data/collator.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, List, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for pairwise data.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
 | 
			
		||||
        r"""
 | 
			
		||||
        Masks out the input ids except for the responses.
 | 
			
		||||
        """
 | 
			
		||||
        padded_labels = []
 | 
			
		||||
        for feature, (prompt_len, answer_len) in zip(batch, positions):
 | 
			
		||||
            if self.tokenizer.padding_side == "left":
 | 
			
		||||
                start, end = feature.size(0) - answer_len, feature.size(0)
 | 
			
		||||
            else:
 | 
			
		||||
                start, end = prompt_len, prompt_len + answer_len
 | 
			
		||||
            padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
 | 
			
		||||
            padded_tensor[start:end] = feature[start:end]
 | 
			
		||||
            padded_labels.append(padded_tensor)
 | 
			
		||||
        return torch.stack(padded_labels, dim=0).contiguous()  # in contiguous memory
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pads batched data to the longest sequence in the batch.
 | 
			
		||||
 | 
			
		||||
        We generate 2 * n examples where the first n examples represent chosen examples and
 | 
			
		||||
        the last n examples represent rejected examples.
 | 
			
		||||
        """
 | 
			
		||||
        concatenated_features = []
 | 
			
		||||
        label_positions = []
 | 
			
		||||
        for key in ("chosen_ids", "rejected_ids"):
 | 
			
		||||
            for feature in features:
 | 
			
		||||
                prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
 | 
			
		||||
                concatenated_features.append(
 | 
			
		||||
                    {
 | 
			
		||||
                        "input_ids": feature["prompt_ids"] + feature[key],
 | 
			
		||||
                        "attention_mask": [1] * (prompt_len + answer_len),
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                label_positions.append((prompt_len, answer_len))
 | 
			
		||||
 | 
			
		||||
        batch = super().__call__(concatenated_features)
 | 
			
		||||
        batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
 | 
			
		||||
        return batch
 | 
			
		||||
@ -117,7 +117,6 @@ def get_dataset(
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    stage: Literal["pt", "sft", "rm", "ppo"],
 | 
			
		||||
    # split: Optional[str] = "train", # TODO: add split
 | 
			
		||||
) -> Union["Dataset", "IterableDataset"]:
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
 | 
			
		||||
    if data_args.train_on_prompt and template.efficient_eos:
 | 
			
		||||
@ -138,6 +137,9 @@ def get_dataset(
 | 
			
		||||
    with training_args.main_process_first(desc="load dataset"):
 | 
			
		||||
        all_datasets = []
 | 
			
		||||
        for dataset_attr in get_dataset_list(data_args):
 | 
			
		||||
            if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
 | 
			
		||||
                raise ValueError("The dataset is not applicable in the current training stage.")
 | 
			
		||||
 | 
			
		||||
            all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
 | 
			
		||||
        dataset = merge_dataset(all_datasets, data_args, training_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,23 +23,25 @@ def preprocess_pretrain_dataset(
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
 | 
			
		||||
    text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
 | 
			
		||||
    if not data_args.packing:
 | 
			
		||||
        return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
 | 
			
		||||
 | 
			
		||||
    tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
 | 
			
		||||
    concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
 | 
			
		||||
    total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
 | 
			
		||||
    block_size = data_args.cutoff_len
 | 
			
		||||
    # we drop the small remainder, and if the total_length < block_size, we exclude this batch
 | 
			
		||||
    total_length = (total_length // block_size) * block_size
 | 
			
		||||
    # split by chunks of cutoff_len
 | 
			
		||||
    result = {
 | 
			
		||||
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 | 
			
		||||
        for k, t in concatenated_examples.items()
 | 
			
		||||
    }
 | 
			
		||||
    if data_args.template == "gemma":
 | 
			
		||||
        for i in range(len(result["input_ids"])):
 | 
			
		||||
            result["input_ids"][i][0] = tokenizer.bos_token_id
 | 
			
		||||
    if not data_args.packing:
 | 
			
		||||
        if data_args.template == "gemma":
 | 
			
		||||
            text_examples = [tokenizer.bos_token + example for example in text_examples]
 | 
			
		||||
 | 
			
		||||
        result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
 | 
			
		||||
    else:
 | 
			
		||||
        tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
 | 
			
		||||
        concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
 | 
			
		||||
        total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
 | 
			
		||||
        block_size = data_args.cutoff_len
 | 
			
		||||
        total_length = (total_length // block_size) * block_size
 | 
			
		||||
        result = {
 | 
			
		||||
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 | 
			
		||||
            for k, t in concatenated_examples.items()
 | 
			
		||||
        }
 | 
			
		||||
        if data_args.template == "gemma":
 | 
			
		||||
            for i in range(len(result["input_ids"])):
 | 
			
		||||
                result["input_ids"][i][0] = tokenizer.bos_token_id
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -44,7 +44,7 @@ def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
 | 
			
		||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
 | 
			
		||||
    max_target_len = int(max_len * (target_len / (source_len + target_len)))
 | 
			
		||||
    max_target_len = max(max_target_len, reserved_label_len)
 | 
			
		||||
    max_source_len = max_len - max_target_len
 | 
			
		||||
    max_source_len = max_len - min(max_target_len, target_len)
 | 
			
		||||
    return max_source_len, max_target_len
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -134,6 +134,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            eval_loss=state.log_history[-1].get("eval_loss", None),
 | 
			
		||||
            predict_loss=state.log_history[-1].get("predict_loss", None),
 | 
			
		||||
            reward=state.log_history[-1].get("reward", None),
 | 
			
		||||
            accuracy=state.log_history[-1].get("rewards/accuracies", None),
 | 
			
		||||
            learning_rate=state.log_history[-1].get("learning_rate", None),
 | 
			
		||||
            epoch=state.log_history[-1].get("epoch", None),
 | 
			
		||||
            percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
 | 
			
		||||
 | 
			
		||||
@ -39,9 +39,12 @@ TRAINING_STAGES = {
 | 
			
		||||
    "Reward Modeling": "rm",
 | 
			
		||||
    "PPO": "ppo",
 | 
			
		||||
    "DPO": "dpo",
 | 
			
		||||
    "ORPO": "orpo",
 | 
			
		||||
    "Pre-Training": "pt",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
 | 
			
		||||
 | 
			
		||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
 | 
			
		||||
 | 
			
		||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
 | 
			
		||||
 | 
			
		||||
@ -110,6 +110,10 @@ class RLHFArguments:
 | 
			
		||||
        default=0.0,
 | 
			
		||||
        metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
 | 
			
		||||
    )
 | 
			
		||||
    orpo_beta: float = field(
 | 
			
		||||
        default=0.1,
 | 
			
		||||
        metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
 | 
			
		||||
    )
 | 
			
		||||
    ppo_buffer_size: int = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
 | 
			
		||||
@ -209,7 +213,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
 | 
			
		||||
    )
 | 
			
		||||
    stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
 | 
			
		||||
    stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
 | 
			
		||||
        default="sft",
 | 
			
		||||
        metadata={"help": "Which stage will be performed in training."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -74,7 +74,7 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor:
 | 
			
		||||
    def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes supervised cross-entropy loss of given labels under the given logits.
 | 
			
		||||
 | 
			
		||||
@ -85,8 +85,8 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        return -all_logps
 | 
			
		||||
 | 
			
		||||
    def concatenated_forward(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor]
 | 
			
		||||
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
        batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()})  # avoid error
 | 
			
		||||
 | 
			
		||||
        all_logits = model(
 | 
			
		||||
@ -107,9 +107,9 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
    def get_batch_loss_metrics(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
        batch: Dict[str, torch.Tensor],
 | 
			
		||||
        batch: Dict[str, "torch.Tensor"],
 | 
			
		||||
        train_eval: Literal["train", "eval"] = "train",
 | 
			
		||||
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
 | 
			
		||||
    ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes the DPO loss and other metrics for the given batch of inputs for train or test.
 | 
			
		||||
        """
 | 
			
		||||
@ -142,21 +142,22 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
            reference_chosen_logps,
 | 
			
		||||
            reference_rejected_logps,
 | 
			
		||||
        )
 | 
			
		||||
        batch_loss = losses.mean()
 | 
			
		||||
        if self.ftx_gamma > 1e-6:
 | 
			
		||||
            batch_size = batch["input_ids"].size(0) // 2
 | 
			
		||||
            chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
 | 
			
		||||
            losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
 | 
			
		||||
            batch_loss += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels).mean()
 | 
			
		||||
 | 
			
		||||
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
 | 
			
		||||
 | 
			
		||||
        prefix = "eval_" if train_eval == "eval" else ""
 | 
			
		||||
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
 | 
			
		||||
        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
 | 
			
		||||
        metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
 | 
			
		||||
        metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
 | 
			
		||||
        metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
 | 
			
		||||
        metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
 | 
			
		||||
        metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().cpu().mean()
 | 
			
		||||
        metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().cpu().mean()
 | 
			
		||||
        metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().cpu().mean()
 | 
			
		||||
        metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().cpu().mean()
 | 
			
		||||
 | 
			
		||||
        return losses.mean(), metrics
 | 
			
		||||
        return batch_loss, metrics
 | 
			
		||||
 | 
			
		||||
@ -2,13 +2,12 @@
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...data import get_dataset, split_dataset
 | 
			
		||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.ploting import plot_loss
 | 
			
		||||
from ...hparams import ModelArguments
 | 
			
		||||
from ...model import load_model, load_tokenizer
 | 
			
		||||
from ..utils import create_modelcard_and_push, create_ref_model
 | 
			
		||||
from .collator import DPODataCollatorWithPadding
 | 
			
		||||
from .trainer import CustomDPOTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,7 +28,7 @@ def run_dpo(
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
 | 
			
		||||
    data_collator = DPODataCollatorWithPadding(
 | 
			
		||||
    data_collator = PairwiseDataCollatorWithPadding(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        pad_to_multiple_of=8,
 | 
			
		||||
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
 | 
			
		||||
@ -64,7 +63,7 @@ def run_dpo(
 | 
			
		||||
        trainer.save_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_state()
 | 
			
		||||
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
 | 
			
		||||
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
 | 
			
		||||
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
 | 
			
		||||
 | 
			
		||||
    # Evaluation
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								src/llmtuner/train/orpo/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								src/llmtuner/train/orpo/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
from .workflow import run_orpo
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ["run_orpo"]
 | 
			
		||||
							
								
								
									
										150
									
								
								src/llmtuner/train/orpo/trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								src/llmtuner/train/orpo/trainer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,150 @@
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from trl import DPOTrainer
 | 
			
		||||
from trl.trainer.utils import disable_dropout_in_model
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..utils import create_custom_optimzer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomORPOTrainer(DPOTrainer):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model: Union["PreTrainedModel", "torch.nn.Module"],
 | 
			
		||||
        finetuning_args: "FinetuningArguments",
 | 
			
		||||
        disable_dropout: bool = True,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        if disable_dropout:
 | 
			
		||||
            disable_dropout_in_model(model)
 | 
			
		||||
 | 
			
		||||
        self.finetuning_args = finetuning_args
 | 
			
		||||
        self.reference_free = False
 | 
			
		||||
        self.use_dpo_data_collator = True  # hack to avoid warning
 | 
			
		||||
        self.generate_during_eval = False  # disable at evaluation
 | 
			
		||||
        self.label_pad_token_id = IGNORE_INDEX
 | 
			
		||||
        self.padding_value = 0
 | 
			
		||||
        self.is_encoder_decoder = model.config.is_encoder_decoder
 | 
			
		||||
        self.precompute_ref_log_probs = False
 | 
			
		||||
        self._precomputed_train_ref_log_probs = False
 | 
			
		||||
        self._precomputed_eval_ref_log_probs = False
 | 
			
		||||
        self._peft_has_been_casted_to_bf16 = False
 | 
			
		||||
 | 
			
		||||
        self.beta = finetuning_args.orpo_beta
 | 
			
		||||
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
 | 
			
		||||
 | 
			
		||||
        Trainer.__init__(self, model=model, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
            self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
 | 
			
		||||
        return super().create_optimizer()
 | 
			
		||||
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes supervised cross-entropy loss of given labels under the given logits.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
 | 
			
		||||
        """
 | 
			
		||||
        all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
 | 
			
		||||
        return -all_logps
 | 
			
		||||
 | 
			
		||||
    # Borrowed from:
 | 
			
		||||
    # https://github.com/huggingface/trl/blob/0ee349dcd43b0f4b3169449f16751c38ac4a609f/trl/trainer/orpo_trainer.py#L592
 | 
			
		||||
    def odds_ratio_loss(
 | 
			
		||||
        self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor"
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes ORPO's odds ratio (OR) loss.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
 | 
			
		||||
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A tuple of five tensors: (losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen).
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
 | 
			
		||||
        log_odds = (chosen_logps - rejected_logps) - (
 | 
			
		||||
            torch.log(1 - torch.exp(chosen_logps)) - torch.log(1 - torch.exp(rejected_logps))
 | 
			
		||||
        )
 | 
			
		||||
        ratio = F.logsigmoid(log_odds)
 | 
			
		||||
        losses = self.beta * ratio
 | 
			
		||||
 | 
			
		||||
        chosen_rewards = self.beta * chosen_logps.detach()
 | 
			
		||||
        rejected_rewards = self.beta * rejected_logps.detach()
 | 
			
		||||
 | 
			
		||||
        return losses, chosen_rewards, rejected_rewards, ratio, log_odds
 | 
			
		||||
 | 
			
		||||
    def concatenated_forward(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
        all_logits = model(
 | 
			
		||||
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True
 | 
			
		||||
        ).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        all_logps = self.get_batch_logps(
 | 
			
		||||
            all_logits,
 | 
			
		||||
            batch["labels"],
 | 
			
		||||
            average_log_prob=False,
 | 
			
		||||
            label_pad_token_id=self.label_pad_token_id,
 | 
			
		||||
        )
 | 
			
		||||
        batch_size = batch["input_ids"].size(0) // 2
 | 
			
		||||
        chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
 | 
			
		||||
        chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
 | 
			
		||||
        return chosen_logps, rejected_logps, chosen_logits, rejected_logits
 | 
			
		||||
 | 
			
		||||
    def get_batch_loss_metrics(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
        batch: Dict[str, "torch.Tensor"],
 | 
			
		||||
        train_eval: Literal["train", "eval"] = "train",
 | 
			
		||||
    ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes the ORPO loss and other metrics for the given batch of inputs for train or test.
 | 
			
		||||
        """
 | 
			
		||||
        metrics = {}
 | 
			
		||||
        chosen_logps, rejected_logps, chosen_logits, rejected_logits = self.concatenated_forward(model, batch)
 | 
			
		||||
 | 
			
		||||
        losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
 | 
			
		||||
            chosen_logps, rejected_logps
 | 
			
		||||
        )
 | 
			
		||||
        batch_size = batch["input_ids"].size(0) // 2
 | 
			
		||||
        chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
 | 
			
		||||
        sft_loss = self.sft_loss(chosen_logits, chosen_labels)
 | 
			
		||||
        batch_loss = (sft_loss - losses).mean()
 | 
			
		||||
 | 
			
		||||
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
 | 
			
		||||
 | 
			
		||||
        prefix = "eval_" if train_eval == "eval" else ""
 | 
			
		||||
        metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.cpu().mean()
 | 
			
		||||
        metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.cpu().mean()
 | 
			
		||||
        metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.cpu().mean()
 | 
			
		||||
        metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).cpu().mean()
 | 
			
		||||
        metrics["{}logps/rejected".format(prefix)] = rejected_logps.detach().cpu().mean()
 | 
			
		||||
        metrics["{}logps/chosen".format(prefix)] = chosen_logps.detach().cpu().mean()
 | 
			
		||||
        metrics["{}logits/rejected".format(prefix)] = rejected_logits.detach().cpu().mean()
 | 
			
		||||
        metrics["{}logits/chosen".format(prefix)] = chosen_logits.detach().cpu().mean()
 | 
			
		||||
        metrics["{}sft_loss".format(prefix)] = sft_loss.detach().cpu().mean()
 | 
			
		||||
        metrics["{}log_odds_ratio".format(prefix)] = log_odds_ratio.detach().cpu().mean()
 | 
			
		||||
        metrics["{}log_odds_chosen".format(prefix)] = log_odds_chosen.detach().cpu().mean()
 | 
			
		||||
 | 
			
		||||
        return batch_loss, metrics
 | 
			
		||||
							
								
								
									
										68
									
								
								src/llmtuner/train/orpo/workflow.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								src/llmtuner/train/orpo/workflow.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,68 @@
 | 
			
		||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.ploting import plot_loss
 | 
			
		||||
from ...hparams import ModelArguments
 | 
			
		||||
from ...model import load_model, load_tokenizer
 | 
			
		||||
from ..utils import create_modelcard_and_push
 | 
			
		||||
from .trainer import CustomORPOTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import Seq2SeqTrainingArguments, TrainerCallback
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments, FinetuningArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_orpo(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    callbacks: Optional[List["TrainerCallback"]] = None,
 | 
			
		||||
):
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)
 | 
			
		||||
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm")
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
 | 
			
		||||
 | 
			
		||||
    data_collator = PairwiseDataCollatorWithPadding(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        pad_to_multiple_of=8,
 | 
			
		||||
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Update arguments
 | 
			
		||||
    training_args.remove_unused_columns = False  # important for pairwise dataset
 | 
			
		||||
 | 
			
		||||
    # Initialize our Trainer
 | 
			
		||||
    trainer = CustomORPOTrainer(
 | 
			
		||||
        model=model,
 | 
			
		||||
        args=training_args,
 | 
			
		||||
        finetuning_args=finetuning_args,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        data_collator=data_collator,
 | 
			
		||||
        callbacks=callbacks,
 | 
			
		||||
        **split_dataset(dataset, data_args, training_args),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Training
 | 
			
		||||
    if training_args.do_train:
 | 
			
		||||
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
 | 
			
		||||
        trainer.save_model()
 | 
			
		||||
        trainer.log_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_state()
 | 
			
		||||
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
 | 
			
		||||
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "accuracy"])
 | 
			
		||||
 | 
			
		||||
    # Evaluation
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
        metrics = trainer.evaluate(metric_key_prefix="eval")
 | 
			
		||||
        trainer.log_metrics("eval", metrics)
 | 
			
		||||
        trainer.save_metrics("eval", metrics)
 | 
			
		||||
 | 
			
		||||
    # Create model card
 | 
			
		||||
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
 | 
			
		||||
@ -2,13 +2,12 @@
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...data import get_dataset, split_dataset
 | 
			
		||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
 | 
			
		||||
from ...extras.callbacks import FixValueHeadModelCallback
 | 
			
		||||
from ...extras.misc import fix_valuehead_checkpoint
 | 
			
		||||
from ...extras.ploting import plot_loss
 | 
			
		||||
from ...model import load_model, load_tokenizer
 | 
			
		||||
from ..utils import create_modelcard_and_push
 | 
			
		||||
from .collator import PairwiseDataCollatorWithPadding
 | 
			
		||||
from .metric import compute_accuracy
 | 
			
		||||
from .trainer import PairwiseTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ from ..extras.logging import get_logger
 | 
			
		||||
from ..hparams import get_infer_args, get_train_args
 | 
			
		||||
from ..model import load_model_and_tokenizer
 | 
			
		||||
from .dpo import run_dpo
 | 
			
		||||
from .orpo import run_orpo
 | 
			
		||||
from .ppo import run_ppo
 | 
			
		||||
from .pt import run_pt
 | 
			
		||||
from .rm import run_rm
 | 
			
		||||
@ -36,6 +37,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
 | 
			
		||||
        run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
 | 
			
		||||
    elif finetuning_args.stage == "dpo":
 | 
			
		||||
        run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
			
		||||
    elif finetuning_args.stage == "orpo":
 | 
			
		||||
        run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unknown task.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ from ..extras.constants import (
 | 
			
		||||
    DEFAULT_MODULE,
 | 
			
		||||
    DEFAULT_TEMPLATE,
 | 
			
		||||
    PEFT_METHODS,
 | 
			
		||||
    STAGES_USE_PAIR_DATA,
 | 
			
		||||
    SUPPORTED_MODELS,
 | 
			
		||||
    TRAINING_STAGES,
 | 
			
		||||
    DownloadSource,
 | 
			
		||||
@ -127,7 +128,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
 | 
			
		||||
 | 
			
		||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
 | 
			
		||||
    dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
 | 
			
		||||
    ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
 | 
			
		||||
    ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
 | 
			
		||||
    datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
 | 
			
		||||
    return gr.Dropdown(value=[], choices=datasets)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user