mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	Merge branch 'main' into add_dataset_sample_num
Former-commit-id: 26300127c45f24e63b91f1b0cc73e46c3a936a91
This commit is contained in:
		
						commit
						a3b52fd380
					
				
							
								
								
									
										17
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								README.md
									
									
									
									
									
								
							@ -69,12 +69,12 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion.
 | 
			
		||||
 | 
			
		||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
 | 
			
		||||
 | 
			
		||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
 | 
			
		||||
 | 
			
		||||
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
 | 
			
		||||
@ -160,6 +160,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)             | 7B/13B                           | q_proj,v_proj     | vicuna    |
 | 
			
		||||
| [Mistral/Mixtral](https://huggingface.co/mistralai)      | 7B/8x7B/8x22B                    | q_proj,v_proj     | mistral   |
 | 
			
		||||
| [OLMo](https://huggingface.co/allenai)                   | 1B/7B                            | q_proj,v_proj     | -         |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)               | 3B                               | q_proj,v_proj     | gemma     |
 | 
			
		||||
| [Phi-1.5/2](https://huggingface.co/microsoft)            | 1.3B/2.7B                        | q_proj,v_proj     | -         |
 | 
			
		||||
| [Phi-3](https://huggingface.co/microsoft)                | 3.8B                             | qkv_proj          | phi       |
 | 
			
		||||
| [Qwen](https://huggingface.co/Qwen)                      | 1.8B/7B/14B/72B                  | c_attn            | qwen      |
 | 
			
		||||
@ -284,11 +285,11 @@ huggingface-cli login
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.8     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.2.0     |
 | 
			
		||||
| transformers | 4.37.2  | 4.40.1    |
 | 
			
		||||
| transformers | 4.37.2  | 4.41.0    |
 | 
			
		||||
| datasets     | 2.14.3  | 2.19.1    |
 | 
			
		||||
| accelerate   | 0.27.2  | 0.30.0    |
 | 
			
		||||
| peft         | 0.9.0   | 0.10.0    |
 | 
			
		||||
| trl          | 0.8.1   | 0.8.6     |
 | 
			
		||||
| accelerate   | 0.27.2  | 0.30.1    |
 | 
			
		||||
| peft         | 0.9.0   | 0.11.1    |
 | 
			
		||||
| trl          | 0.8.2   | 0.8.6     |
 | 
			
		||||
 | 
			
		||||
| Optional     | Minimum | Recommend |
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
@ -344,6 +345,8 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
 | 
			
		||||
 | 
			
		||||
<details><summary>For Ascend NPU users</summary>
 | 
			
		||||
 | 
			
		||||
Join [NPU user group](assets/wechat_npu.jpg).
 | 
			
		||||
 | 
			
		||||
To utilize Ascend NPU devices for (distributed) training and inference, you need to install the **[torch-npu](https://gitee.com/ascend/pytorch)** library and the **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**.
 | 
			
		||||
 | 
			
		||||
| Requirement  | Minimum | Recommend |
 | 
			
		||||
@ -356,7 +359,7 @@ To utilize Ascend NPU devices for (distributed) training and inference, you need
 | 
			
		||||
Docker image:
 | 
			
		||||
 | 
			
		||||
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
 | 
			
		||||
- 64GB: Coming soon
 | 
			
		||||
- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
 | 
			
		||||
 | 
			
		||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								README_zh.md
									
									
									
									
									
								
							@ -69,12 +69,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。
 | 
			
		||||
 | 
			
		||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
 | 
			
		||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
 | 
			
		||||
 | 
			
		||||
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
@ -160,6 +160,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)             | 7B/13B                           | q_proj,v_proj     | vicuna    |
 | 
			
		||||
| [Mistral/Mixtral](https://huggingface.co/mistralai)      | 7B/8x7B/8x22B                    | q_proj,v_proj     | mistral   |
 | 
			
		||||
| [OLMo](https://huggingface.co/allenai)                   | 1B/7B                            | q_proj,v_proj     | -         |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)               | 3B                               | q_proj,v_proj     | gemma     |
 | 
			
		||||
| [Phi-1.5/2](https://huggingface.co/microsoft)            | 1.3B/2.7B                        | q_proj,v_proj     | -         |
 | 
			
		||||
| [Phi-3](https://huggingface.co/microsoft)                | 3.8B                             | qkv_proj          | phi       |
 | 
			
		||||
| [Qwen](https://huggingface.co/Qwen)                      | 1.8B/7B/14B/72B                  | c_attn            | qwen      |
 | 
			
		||||
@ -284,11 +285,11 @@ huggingface-cli login
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.8     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.2.0     |
 | 
			
		||||
| transformers | 4.37.2  | 4.40.1    |
 | 
			
		||||
| transformers | 4.37.2  | 4.41.0    |
 | 
			
		||||
| datasets     | 2.14.3  | 2.19.1    |
 | 
			
		||||
| accelerate   | 0.27.2  | 0.30.0    |
 | 
			
		||||
| peft         | 0.9.0   | 0.10.0    |
 | 
			
		||||
| trl          | 0.8.1   | 0.8.6     |
 | 
			
		||||
| accelerate   | 0.27.2  | 0.30.1    |
 | 
			
		||||
| peft         | 0.9.0   | 0.11.1    |
 | 
			
		||||
| trl          | 0.8.2   | 0.8.6     |
 | 
			
		||||
 | 
			
		||||
| 可选项       | 至少     | 推荐      |
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
@ -344,6 +345,8 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
 | 
			
		||||
 | 
			
		||||
<details><summary>昇腾 NPU 用户指南</summary>
 | 
			
		||||
 | 
			
		||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
 | 
			
		||||
 | 
			
		||||
如果使用昇腾 NPU 设备进行(分布式)训练或推理,需要安装 **[torch-npu](https://gitee.com/ascend/pytorch)** 库和 **[Ascend CANN Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**。
 | 
			
		||||
 | 
			
		||||
| 依赖项       | 至少     | 推荐      |
 | 
			
		||||
@ -356,7 +359,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
 | 
			
		||||
Docker 镜像:
 | 
			
		||||
 | 
			
		||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
 | 
			
		||||
- 64GB:敬请期待
 | 
			
		||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
 | 
			
		||||
 | 
			
		||||
请记得使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定您使用的设备。
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@
 | 
			
		||||
  "hf_hub_url": "Hugging Face 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
 | 
			
		||||
  "ms_hub_url": "ModelScope 的数据集仓库地址(若指定,则忽略 script_url 和 file_name)",
 | 
			
		||||
  "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略 file_name)",
 | 
			
		||||
  "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
 | 
			
		||||
  "file_name": "该目录下数据集文件夹或文件的名称(若上述参数未指定,则此项必需)",
 | 
			
		||||
  "formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
 | 
			
		||||
  "ranking": "是否为偏好数据集(可选,默认:False)",
 | 
			
		||||
  "subset": "数据集子集的名称(可选,默认:None)",
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,8 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
 | 
			
		||||
        features = datasets.Features(
 | 
			
		||||
            {
 | 
			
		||||
                "instruction": datasets.Value("string"),
 | 
			
		||||
                "output": datasets.Sequence(datasets.Value("string")),
 | 
			
		||||
                "chosen": datasets.Value("string"),
 | 
			
		||||
                "rejected": datasets.Value("string"),
 | 
			
		||||
                "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import torch
 | 
			
		||||
from transformers import GenerationConfig, TextIteratorStreamer
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras.constants import IMAGE_TOKEN
 | 
			
		||||
from ..extras.misc import get_logits_processor
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
from .base_engine import BaseEngine, Response
 | 
			
		||||
@ -55,14 +56,28 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Tuple[Dict[str, Any], int]:
 | 
			
		||||
        if processor is not None and image is not None and "<image>" not in messages[0]["content"]:
 | 
			
		||||
            messages[0]["content"] = "<image>" + messages[0]["content"]
 | 
			
		||||
        if (
 | 
			
		||||
            processor is not None
 | 
			
		||||
            and image is not None
 | 
			
		||||
            and not hasattr(processor, "image_seq_length")
 | 
			
		||||
            and IMAGE_TOKEN not in messages[0]["content"]
 | 
			
		||||
        ):  # llava-like models
 | 
			
		||||
            messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or generating_args["default_system"]
 | 
			
		||||
        pixel_values = None
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(
 | 
			
		||||
            tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
 | 
			
		||||
        )
 | 
			
		||||
        if processor is not None and image is not None:  # add image features
 | 
			
		||||
            image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
            batch_feature = image_processor(image, return_tensors="pt")
 | 
			
		||||
            pixel_values = batch_feature.to(model.device)["pixel_values"]  # shape (B, C, H, W)
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 | 
			
		||||
                prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
 | 
			
		||||
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
        inputs = torch.tensor([prompt_ids], device=model.device)
 | 
			
		||||
 | 
			
		||||
@ -122,10 +137,8 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            logits_processor=get_logits_processor(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if processor is not None and image is not None:
 | 
			
		||||
            image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
            pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
 | 
			
		||||
            gen_kwargs["pixel_values"] = pixel_values.to(model.device)
 | 
			
		||||
        if pixel_values is not None:
 | 
			
		||||
            gen_kwargs["pixel_values"] = pixel_values
 | 
			
		||||
 | 
			
		||||
        return gen_kwargs, prompt_length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ import uuid
 | 
			
		||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras.constants import IMAGE_TOKEN
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from ..extras.misc import get_device_count, infer_optim_dtype
 | 
			
		||||
from ..extras.packages import is_vllm_available
 | 
			
		||||
@ -17,7 +18,6 @@ if is_vllm_available():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import torch
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            patch_size = config.vision_config.patch_size
 | 
			
		||||
            self.image_feature_size = (image_size // patch_size) ** 2
 | 
			
		||||
            engine_args["image_input_type"] = "pixel_values"
 | 
			
		||||
            engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
 | 
			
		||||
            engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 | 
			
		||||
            engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
 | 
			
		||||
            engine_args["image_feature_size"] = self.image_feature_size
 | 
			
		||||
            if getattr(config, "is_yi_vl_derived_model", None):
 | 
			
		||||
@ -92,14 +92,28 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator["RequestOutput"]:
 | 
			
		||||
        request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
 | 
			
		||||
        if self.processor is not None and image is not None and "<image>" not in messages[0]["content"]:
 | 
			
		||||
            messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            self.processor is not None
 | 
			
		||||
            and image is not None
 | 
			
		||||
            and not hasattr(self.processor, "image_seq_length")
 | 
			
		||||
            and IMAGE_TOKEN not in messages[0]["content"]
 | 
			
		||||
        ):  # llava-like models
 | 
			
		||||
            messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = self.template.encode_oneturn(
 | 
			
		||||
            tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.processor is not None and image is not None:  # add image features
 | 
			
		||||
            image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
 | 
			
		||||
            pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
 | 
			
		||||
            multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
 | 
			
		||||
        use_beam_search: bool = self.generating_args["num_beams"] > 1
 | 
			
		||||
@ -144,13 +158,6 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            skip_special_tokens=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.processor is not None and image is not None:
 | 
			
		||||
            image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
 | 
			
		||||
            pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"]
 | 
			
		||||
            multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
        result_generator = self.model.generate(
 | 
			
		||||
            prompt=None,
 | 
			
		||||
            sampling_params=sampling_params,
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, List, Sequence, Tuple
 | 
			
		||||
from typing import Any, Dict, Sequence
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
@ -11,21 +11,6 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
    Data collator for pairwise data.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
 | 
			
		||||
        r"""
 | 
			
		||||
        Masks out the input ids except for the responses.
 | 
			
		||||
        """
 | 
			
		||||
        padded_labels = []
 | 
			
		||||
        for feature, (prompt_len, answer_len) in zip(batch, positions):
 | 
			
		||||
            if self.tokenizer.padding_side == "left":
 | 
			
		||||
                start, end = feature.size(0) - answer_len, feature.size(0)
 | 
			
		||||
            else:
 | 
			
		||||
                start, end = prompt_len, prompt_len + answer_len
 | 
			
		||||
            padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
 | 
			
		||||
            padded_tensor[start:end] = feature[start:end]
 | 
			
		||||
            padded_labels.append(padded_tensor)
 | 
			
		||||
        return torch.stack(padded_labels, dim=0).contiguous()  # in contiguous memory
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pads batched data to the longest sequence in the batch.
 | 
			
		||||
@ -34,21 +19,22 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
        the last n examples represent rejected examples.
 | 
			
		||||
        """
 | 
			
		||||
        concatenated_features = []
 | 
			
		||||
        label_positions = []
 | 
			
		||||
        for key in ("chosen_ids", "rejected_ids"):
 | 
			
		||||
        for key in ("chosen", "rejected"):
 | 
			
		||||
            for feature in features:
 | 
			
		||||
                prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
 | 
			
		||||
                concatenated_features.append(
 | 
			
		||||
                    {
 | 
			
		||||
                        "input_ids": feature["prompt_ids"] + feature[key],
 | 
			
		||||
                        "attention_mask": [1] * (prompt_len + answer_len),
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                label_positions.append((prompt_len, answer_len))
 | 
			
		||||
                target_feature = {
 | 
			
		||||
                    "input_ids": feature["{}_input_ids".format(key)],
 | 
			
		||||
                    "attention_mask": feature["{}_attention_mask".format(key)],
 | 
			
		||||
                    "labels": feature["{}_labels".format(key)],
 | 
			
		||||
                }
 | 
			
		||||
                if "pixel_values" in feature:
 | 
			
		||||
                    target_feature["pixel_values"] = feature["pixel_values"]
 | 
			
		||||
 | 
			
		||||
        batch = super().__call__(concatenated_features)
 | 
			
		||||
        batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
 | 
			
		||||
        return batch
 | 
			
		||||
                if "{}_token_type_ids".format(key) in feature:
 | 
			
		||||
                    target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
 | 
			
		||||
 | 
			
		||||
                concatenated_features.append(target_feature)
 | 
			
		||||
 | 
			
		||||
        return super().__call__(concatenated_features)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -62,20 +48,25 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
        kl_features = []
 | 
			
		||||
        kto_tags = []
 | 
			
		||||
        for feature in features:
 | 
			
		||||
            target_features.append(
 | 
			
		||||
                {
 | 
			
		||||
                    "input_ids": feature["input_ids"],
 | 
			
		||||
                    "attention_mask": feature["attention_mask"],
 | 
			
		||||
                    "labels": feature["labels"],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            kl_features.append(
 | 
			
		||||
                {
 | 
			
		||||
                    "input_ids": feature["kl_input_ids"],
 | 
			
		||||
                    "attention_mask": feature["kl_attention_mask"],
 | 
			
		||||
                    "labels": feature["kl_labels"],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            target_feature = {
 | 
			
		||||
                "input_ids": feature["input_ids"],
 | 
			
		||||
                "attention_mask": feature["attention_mask"],
 | 
			
		||||
                "labels": feature["labels"],
 | 
			
		||||
            }
 | 
			
		||||
            kl_feature = {
 | 
			
		||||
                "input_ids": feature["kl_input_ids"],
 | 
			
		||||
                "attention_mask": feature["kl_attention_mask"],
 | 
			
		||||
                "labels": feature["kl_labels"],
 | 
			
		||||
            }
 | 
			
		||||
            if "pixel_values" in feature:
 | 
			
		||||
                target_feature["pixel_values"] = feature["pixel_values"]
 | 
			
		||||
 | 
			
		||||
            if "token_type_ids" in feature:
 | 
			
		||||
                target_feature["token_type_ids"] = feature["token_type_ids"]
 | 
			
		||||
                kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
 | 
			
		||||
 | 
			
		||||
            target_features.append(target_feature)
 | 
			
		||||
            kl_features.append(kl_feature)
 | 
			
		||||
            kto_tags.append(feature["kto_tags"])
 | 
			
		||||
 | 
			
		||||
        batch = super().__call__(target_features)
 | 
			
		||||
@ -83,5 +74,8 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
        batch["kl_input_ids"] = kl_batch["input_ids"]
 | 
			
		||||
        batch["kl_attention_mask"] = kl_batch["attention_mask"]
 | 
			
		||||
        batch["kl_labels"] = kl_batch["labels"]
 | 
			
		||||
        if "token_type_ids" in batch:
 | 
			
		||||
            batch["kl_token_type_ids"] = kl_batch["token_type_ids"]
 | 
			
		||||
 | 
			
		||||
        batch["kto_tags"] = torch.tensor(kto_tags)
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@ import inspect
 | 
			
		||||
import os
 | 
			
		||||
import numpy as np
 | 
			
		||||
from numpy.random import RandomState
 | 
			
		||||
import sys
 | 
			
		||||
from typing import TYPE_CHECKING, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
from datasets import load_dataset, load_from_disk
 | 
			
		||||
@ -180,12 +181,15 @@ def get_dataset(
 | 
			
		||||
                logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
 | 
			
		||||
                logger.info("Please restart the training with `--tokenized_path {}`.".format(data_args.tokenized_path))
 | 
			
		||||
 | 
			
		||||
            exit(0)
 | 
			
		||||
            sys.exit(0)
 | 
			
		||||
 | 
			
		||||
        if training_args.should_log:
 | 
			
		||||
            try:
 | 
			
		||||
                print_function(next(iter(dataset)))
 | 
			
		||||
            except StopIteration:
 | 
			
		||||
                raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
 | 
			
		||||
                if stage == "pt":
 | 
			
		||||
                    raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
 | 
			
		||||
                else:
 | 
			
		||||
                    raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
 | 
			
		||||
 | 
			
		||||
        return dataset
 | 
			
		||||
 | 
			
		||||
@ -1,380 +1,25 @@
 | 
			
		||||
from functools import partial
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
 | 
			
		||||
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from ..extras.packages import is_pillow_available
 | 
			
		||||
from .utils import Role
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
from .processors.feedback import preprocess_feedback_dataset
 | 
			
		||||
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example
 | 
			
		||||
from .processors.pretrain import preprocess_pretrain_dataset
 | 
			
		||||
from .processors.supervised import (
 | 
			
		||||
    preprocess_packed_supervised_dataset,
 | 
			
		||||
    preprocess_supervised_dataset,
 | 
			
		||||
    print_supervised_dataset_example,
 | 
			
		||||
)
 | 
			
		||||
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
    from transformers import ProcessorMixin, Seq2SeqTrainingArguments
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments
 | 
			
		||||
    from .template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _preprocess_visual_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
 | 
			
		||||
    # process visual inputs (currently only supports a single image)
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
 | 
			
		||||
    return image_processor(image, return_tensors="pt")["pixel_values"][0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_pretrain_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
 | 
			
		||||
    text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
 | 
			
		||||
 | 
			
		||||
    if not data_args.packing:
 | 
			
		||||
        if data_args.template == "gemma":
 | 
			
		||||
            text_examples = [tokenizer.bos_token + example for example in text_examples]
 | 
			
		||||
 | 
			
		||||
        result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
 | 
			
		||||
    else:
 | 
			
		||||
        tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
 | 
			
		||||
        concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
 | 
			
		||||
        total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
 | 
			
		||||
        block_size = data_args.cutoff_len
 | 
			
		||||
        total_length = (total_length // block_size) * block_size
 | 
			
		||||
        result = {
 | 
			
		||||
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 | 
			
		||||
            for k, t in concatenated_examples.items()
 | 
			
		||||
        }
 | 
			
		||||
        if data_args.template == "gemma":
 | 
			
		||||
            for i in range(len(result["input_ids"])):
 | 
			
		||||
                result["input_ids"][i][0] = tokenizer.bos_token_id
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_supervised_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
 | 
			
		||||
    # for multiturn examples, we only mask the prompt part in each prompt-response pair.
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = examples["prompt"][i] + examples["response"][i]
 | 
			
		||||
        input_ids, labels = [], []
 | 
			
		||||
        for turn_idx, (source_ids, target_ids) in enumerate(
 | 
			
		||||
            template.encode_multiturn(
 | 
			
		||||
                tokenizer,
 | 
			
		||||
                messages,
 | 
			
		||||
                examples["system"][i],
 | 
			
		||||
                examples["tools"][i],
 | 
			
		||||
                data_args.cutoff_len,
 | 
			
		||||
                data_args.reserved_label_len,
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            if data_args.train_on_prompt:
 | 
			
		||||
                source_mask = source_ids
 | 
			
		||||
            elif turn_idx != 0 and template.efficient_eos:
 | 
			
		||||
                source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
 | 
			
		||||
            else:
 | 
			
		||||
                source_mask = [IGNORE_INDEX] * len(source_ids)
 | 
			
		||||
 | 
			
		||||
            input_ids += source_ids + target_ids
 | 
			
		||||
            labels += source_mask + target_ids
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            input_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_packed_supervised_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
 | 
			
		||||
    # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    input_ids, labels = [], []
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        messages = examples["prompt"][i] + examples["response"][i]
 | 
			
		||||
        for source_ids, target_ids in template.encode_multiturn(
 | 
			
		||||
            tokenizer, messages, examples["system"][i], examples["tools"][i]
 | 
			
		||||
        ):
 | 
			
		||||
            if data_args.train_on_prompt:
 | 
			
		||||
                source_mask = source_ids
 | 
			
		||||
            elif len(input_ids) != 0 and template.efficient_eos:
 | 
			
		||||
                source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
 | 
			
		||||
            else:
 | 
			
		||||
                source_mask = [IGNORE_INDEX] * len(source_ids)
 | 
			
		||||
 | 
			
		||||
            input_ids += source_ids + target_ids
 | 
			
		||||
            labels += source_mask + target_ids
 | 
			
		||||
 | 
			
		||||
    if template.efficient_eos:
 | 
			
		||||
        input_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    total_length = len(input_ids)
 | 
			
		||||
    block_size = data_args.cutoff_len
 | 
			
		||||
    # we drop the small remainder, and if the total_length < block_size, we exclude this batch
 | 
			
		||||
    total_length = (total_length // block_size) * block_size
 | 
			
		||||
    # split by chunks of cutoff_len
 | 
			
		||||
    for i in range(0, total_length, block_size):
 | 
			
		||||
        if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
 | 
			
		||||
            model_inputs["input_ids"].append(input_ids[i : i + block_size])
 | 
			
		||||
            model_inputs["attention_mask"].append([1] * block_size)
 | 
			
		||||
            model_inputs["labels"].append(labels[i : i + block_size])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_unsupervised_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X` and labels with format `Y <eos>`
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        if len(examples["response"][i]) == 1:
 | 
			
		||||
            messages = examples["prompt"][i] + examples["response"][i]
 | 
			
		||||
        else:
 | 
			
		||||
            messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
 | 
			
		||||
 | 
			
		||||
        input_ids, labels = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_pairwise_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
 | 
			
		||||
    model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
 | 
			
		||||
        rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
 | 
			
		||||
        prompt_ids, chosen_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            chosen_messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
        _, rejected_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            rejected_messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            chosen_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            rejected_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        model_inputs["prompt_ids"].append(prompt_ids)
 | 
			
		||||
        model_inputs["chosen_ids"].append(chosen_ids)
 | 
			
		||||
        model_inputs["rejected_ids"].append(rejected_ids)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_kto_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
 | 
			
		||||
    kl_response = examples["response"][::-1]
 | 
			
		||||
    model_inputs = {
 | 
			
		||||
        "input_ids": [],
 | 
			
		||||
        "attention_mask": [],
 | 
			
		||||
        "labels": [],
 | 
			
		||||
        "kl_input_ids": [],
 | 
			
		||||
        "kl_attention_mask": [],
 | 
			
		||||
        "kl_labels": [],
 | 
			
		||||
        "kto_tags": [],
 | 
			
		||||
    }
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        if examples["response"][i][0]["content"]:  # desired example
 | 
			
		||||
            kto_tag = True
 | 
			
		||||
            messages = examples["prompt"][i] + [examples["response"][i][0]]
 | 
			
		||||
        else:  # undesired example
 | 
			
		||||
            kto_tag = False
 | 
			
		||||
            messages = examples["prompt"][i] + [examples["response"][i][1]]
 | 
			
		||||
 | 
			
		||||
        if kl_response[i][0]["content"]:
 | 
			
		||||
            kl_messages = examples["prompt"][i] + [kl_response[i][0]]
 | 
			
		||||
        else:
 | 
			
		||||
            kl_messages = examples["prompt"][i] + [kl_response[i][1]]
 | 
			
		||||
 | 
			
		||||
        prompt_ids, response_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
        _, kl_response_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            kl_messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            kl_response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        input_ids = prompt_ids + response_ids
 | 
			
		||||
        labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
 | 
			
		||||
        kl_input_ids = prompt_ids + kl_response_ids
 | 
			
		||||
        kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        model_inputs["kl_input_ids"].append(kl_input_ids)
 | 
			
		||||
        model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
 | 
			
		||||
        model_inputs["kl_labels"].append(kl_labels)
 | 
			
		||||
        model_inputs["kto_tags"].append(kto_tag)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
 | 
			
		||||
 | 
			
		||||
    desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
 | 
			
		||||
    undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
 | 
			
		||||
    if desirable_num == 0 or undesirable_num == 0:
 | 
			
		||||
        logger.warning("Your dataset only has one preference type.")
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
    print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
    print("label_ids:\n{}".format(example["labels"]))
 | 
			
		||||
    print(
 | 
			
		||||
        "labels:\n{}".format(
 | 
			
		||||
            tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    print("prompt_ids:\n{}".format(example["prompt_ids"]))
 | 
			
		||||
    print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
 | 
			
		||||
    print("chosen_ids:\n{}".format(example["chosen_ids"]))
 | 
			
		||||
    print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
 | 
			
		||||
    print("rejected_ids:\n{}".format(example["rejected_ids"]))
 | 
			
		||||
    print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
    print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_preprocess_and_print_func(
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
@ -419,7 +64,7 @@ def get_preprocess_and_print_func(
 | 
			
		||||
        print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
 | 
			
		||||
    elif stage == "kto":
 | 
			
		||||
        preprocess_func = partial(
 | 
			
		||||
            preprocess_kto_dataset,
 | 
			
		||||
            preprocess_feedback_dataset,
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										0
									
								
								src/llamafactory/data/processors/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/llamafactory/data/processors/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										110
									
								
								src/llamafactory/data/processors/feedback.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								src/llamafactory/data/processors/feedback.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,110 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_feedback_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
 | 
			
		||||
    kl_response = examples["response"][::-1]
 | 
			
		||||
    model_inputs = {
 | 
			
		||||
        "input_ids": [],
 | 
			
		||||
        "attention_mask": [],
 | 
			
		||||
        "labels": [],
 | 
			
		||||
        "kl_input_ids": [],
 | 
			
		||||
        "kl_attention_mask": [],
 | 
			
		||||
        "kl_labels": [],
 | 
			
		||||
        "kto_tags": [],
 | 
			
		||||
    }
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["token_type_ids"] = []
 | 
			
		||||
            model_inputs["kl_token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
            examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        if examples["response"][i][0]["content"]:  # desired example
 | 
			
		||||
            kto_tag = True
 | 
			
		||||
            messages = examples["prompt"][i] + [examples["response"][i][0]]
 | 
			
		||||
        else:  # undesired example
 | 
			
		||||
            kto_tag = False
 | 
			
		||||
            messages = examples["prompt"][i] + [examples["response"][i][1]]
 | 
			
		||||
 | 
			
		||||
        if kl_response[i][0]["content"]:
 | 
			
		||||
            kl_messages = examples["prompt"][i] + [kl_response[i][0]]
 | 
			
		||||
        else:
 | 
			
		||||
            kl_messages = examples["prompt"][i] + [kl_response[i][1]]
 | 
			
		||||
 | 
			
		||||
        prompt_ids, response_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
        _, kl_response_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            kl_messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            kl_response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 | 
			
		||||
            prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
 | 
			
		||||
 | 
			
		||||
        input_ids = prompt_ids + response_ids
 | 
			
		||||
        labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
 | 
			
		||||
        kl_input_ids = prompt_ids + kl_response_ids
 | 
			
		||||
        kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        model_inputs["kl_input_ids"].append(kl_input_ids)
 | 
			
		||||
        model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
 | 
			
		||||
        model_inputs["kl_labels"].append(kl_labels)
 | 
			
		||||
        model_inputs["kto_tags"].append(kto_tag)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
 | 
			
		||||
                model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
 | 
			
		||||
 | 
			
		||||
    desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
 | 
			
		||||
    undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
 | 
			
		||||
    if desirable_num == 0 or undesirable_num == 0:
 | 
			
		||||
        logger.warning("Your dataset only has one preference type.")
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
							
								
								
									
										27
									
								
								src/llamafactory/data/processors/mm_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								src/llamafactory/data/processors/mm_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
from typing import TYPE_CHECKING, List, Sequence
 | 
			
		||||
 | 
			
		||||
from ...extras.packages import is_pillow_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
 | 
			
		||||
    # process visual inputs (currently only supports a single image)
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
 | 
			
		||||
    return image_processor(image, return_tensors="pt")["pixel_values"][0]  # shape (C, H, W)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
 | 
			
		||||
    # get paligemma token type ids for computing loss
 | 
			
		||||
    image_seq_length = getattr(processor, "image_seq_length")
 | 
			
		||||
    return [0] * image_seq_length + [1] * (input_len - image_seq_length)
 | 
			
		||||
							
								
								
									
										109
									
								
								src/llamafactory/data/processors/pairwise.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								src/llamafactory/data/processors/pairwise.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,109 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_pairwise_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
 | 
			
		||||
    model_inputs = {
 | 
			
		||||
        "chosen_input_ids": [],
 | 
			
		||||
        "chosen_attention_mask": [],
 | 
			
		||||
        "chosen_labels": [],
 | 
			
		||||
        "rejected_input_ids": [],
 | 
			
		||||
        "rejected_attention_mask": [],
 | 
			
		||||
        "rejected_labels": [],
 | 
			
		||||
    }
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["chosen_token_type_ids"] = []
 | 
			
		||||
            model_inputs["rejected_token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
            examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
 | 
			
		||||
        rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
 | 
			
		||||
        prompt_ids, chosen_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            chosen_messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
        _, rejected_ids = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            rejected_messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            chosen_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            rejected_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 | 
			
		||||
            prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
 | 
			
		||||
 | 
			
		||||
        chosen_input_ids = prompt_ids + chosen_ids
 | 
			
		||||
        chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
 | 
			
		||||
        rejected_input_ids = prompt_ids + rejected_ids
 | 
			
		||||
        rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
 | 
			
		||||
        model_inputs["chosen_input_ids"].append(chosen_input_ids)
 | 
			
		||||
        model_inputs["chosen_attention_mask"].append([1] * len(chosen_input_ids))
 | 
			
		||||
        model_inputs["chosen_labels"].append(chosen_labels)
 | 
			
		||||
        model_inputs["rejected_input_ids"].append(rejected_input_ids)
 | 
			
		||||
        model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
 | 
			
		||||
        model_inputs["rejected_labels"].append(rejected_labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["chosen_token_type_ids"].append(
 | 
			
		||||
                    get_paligemma_token_type_ids(len(chosen_input_ids), processor)
 | 
			
		||||
                )
 | 
			
		||||
                model_inputs["rejected_token_type_ids"].append(
 | 
			
		||||
                    get_paligemma_token_type_ids(len(rejected_input_ids), processor)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    valid_chosen_labels = list(filter(lambda x: x != IGNORE_INDEX, example["chosen_labels"]))
 | 
			
		||||
    valid_rejected_labels = list(filter(lambda x: x != IGNORE_INDEX, example["rejected_labels"]))
 | 
			
		||||
    print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
 | 
			
		||||
    print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
 | 
			
		||||
    print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
 | 
			
		||||
    print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
 | 
			
		||||
    print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
 | 
			
		||||
    print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
 | 
			
		||||
    print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
 | 
			
		||||
    print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
 | 
			
		||||
							
								
								
									
										36
									
								
								src/llamafactory/data/processors/pretrain.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								src/llamafactory/data/processors/pretrain.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_pretrain_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
 | 
			
		||||
    text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
 | 
			
		||||
 | 
			
		||||
    if not data_args.packing:
 | 
			
		||||
        if data_args.template == "gemma":
 | 
			
		||||
            text_examples = [tokenizer.bos_token + example for example in text_examples]
 | 
			
		||||
 | 
			
		||||
        result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
 | 
			
		||||
    else:
 | 
			
		||||
        tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
 | 
			
		||||
        concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
 | 
			
		||||
        total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
 | 
			
		||||
        block_size = data_args.cutoff_len
 | 
			
		||||
        total_length = (total_length // block_size) * block_size
 | 
			
		||||
        result = {
 | 
			
		||||
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 | 
			
		||||
            for k, t in concatenated_examples.items()
 | 
			
		||||
        }
 | 
			
		||||
        if data_args.template == "gemma":
 | 
			
		||||
            for i in range(len(result["input_ids"])):
 | 
			
		||||
                result["input_ids"][i][0] = tokenizer.bos_token_id
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
							
								
								
									
										137
									
								
								src/llamafactory/data/processors/supervised.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								src/llamafactory/data/processors/supervised.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX, IMAGE_TOKEN
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_supervised_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
 | 
			
		||||
    # for multiturn examples, we only mask the prompt part in each prompt-response pair.
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
            examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = examples["prompt"][i] + examples["response"][i]
 | 
			
		||||
        input_ids, labels = [], []
 | 
			
		||||
 | 
			
		||||
        if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 | 
			
		||||
            input_ids += [image_token_id] * getattr(processor, "image_seq_length")
 | 
			
		||||
            labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
 | 
			
		||||
 | 
			
		||||
        for turn_idx, (source_ids, target_ids) in enumerate(
 | 
			
		||||
            template.encode_multiturn(
 | 
			
		||||
                tokenizer,
 | 
			
		||||
                messages,
 | 
			
		||||
                examples["system"][i],
 | 
			
		||||
                examples["tools"][i],
 | 
			
		||||
                data_args.cutoff_len,
 | 
			
		||||
                data_args.reserved_label_len,
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            if data_args.train_on_prompt:
 | 
			
		||||
                source_mask = source_ids
 | 
			
		||||
            elif turn_idx != 0 and template.efficient_eos:
 | 
			
		||||
                source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
 | 
			
		||||
            else:
 | 
			
		||||
                source_mask = [IGNORE_INDEX] * len(source_ids)
 | 
			
		||||
 | 
			
		||||
            input_ids += source_ids + target_ids
 | 
			
		||||
            labels += source_mask + target_ids
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            input_ids += [tokenizer.eos_token_id]
 | 
			
		||||
            labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_packed_supervised_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
 | 
			
		||||
    # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    input_ids, labels = [], []
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        messages = examples["prompt"][i] + examples["response"][i]
 | 
			
		||||
        for source_ids, target_ids in template.encode_multiturn(
 | 
			
		||||
            tokenizer, messages, examples["system"][i], examples["tools"][i]
 | 
			
		||||
        ):
 | 
			
		||||
            if data_args.train_on_prompt:
 | 
			
		||||
                source_mask = source_ids
 | 
			
		||||
            elif len(input_ids) != 0 and template.efficient_eos:
 | 
			
		||||
                source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
 | 
			
		||||
            else:
 | 
			
		||||
                source_mask = [IGNORE_INDEX] * len(source_ids)
 | 
			
		||||
 | 
			
		||||
            input_ids += source_ids + target_ids
 | 
			
		||||
            labels += source_mask + target_ids
 | 
			
		||||
 | 
			
		||||
    if template.efficient_eos:
 | 
			
		||||
        input_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    total_length = len(input_ids)
 | 
			
		||||
    block_size = data_args.cutoff_len
 | 
			
		||||
    # we drop the small remainder, and if the total_length < block_size, we exclude this batch
 | 
			
		||||
    total_length = (total_length // block_size) * block_size
 | 
			
		||||
    # split by chunks of cutoff_len
 | 
			
		||||
    for i in range(0, total_length, block_size):
 | 
			
		||||
        if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
 | 
			
		||||
            model_inputs["input_ids"].append(input_ids[i : i + block_size])
 | 
			
		||||
            model_inputs["attention_mask"].append([1] * block_size)
 | 
			
		||||
            model_inputs["labels"].append(labels[i : i + block_size])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
 | 
			
		||||
    print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
    print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
    print("label_ids:\n{}".format(example["labels"]))
 | 
			
		||||
    print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
 | 
			
		||||
							
								
								
									
										76
									
								
								src/llamafactory/data/processors/unsupervised.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								src/llamafactory/data/processors/unsupervised.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IMAGE_TOKEN
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from ..utils import Role
 | 
			
		||||
from .mm_utils import get_paligemma_token_type_ids, get_pixel_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_unsupervised_dataset(
 | 
			
		||||
    examples: Dict[str, List[Any]],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X` and labels with format `Y <eos>`
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
            examples["prompt"][i][0]["content"] = IMAGE_TOKEN + examples["prompt"][i][0]["content"]
 | 
			
		||||
 | 
			
		||||
        if len(examples["response"][i]) == 1:
 | 
			
		||||
            messages = examples["prompt"][i] + examples["response"][i]
 | 
			
		||||
        else:
 | 
			
		||||
            messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
 | 
			
		||||
 | 
			
		||||
        input_ids, labels = template.encode_oneturn(
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            messages,
 | 
			
		||||
            examples["system"][i],
 | 
			
		||||
            examples["tools"][i],
 | 
			
		||||
            data_args.cutoff_len,
 | 
			
		||||
            data_args.reserved_label_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if template.efficient_eos:
 | 
			
		||||
            labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
        if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 | 
			
		||||
            input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
 | 
			
		||||
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
    print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
 | 
			
		||||
                slot_items.append(placeholder)
 | 
			
		||||
                if slot_pieces[1]:
 | 
			
		||||
                    slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
 | 
			
		||||
        elif isinstance(slot, set):
 | 
			
		||||
            if "bos_token" in slot:
 | 
			
		||||
        elif isinstance(slot, set):  # do not use {{ eos_token }} since it may be replaced
 | 
			
		||||
            if "bos_token" in slot and tokenizer.bos_token_id is not None:
 | 
			
		||||
                slot_items.append("'" + tokenizer.bos_token + "'")
 | 
			
		||||
            elif "eos_token" in slot:  # do not use {{ eos_token }} since it may be replaced
 | 
			
		||||
            elif "eos_token" in slot and tokenizer.eos_token_id is not None:
 | 
			
		||||
                slot_items.append("'" + tokenizer.eos_token + "'")
 | 
			
		||||
        elif isinstance(slot, dict):
 | 
			
		||||
            raise ValueError("Dict is not supported.")
 | 
			
		||||
@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
 | 
			
		||||
        jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
 | 
			
		||||
        jinja_template += "{% set content = " + system_message + " + message['content'] %}"
 | 
			
		||||
        jinja_template += "{% endif %}"
 | 
			
		||||
 | 
			
		||||
    jinja_template += "{% if message['role'] == 'user' %}"
 | 
			
		||||
    user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
 | 
			
		||||
    jinja_template += "{{ " + user_message + " }}"
 | 
			
		||||
 | 
			
		||||
    jinja_template += "{% elif message['role'] == 'assistant' %}"
 | 
			
		||||
    assistant_message = _convert_slots_to_jinja(
 | 
			
		||||
        template.format_assistant.apply() + template.format_separator.apply(), tokenizer
 | 
			
		||||
@ -614,6 +616,9 @@ _register_template(
 | 
			
		||||
    name="empty",
 | 
			
		||||
    format_user=StringFormatter(slots=["{{content}}"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}"]),
 | 
			
		||||
    format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
 | 
			
		||||
    efficient_eos=True,
 | 
			
		||||
    force_system=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,8 @@ FILEEXT2TYPE = {
 | 
			
		||||
 | 
			
		||||
IGNORE_INDEX = -100
 | 
			
		||||
 | 
			
		||||
IMAGE_TOKEN = "<image>"
 | 
			
		||||
 | 
			
		||||
LAYERNORM_NAMES = {"norm", "ln"}
 | 
			
		||||
 | 
			
		||||
METHODS = ["full", "freeze", "lora"]
 | 
			
		||||
@ -714,6 +716,28 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "PaliGemma-3B-pt-224": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-pt-448": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-pt-896": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-mix-224": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-mix-448": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    vision=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Phi-1.5-1.3B": {
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ def check_dependencies() -> None:
 | 
			
		||||
        require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
 | 
			
		||||
        require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
 | 
			
		||||
        require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
 | 
			
		||||
        require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
 | 
			
		||||
        require_version("trl>=0.8.2", "To fix: pip install trl>=0.8.2")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
 | 
			
		||||
 | 
			
		||||
@ -145,7 +145,7 @@ class ModelArguments:
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "The file shard size (in GB) of the exported model."},
 | 
			
		||||
    )
 | 
			
		||||
    export_device: str = field(
 | 
			
		||||
    export_device: Literal["cpu", "cuda"] = field(
 | 
			
		||||
        default="cpu",
 | 
			
		||||
        metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -328,8 +328,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
    _check_extra_dependencies(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    if model_args.export_dir is not None:
 | 
			
		||||
        model_args.device_map = {"": torch.device(model_args.export_device)}
 | 
			
		||||
    if model_args.export_dir is not None and model_args.export_device == "cpu":
 | 
			
		||||
        model_args.device_map = {"": torch.device("cpu")}
 | 
			
		||||
    else:
 | 
			
		||||
        model_args.device_map = "auto"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import BatchEncoding, Trainer
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from trl import DPOTrainer
 | 
			
		||||
from trl.trainer.utils import disable_dropout_in_model
 | 
			
		||||
 | 
			
		||||
@ -108,14 +108,8 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
 | 
			
		||||
        Otherwise the average log probabilities.
 | 
			
		||||
        """
 | 
			
		||||
        batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()})  # avoid error
 | 
			
		||||
 | 
			
		||||
        all_logits: "torch.Tensor" = model(
 | 
			
		||||
            input_ids=batch_copied["input_ids"],
 | 
			
		||||
            attention_mask=batch_copied["attention_mask"],
 | 
			
		||||
            return_dict=True,
 | 
			
		||||
            use_cache=False,
 | 
			
		||||
        ).logits.to(torch.float32)
 | 
			
		||||
        batch_copied = {k: v.detach().clone() for k, v in batch.items()}  # avoid error
 | 
			
		||||
        all_logits: "torch.Tensor" = model(**batch_copied, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        all_logps = self.get_batch_logps(
 | 
			
		||||
            logits=all_logits,
 | 
			
		||||
 | 
			
		||||
@ -104,19 +104,23 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            kl_logits = model(
 | 
			
		||||
                input_ids=batch["kl_input_ids"],
 | 
			
		||||
                attention_mask=batch["kl_attention_mask"],
 | 
			
		||||
                return_dict=True,
 | 
			
		||||
                use_cache=False,
 | 
			
		||||
            ).logits.to(torch.float32)
 | 
			
		||||
            kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
 | 
			
		||||
            if "pixel_values" in batch:
 | 
			
		||||
                kl_model_inputs["pixel_values"] = batch["pixel_values"]
 | 
			
		||||
 | 
			
		||||
        target_logits = model(
 | 
			
		||||
            input_ids=batch["input_ids"],
 | 
			
		||||
            attention_mask=batch["attention_mask"],
 | 
			
		||||
            return_dict=True,
 | 
			
		||||
            use_cache=False,
 | 
			
		||||
        ).logits.to(torch.float32)
 | 
			
		||||
            if "kl_token_type_ids" in batch:
 | 
			
		||||
                kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
 | 
			
		||||
 | 
			
		||||
            kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
 | 
			
		||||
        if "pixel_values" in batch:
 | 
			
		||||
            model_inputs["pixel_values"] = batch["pixel_values"]
 | 
			
		||||
 | 
			
		||||
        if "token_type_ids" in batch:
 | 
			
		||||
            model_inputs["token_type_ids"] = batch["token_type_ids"]
 | 
			
		||||
 | 
			
		||||
        target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        target_logps = self.get_batch_logps(
 | 
			
		||||
            logits=target_logits,
 | 
			
		||||
 | 
			
		||||
@ -85,9 +85,7 @@ class CustomORPOTrainer(DPOTrainer):
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes the average log probabilities of the labels under the given logits.
 | 
			
		||||
        """
 | 
			
		||||
        all_logits: "torch.Tensor" = model(
 | 
			
		||||
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True, use_cache=False
 | 
			
		||||
        ).logits.to(torch.float32)
 | 
			
		||||
        all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        all_logps = self.get_batch_logps(
 | 
			
		||||
            logits=all_logits,
 | 
			
		||||
 | 
			
		||||
@ -184,14 +184,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as rlhf_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            dpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
 | 
			
		||||
            dpo_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
 | 
			
		||||
            orpo_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
 | 
			
		||||
            pref_beta = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.01)
 | 
			
		||||
            pref_ftx = gr.Slider(minimum=0, maximum=10, value=0, step=0.01)
 | 
			
		||||
            pref_loss = gr.Dropdown(choices=["sigmoid", "hinge", "ipo", "kto_pair"], value="sigmoid")
 | 
			
		||||
            reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
 | 
			
		||||
            with gr.Column():
 | 
			
		||||
                ppo_score_norm = gr.Checkbox()
 | 
			
		||||
                ppo_whiten_rewards = gr.Checkbox()
 | 
			
		||||
 | 
			
		||||
    input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
 | 
			
		||||
    input_elems.update({pref_beta, pref_ftx, pref_loss, reward_model, ppo_score_norm, ppo_whiten_rewards})
 | 
			
		||||
    elem_dict.update(
 | 
			
		||||
        dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, orpo_beta=orpo_beta, reward_model=reward_model)
 | 
			
		||||
        dict(
 | 
			
		||||
            rlhf_tab=rlhf_tab,
 | 
			
		||||
            pref_beta=pref_beta,
 | 
			
		||||
            pref_ftx=pref_ftx,
 | 
			
		||||
            pref_loss=pref_loss,
 | 
			
		||||
            reward_model=reward_model,
 | 
			
		||||
            ppo_score_norm=ppo_score_norm,
 | 
			
		||||
            ppo_whiten_rewards=ppo_whiten_rewards,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as galore_tab:
 | 
			
		||||
 | 
			
		||||
@ -774,52 +774,52 @@ LOCALES = {
 | 
			
		||||
            "label": "RLHF 参数设置",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "dpo_beta": {
 | 
			
		||||
    "pref_beta": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "DPO beta",
 | 
			
		||||
            "info": "Value of the beta parameter in the DPO loss.",
 | 
			
		||||
            "label": "Beta value",
 | 
			
		||||
            "info": "Value of the beta parameter in the loss.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "DPO бета",
 | 
			
		||||
            "info": "Значение параметра бета в функции потерь DPO.",
 | 
			
		||||
            "label": "Бета значение",
 | 
			
		||||
            "info": "Значение параметра бета в функции потерь.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "DPO beta 参数",
 | 
			
		||||
            "info": "DPO 损失函数中 beta 超参数大小。",
 | 
			
		||||
            "label": "Beta 参数",
 | 
			
		||||
            "info": "损失函数中 beta 超参数大小。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "dpo_ftx": {
 | 
			
		||||
    "pref_ftx": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "DPO-ftx weight",
 | 
			
		||||
            "info": "The weight of SFT loss in the DPO-ftx.",
 | 
			
		||||
            "label": "Ftx gamma",
 | 
			
		||||
            "info": "The weight of SFT loss in the final loss.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Вес DPO-ftx",
 | 
			
		||||
            "info": "Вес функции потерь SFT в DPO-ftx.",
 | 
			
		||||
            "label": "Ftx гамма",
 | 
			
		||||
            "info": "Вес потери SFT в итоговой потере.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "DPO-ftx 权重",
 | 
			
		||||
            "info": "DPO-ftx 中 SFT 损失的权重大小。",
 | 
			
		||||
            "label": "Ftx gamma",
 | 
			
		||||
            "info": "损失函数中 SFT 损失的权重大小。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "orpo_beta": {
 | 
			
		||||
    "pref_loss": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "ORPO beta",
 | 
			
		||||
            "info": "Value of the beta parameter in the ORPO loss.",
 | 
			
		||||
            "label": "Loss type",
 | 
			
		||||
            "info": "The type of the loss function.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "ORPO бета",
 | 
			
		||||
            "info": "Значение параметра бета в функции потерь ORPO.",
 | 
			
		||||
            "label": "Тип потерь",
 | 
			
		||||
            "info": "Тип функции потерь.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "ORPO beta 参数",
 | 
			
		||||
            "info": "ORPO 损失函数中 beta 超参数大小。",
 | 
			
		||||
            "label": "损失类型",
 | 
			
		||||
            "info": "损失函数的类型。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "reward_model": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Reward model",
 | 
			
		||||
            "info": "Adapter of the reward model for PPO training.",
 | 
			
		||||
            "info": "Adapter of the reward model in PPO training.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Модель вознаграждения",
 | 
			
		||||
@ -830,6 +830,34 @@ LOCALES = {
 | 
			
		||||
            "info": "PPO 训练中奖励模型的适配器路径。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "ppo_score_norm": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Score norm",
 | 
			
		||||
            "info": "Normalizing scores in PPO training.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Норма оценок",
 | 
			
		||||
            "info": "Нормализация оценок в тренировке PPO.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "奖励模型",
 | 
			
		||||
            "info": "PPO 训练中归一化奖励分数。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "ppo_whiten_rewards": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Whiten rewards",
 | 
			
		||||
            "info": "Whiten the rewards in PPO training.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Белые вознаграждения",
 | 
			
		||||
            "info": "Осветлите вознаграждения в обучении PPO.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "白化奖励",
 | 
			
		||||
            "info": "PPO 训练中将奖励分数做白化处理。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "galore_tab": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "GaLore configurations",
 | 
			
		||||
 | 
			
		||||
@ -145,11 +145,14 @@ class Runner:
 | 
			
		||||
            plot_loss=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # freeze config
 | 
			
		||||
        if args["finetuning_type"] == "freeze":
 | 
			
		||||
            args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
 | 
			
		||||
            args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
 | 
			
		||||
            args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
 | 
			
		||||
        elif args["finetuning_type"] == "lora":
 | 
			
		||||
 | 
			
		||||
        # lora config
 | 
			
		||||
        if args["finetuning_type"] == "lora":
 | 
			
		||||
            args["lora_rank"] = get("train.lora_rank")
 | 
			
		||||
            args["lora_alpha"] = get("train.lora_alpha")
 | 
			
		||||
            args["lora_dropout"] = get("train.lora_dropout")
 | 
			
		||||
@ -163,6 +166,7 @@ class Runner:
 | 
			
		||||
            if args["use_llama_pro"]:
 | 
			
		||||
                args["num_layer_trainable"] = get("train.num_layer_trainable")
 | 
			
		||||
 | 
			
		||||
        # rlhf config
 | 
			
		||||
        if args["stage"] == "ppo":
 | 
			
		||||
            args["reward_model"] = ",".join(
 | 
			
		||||
                [
 | 
			
		||||
@ -171,31 +175,41 @@ class Runner:
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
 | 
			
		||||
            args["ppo_score_norm"] = get("train.ppo_score_norm")
 | 
			
		||||
            args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
 | 
			
		||||
            args["top_k"] = 0
 | 
			
		||||
            args["top_p"] = 0.9
 | 
			
		||||
        elif args["stage"] == "dpo":
 | 
			
		||||
            args["dpo_beta"] = get("train.dpo_beta")
 | 
			
		||||
            args["dpo_ftx"] = get("train.dpo_ftx")
 | 
			
		||||
            args["dpo_beta"] = get("train.pref_beta")
 | 
			
		||||
            args["dpo_ftx"] = get("train.pref_ftx")
 | 
			
		||||
            args["dpo_loss"] = get("train.pref_loss")
 | 
			
		||||
        elif args["stage"] == "kto":
 | 
			
		||||
            args["kto_beta"] = get("train.pref_beta")
 | 
			
		||||
            args["kto_ftx"] = get("train.pref_ftx")
 | 
			
		||||
        elif args["stage"] == "orpo":
 | 
			
		||||
            args["orpo_beta"] = get("train.orpo_beta")
 | 
			
		||||
 | 
			
		||||
        if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
 | 
			
		||||
            args["val_size"] = get("train.val_size")
 | 
			
		||||
            args["evaluation_strategy"] = "steps"
 | 
			
		||||
            args["eval_steps"] = args["save_steps"]
 | 
			
		||||
            args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
 | 
			
		||||
            args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
 | 
			
		||||
            args["orpo_beta"] = get("train.pref_beta")
 | 
			
		||||
 | 
			
		||||
        # galore config
 | 
			
		||||
        if args["use_galore"]:
 | 
			
		||||
            args["galore_rank"] = get("train.galore_rank")
 | 
			
		||||
            args["galore_update_interval"] = get("train.galore_update_interval")
 | 
			
		||||
            args["galore_scale"] = get("train.galore_scale")
 | 
			
		||||
            args["galore_target"] = get("train.galore_target")
 | 
			
		||||
 | 
			
		||||
        # badam config
 | 
			
		||||
        if args["use_badam"]:
 | 
			
		||||
            args["badam_mode"] = get("train.badam_mode")
 | 
			
		||||
            args["badam_switch_mode"] = get("train.badam_switch_mode")
 | 
			
		||||
            args["badam_switch_interval"] = get("train.badam_switch_interval")
 | 
			
		||||
            args["badam_update_ratio"] = get("train.badam_update_ratio")
 | 
			
		||||
 | 
			
		||||
        # eval config
 | 
			
		||||
        if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
 | 
			
		||||
            args["val_size"] = get("train.val_size")
 | 
			
		||||
            args["evaluation_strategy"] = "steps"
 | 
			
		||||
            args["eval_steps"] = args["save_steps"]
 | 
			
		||||
            args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
 | 
			
		||||
 | 
			
		||||
        return args
 | 
			
		||||
 | 
			
		||||
    def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user