From a83756b5e902e2943598567ebfe2177fe4b45f54 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 30 Aug 2024 02:14:31 +0800 Subject: [PATCH] refactor mm training Former-commit-id: 3382317e32f88ed377d3e7759bdeaf0f2559d22a --- README.md | 15 +- README_zh.md | 15 +- data/dataset_info.json | 14 - data/mllm_demo.json | 12 +- data/qwen2vl_demo.json | 140 --------- examples/README.md | 1 + examples/README_zh.md | 1 + examples/train_full/qwen2vl_full_sft.yaml | 40 --- examples/train_lora/qwen2vl_lora_sft.yaml | 12 +- requirements.txt | 2 +- src/llamafactory/__init__.py | 6 +- src/llamafactory/chat/hf_engine.py | 33 +-- src/llamafactory/data/__init__.py | 8 +- src/llamafactory/data/collator.py | 49 ++-- src/llamafactory/data/mm_plugin.py | 271 ++++++++++++++++++ src/llamafactory/data/processors/feedback.py | 45 ++- src/llamafactory/data/processors/pairwise.py | 47 +-- .../data/processors/processor_utils.py | 46 +-- .../data/processors/supervised.py | 71 +---- .../data/processors/unsupervised.py | 31 +- src/llamafactory/data/template.py | 45 ++- src/llamafactory/extras/constants.py | 43 ++- src/llamafactory/extras/misc.py | 2 +- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/hparams/parser.py | 2 +- .../model/model_utils/liger_kernel.py | 2 +- .../model/model_utils/longlora.py | 2 +- src/llamafactory/model/model_utils/misc.py | 7 +- src/llamafactory/model/model_utils/packing.py | 2 +- src/llamafactory/train/kto/trainer.py | 3 + src/llamafactory/train/ppo/workflow.py | 6 +- src/llamafactory/webui/runner.py | 2 +- 32 files changed, 505 insertions(+), 472 deletions(-) delete mode 100644 data/qwen2vl_demo.json delete mode 100644 examples/train_full/qwen2vl_full_sft.yaml create mode 100644 src/llamafactory/data/mm_plugin.py diff --git a/README.md b/README.md index 5c58d193..19a628d4 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Choose your path: ## Features -- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. +- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. @@ -72,14 +72,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `use_liger_kernel: true` for efficient training. +[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. + +[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training. [24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR. -[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR. -
Full Changelog +[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR. + [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. @@ -172,14 +174,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | diff --git a/README_zh.md b/README_zh.md index 6f8e5aa4..bbe6d159 100644 --- a/README_zh.md +++ b/README_zh.md @@ -48,7 +48,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 项目特色 -- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 +- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **先进算法**:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 @@ -73,14 +73,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 -[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `use_liger_kernel: true` 来加速训练。 +[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。 + +[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。 [24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。 -[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。 -
展开日志 +[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。 + [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 @@ -173,14 +175,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | gemma | +| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | diff --git a/data/dataset_info.json b/data/dataset_info.json index b2f99fe1..b00456d2 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -38,20 +38,6 @@ "assistant_tag": "assistant" } }, - "qwen2vl_demo": { - "file_name": "qwen2vl_demo.json", - "formatting": "sharegpt", - "columns": { - "messages": "messages", - "images": "images" - }, - "tags": { - "role_tag": "role", - "content_tag": "content", - "user_tag": "user", - "assistant_tag": "assistant" - } - }, "alpaca_en": { "hf_hub_url": "llamafactory/alpaca_en", "ms_hub_url": "llamafactory/alpaca_en" diff --git a/data/mllm_demo.json b/data/mllm_demo.json index 39bda392..9cc0e7e1 100644 --- a/data/mllm_demo.json +++ b/data/mllm_demo.json @@ -2,7 +2,7 @@ { "messages": [ { - "content": "Who are they?", + "content": "Who are they?", "role": "user" }, { @@ -25,7 +25,7 @@ { "messages": [ { - "content": "Who is he?", + "content": "Who is he?", "role": "user" }, { @@ -48,7 +48,7 @@ { "messages": [ { - "content": "Please describe this image", + "content": "Please describe this image", "role": "user" }, { @@ -71,7 +71,7 @@ { "messages": [ { - "content": "他们是谁?", + "content": "他们是谁?", "role": "user" }, { @@ -94,7 +94,7 @@ { "messages": [ { - "content": "他是谁?", + "content": "他是谁?", "role": "user" }, { @@ -117,7 +117,7 @@ { "messages": [ { - "content": "请描述这张图片", + "content": "请描述这张图片", "role": "user" }, { diff --git a/data/qwen2vl_demo.json b/data/qwen2vl_demo.json deleted file mode 100644 index f8ab29ff..00000000 --- a/data/qwen2vl_demo.json +++ /dev/null @@ -1,140 +0,0 @@ -[ - { - "messages": [ - { - "content": "<|image_pad|>Who are they?", - "role": "user" - }, - { - "content": "They're Kane and Gretzka from Bayern Munich.", - "role": "assistant" - }, - { - "content": "What are they doing?", - "role": "user" - }, - { - "content": "They are celebrating on the soccer field.", - "role": "assistant" - } - ], - "images": [ - "mllm_demo_data/1.jpg" - ] - }, - { - "messages": [ - { - "content": "<|image_pad|>Who is he?", - "role": "user" - }, - { - "content": "He's Thomas Muller from Bayern Munich.", - "role": "assistant" - }, - { - "content": "<|image_pad|>Why is he on the ground?", - "role": "user" - }, - { - "content": "Because he's sliding on his knees to celebrate.", - "role": "assistant" - } - ], - "images": [ - "mllm_demo_data/2.jpg","mllm_demo_data/2.jpg" - ] - }, - { - "messages": [ - { - "content": "<|image_pad|>Please describe this image", - "role": "user" - }, - { - "content": "Chinese astronaut Gui Haichao is giving a speech.", - "role": "assistant" - }, - { - "content": "What has he accomplished?", - "role": "user" - }, - { - "content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.", - "role": "assistant" - } - ], - "images": [ - "mllm_demo_data/3.jpg" - ] - }, - { - "messages": [ - { - "content": "<|image_pad|>他们是谁?", - "role": "user" - }, - { - "content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。", - "role": "assistant" - }, - { - "content": "<|image_pad|>他们在做什么?", - "role": "user" - }, - { - "content": "他们在足球场上庆祝。", - "role": "assistant" - } - ], - "images": [ - "mllm_demo_data/1.jpg","mllm_demo_data/1.jpg" - ] - }, - { - "messages": [ - { - "content": "<|image_pad|>他是谁?", - "role": "user" - }, - { - "content": "他是来自拜仁慕尼黑的托马斯·穆勒。", - "role": "assistant" - }, - { - "content": "他为什么在地上?", - "role": "user" - }, - { - "content": "因为他正在双膝跪地滑行庆祝。", - "role": "assistant" - } - ], - "images": [ - "mllm_demo_data/2.jpg" - ] - }, - { - "messages": [ - { - "content": "<|image_pad|>请描述这张图片", - "role": "user" - }, - { - "content": "中国宇航员桂海潮正在讲话。", - "role": "assistant" - }, - { - "content": "他取得过哪些成就?", - "role": "user" - }, - { - "content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。", - "role": "assistant" - } - ], - "images": [ - "mllm_demo_data/3.jpg" - ] - } -] diff --git a/examples/README.md b/examples/README.md index 34d5f198..e92bf052 100644 --- a/examples/README.md +++ b/examples/README.md @@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ```bash llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml +llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml ``` #### Reward Modeling diff --git a/examples/README_zh.md b/examples/README_zh.md index 037a7fe6..88588c3a 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml ```bash llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml +llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml ``` #### 奖励模型训练 diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml deleted file mode 100644 index ea8a3422..00000000 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ /dev/null @@ -1,40 +0,0 @@ -### model -model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf -visual_inputs: true - -### method -stage: sft -do_train: true -finetuning_type: full -deepspeed: examples/deepspeed/ds_z3_config.json - -### dataset -dataset: qwen2vl_demo -template: qwen2vl -cutoff_len: 1024 -max_samples: 1000 -overwrite_cache: true -preprocessing_num_workers: 16 - -### output -output_dir: saves/qwen2-vl-7b/full/sft -logging_steps: 10 -save_steps: 500 -plot_loss: true -overwrite_output_dir: true - -### train -per_device_train_batch_size: 1 -gradient_accumulation_steps: 1 -learning_rate: 1.0e-5 -num_train_epochs: 3.0 -lr_scheduler_type: cosine -warmup_ratio: 0.1 -bf16: true -ddp_timeout: 180000000 - -### eval -val_size: 0.1 -per_device_eval_batch_size: 1 -eval_strategy: steps -eval_steps: 500 \ No newline at end of file diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml index e893ea01..74b58922 100644 --- a/examples/train_lora/qwen2vl_lora_sft.yaml +++ b/examples/train_lora/qwen2vl_lora_sft.yaml @@ -1,5 +1,5 @@ ### model -model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf +model_name_or_path: Qwen/Qwen2-VL-7B-Instruct visual_inputs: true ### method @@ -9,23 +9,23 @@ finetuning_type: lora lora_target: all ### dataset -dataset: qwen2vl_demo -template: qwen2vl +dataset: mllm_demo +template: qwen2_vl cutoff_len: 1024 max_samples: 1000 overwrite_cache: true preprocessing_num_workers: 16 ### output -output_dir: saves/qwen2-vl-7b/lora/sft +output_dir: saves/qwen2_vl-7b/lora/sft logging_steps: 10 save_steps: 500 plot_loss: true overwrite_output_dir: true ### train -per_device_train_batch_size: 2 -gradient_accumulation_steps: 1 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 learning_rate: 1.0e-4 num_train_epochs: 3.0 lr_scheduler_type: cosine diff --git a/requirements.txt b/requirements.txt index 0cc71ae4..54d58bb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.41.2,<=4.43.4 +transformers>=4.41.2,<=4.45.0 datasets>=2.16.0,<=2.21.0 accelerate>=0.30.1,<=0.33.0 peft>=0.11.1,<=0.12.0 diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 5cd86134..ed54278f 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -20,7 +20,7 @@ Level: Dependency graph: main: - transformers>=4.41.2,<=4.43.4 + transformers>=4.41.2,<=4.44.3 datasets>=2.16.0,<=2.21.0 accelerate>=0.30.1,<=0.33.0 peft>=0.11.1,<=0.12.0 @@ -28,9 +28,9 @@ Dependency graph: attention: transformers>=4.42.4 (gemma+fa2) longlora: - transformers>=4.41.2,<=4.43.4 + transformers>=4.41.2,<=4.44.3 packing: - transformers>=4.41.2,<=4.43.4 + transformers>=4.41.2,<=4.44.3 """ from .cli import VERSION diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 6e728c2b..dabfca2a 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -22,6 +22,7 @@ import torch from transformers import GenerationConfig, TextIteratorStreamer from ..data import get_template_and_fix_tokenizer +from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer @@ -31,7 +32,6 @@ from .base_engine import BaseEngine, Response if TYPE_CHECKING: from numpy.typing import NDArray from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin - from transformers.image_processing_utils import BaseImageProcessor from trl import PreTrainedModelWrapper from ..data import Template @@ -81,27 +81,19 @@ class HuggingfaceEngine(BaseEngine): image: Optional["NDArray"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: - if ( - processor is not None - and image is not None - and not hasattr(processor, "image_seq_length") - and template.image_token not in messages[0]["content"] - ): # llava-like models - messages[0]["content"] = template.image_token + messages[0]["content"] + if image is not None: + if IMAGE_PLACEHOLDER not in messages[0]["content"]: + messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] + + messages = template.mm_plugin.process_messages(messages, [image], processor) paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or generating_args["default_system"] - pixel_values = None prompt_ids, _ = template.encode_oneturn( tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools ) - if processor is not None and image is not None: # add image features - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - batch_feature = image_processor(image, return_tensors="pt") - pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W) - if hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + if image is not None: + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) @@ -164,8 +156,13 @@ class HuggingfaceEngine(BaseEngine): logits_processor=get_logits_processor(), ) - if pixel_values is not None: - gen_kwargs["pixel_values"] = pixel_values + if image is not None: + mm_inputs = template.mm_plugin.get_mm_inputs( + images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor + ) + for key, value in mm_inputs.items(): + value = value if isinstance(value, torch.Tensor) else torch.tensor(value) + gen_kwargs[key] = value.to(model.device) return gen_kwargs, prompt_length diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 4da742b4..9bfd9708 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask +from .collator import ( + CustomDataCollatorForSeq2Seq, + KTODataCollatorWithPadding, + PairwiseDataCollatorWithPadding, + SFTDataCollatorWith4DAttentionMask, +) from .data_utils import Role, split_dataset from .loader import get_dataset from .template import TEMPLATES, Template, get_template_and_fix_tokenizer __all__ = [ + "CustomDataCollatorForSeq2Seq", "KTODataCollatorWithPadding", "PairwiseDataCollatorWithPadding", "SFTDataCollatorWith4DAttentionMask", diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index a4197a20..0885705a 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype @dataclass -class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): +class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): r""" - Data collator for 4d attention mask. + Data collator for custom models (like Qwen2-VL). """ - block_diag_attn: bool = False - attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" - compute_dtype: "torch.dtype" = torch.float32 - def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: image_grid_thw = None if "image_grid_thw" in features[0]: @@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0 ] if image_grid_thw_list: - image_grid_thw = torch.cat(image_grid_thw_list, 0) + image_grid_thw = torch.cat(image_grid_thw_list, dim=0) + pixel_values = torch.cat(pixel_values_list, dim=0) else: - # Handle the case where the list is empty, for example: image_grid_thw = None - if pixel_values_list: - pixel_values = torch.cat(pixel_values_list, 0) - else: - # Handle the case where the list is empty, for example: pixel_values = None + features = [ {key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]} for feature in features ] features = super().__call__(features) - if self.block_diag_attn and self.attn_implementation != "flash_attention_2": - features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) if image_grid_thw is not None: features["image_grid_thw"] = image_grid_thw features["pixel_values"] = pixel_values @@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): @dataclass -class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): +class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq): + r""" + Data collator for 4d attention mask. + """ + + block_diag_attn: bool = False + attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" + compute_dtype: "torch.dtype" = torch.float32 + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + features = super().__call__(features) + if self.block_diag_attn and self.attn_implementation != "flash_attention_2": + features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + + return features + + +@dataclass +class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq): r""" Data collator for pairwise data. """ @@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): "attention_mask": feature["{}_attention_mask".format(key)], "labels": feature["{}_labels".format(key)], } - if "pixel_values" in feature: + if "pixel_values" in feature: # image data are same for chosen and rejected target_feature["pixel_values"] = feature["pixel_values"] + if "image_grid_thw" in feature: + target_feature["image_grid_thw"] = feature["image_grid_thw"] + if "{}_token_type_ids".format(key) in feature: target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] @@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): @dataclass -class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): +class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq): r""" Data collator for KTO data. """ @@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): if "pixel_values" in feature: target_feature["pixel_values"] = feature["pixel_values"] + if "image_grid_thw" in feature: + target_feature["image_grid_thw"] = feature["image_grid_thw"] + if "token_type_ids" in feature: target_feature["token_type_ids"] = feature["token_type_ids"] kl_feature["token_type_ids"] = feature["kl_token_type_ids"] diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py new file mode 100644 index 00000000..714f09fb --- /dev/null +++ b/src/llamafactory/data/mm_plugin.py @@ -0,0 +1,271 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from PIL.Image import Image +from transformers import ProcessorMixin + +from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER +from ..extras.packages import is_pillow_available + + +if is_pillow_available(): + import torch + from PIL import Image + + +if TYPE_CHECKING: + from PIL.Image import Image as ImageObject + from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers.image_processing_utils import BaseImageProcessor + + +def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor": + r""" + Processes visual inputs. (currently only supports a single image) + + Returns: + pixel_values: tensor with shape (B, C, H, W) + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) + return image_processor([image], return_tensors="pt")["pixel_values"] + + +def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]: + r""" + Gets paligemma token type ids for computing loss. + + Returns: + token_type_ids: shape (1, seq_len) + """ + image_seq_length = getattr(processor, "image_seq_length") + return [[0] * image_seq_length + [1] * (input_len - image_seq_length)] + + +def get_qwen2vl_image_inputs( + images: Sequence["ImageObject"], processor: "ProcessorMixin" +) -> Dict[str, "torch.Tensor"]: + r""" + Processes qwen2-vl visual inputs. Supports multiple images. + + Returns: + pixel_values: tensor with shape (num_patches, patch_dim) + image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height + + It holds num_patches == torch.prod(image_grid_thw) + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + if len(images) != 0: + image_inputs = image_processor(images=images, return_tensors="pt") + else: + image = Image.new("RGB", (56, 56), (255, 255, 255)) + image_inputs = image_processor(images=[image], return_tensors="pt") + image_inputs["image_grid_thw"][0][0] = 0 # fake image + + return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]} + + +class BasePlugin: + def __init__(self, image_token: str) -> None: + self.image_token = image_token + + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageObject"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + return messages + + def process_token_ids( + self, + input_ids: List[int], + labels: Optional[List[int]], + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + ) -> Tuple[List[int], Optional[List[int]]]: + return input_ids, labels + + def get_mm_inputs( + self, + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Any]: + return {} + + def process_model_inputs( + self, + model_inputs: Dict[str, List[Any]], + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> None: + return + + +class LlavaPlugin(BasePlugin): + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageObject"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + image_count = 0 + new_messages = [] + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_count += 1 + if image_count > 1: + raise ValueError("Llava model only accepts one image per sample.") + + content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1) + + new_messages.append({"role": message["role"], "content": content}) + + return new_messages + + def get_mm_inputs( + self, + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Any]: + return {"pixel_values": get_pixel_values(images, processor)} + + def process_model_inputs( + self, + model_inputs: Dict[str, List[Any]], + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> None: + mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) + model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0]) + + +class PaliGemmaPlugin(BasePlugin): + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageObject"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + image_count = 0 + new_messages = [] + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_count += 1 + if image_count > 1: + raise ValueError("PaliGemma model only accepts one image per sample.") + + content = content.replace(IMAGE_PLACEHOLDER, "", 1) + + new_messages.append({"role": message["role"], "content": content}) + + return new_messages + + def process_token_ids( + self, + input_ids: List[int], + labels: Optional[List[int]], + tokenizer: "PreTrainedTokenizer", + processor: Optional["ProcessorMixin"], + ) -> Tuple[List[int], Optional[List[int]]]: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + image_seq_length: int = getattr(image_processor, "image_seq_length") + image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + input_ids = [image_token_id] * image_seq_length + input_ids + if labels is not None: + labels = [IGNORE_INDEX] * image_seq_length + labels + + return input_ids, labels + + def get_mm_inputs( + self, + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Any]: + mm_inputs = {"pixel_values": get_pixel_values(images, processor)} + for feature_name, feature_length in feature_seqlens.items(): + mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor) + + return mm_inputs + + def process_model_inputs( + self, + model_inputs: Dict[str, List[Any]], + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> None: + mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) + model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0]) + for feature_name in feature_seqlens.keys(): + model_inputs[feature_name].append(mm_inputs[feature_name][0]) + + +class Qwen2vlPlugin(BasePlugin): + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageObject"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + merge_length: int = getattr(image_processor, "merge_size") ** 2 + if len(images) > 0: + image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"] + + index = 0 + new_messages = [] + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace( + IMAGE_PLACEHOLDER, + "<|vision_start|>{}<|vision_end|>".format( + self.image_token * (image_grid_thw[index].prod() // merge_length) + ), + 1, + ) + index += 1 + + new_messages.append({"role": message["role"], "content": content}) + + return new_messages + + def get_mm_inputs( + self, + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Any]: + return get_qwen2vl_image_inputs(images, processor) + + def process_model_inputs( + self, + model_inputs: Dict[str, List[Any]], + images: Sequence["ImageObject"], + feature_seqlens: Dict[str, int], + processor: Optional["ProcessorMixin"], + ) -> None: + mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor) + model_inputs["pixel_values"].append(mm_inputs["pixel_values"]) + model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"]) + + +PLUGINS = { + "llava": LlavaPlugin, + "paligemma": PaliGemmaPlugin, + "qwen2_vl": Qwen2vlPlugin, +} + + +def get_mm_plugin(name: str, image_token: str) -> "BasePlugin": + if name not in PLUGINS: + raise ValueError("{} not found.".format(name)) + + return PLUGINS[name](image_token) diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index bed0c33c..c09ef488 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen +from .processor_utils import infer_seqlen if TYPE_CHECKING: @@ -40,9 +41,6 @@ def _encode_feedback_example( processor: Optional["ProcessorMixin"], cutoff_len: int, ) -> Tuple[List[int], List[int], List[int], List[int], bool]: - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - prompt[0]["content"] = template.image_token + prompt[0]["content"] - if response[0]["content"]: # desired example kto_tag = True messages = prompt + [response[0]] @@ -62,10 +60,8 @@ def _encode_feedback_example( response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids - kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) + kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) prompt_ids = prompt_ids[:source_len] @@ -91,28 +87,15 @@ def preprocess_feedback_dataset( ) -> Dict[str, List[List[int]]]: # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs kl_response = examples["response"][::-1] - model_inputs = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "kl_input_ids": [], - "kl_attention_mask": [], - "kl_labels": [], - "kto_tags": [], - } - if processor is not None: - model_inputs["pixel_values"] = [] - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"] = [] - model_inputs["kl_token_type_ids"] = [] - + model_inputs = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue + prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( - prompt=examples["prompt"][i], + prompt=prompt, response=examples["response"][i], kl_response=kl_response[i], system=examples["system"][i], @@ -129,11 +112,15 @@ def preprocess_feedback_dataset( model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_labels"].append(kl_labels) model_inputs["kto_tags"].append(kto_tag) - if processor is not None: - model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) - model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor)) + template.mm_plugin.process_model_inputs( + model_inputs=model_inputs, + images=examples["images"][i], + feature_seqlens={ + "token_type_ids": len(input_ids), + "kl_token_type_ids": len(kl_input_ids), + }, + processor=processor, + ) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) undesirable_num = len(model_inputs["kto_tags"]) - desirable_num diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index ddf885b5..fec25783 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen +from .processor_utils import infer_seqlen if TYPE_CHECKING: @@ -39,9 +40,6 @@ def _encode_pairwise_example( processor: Optional["ProcessorMixin"], cutoff_len: int, ) -> Tuple[List[int], List[int], List[int], List[int]]: - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - prompt[0]["content"] = template.image_token + prompt[0]["content"] - chosen_messages = prompt + [response[0]] rejected_messages = prompt + [response[1]] prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) @@ -51,10 +49,7 @@ def _encode_pairwise_example( chosen_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids - + prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor) # consider the response is more important source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) prompt_ids = prompt_ids[:source_len] @@ -77,27 +72,15 @@ def preprocess_pairwise_dataset( data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build input pairs with format ` X`, `Y1 ` and `Y2 ` - model_inputs = { - "chosen_input_ids": [], - "chosen_attention_mask": [], - "chosen_labels": [], - "rejected_input_ids": [], - "rejected_attention_mask": [], - "rejected_labels": [], - } - if processor is not None: - model_inputs["pixel_values"] = [] - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["chosen_token_type_ids"] = [] - model_inputs["rejected_token_type_ids"] = [] - + model_inputs = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue + prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( - prompt=examples["prompt"][i], + prompt=prompt, response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], @@ -112,15 +95,15 @@ def preprocess_pairwise_dataset( model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_labels"].append(rejected_labels) - if processor is not None: - model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["chosen_token_type_ids"].append( - get_paligemma_token_type_ids(len(chosen_input_ids), processor) - ) - model_inputs["rejected_token_type_ids"].append( - get_paligemma_token_type_ids(len(rejected_input_ids), processor) - ) + template.mm_plugin.process_model_inputs( + model_inputs=model_inputs, + images=examples["images"][i], + feature_seqlens={ + "chosen_token_type_ids": len(chosen_input_ids), + "rejected_token_type_ids": len(rejected_input_ids), + }, + processor=processor, + ) return model_inputs diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 845ef8f8..8e13d100 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -13,20 +13,7 @@ # limitations under the License. import bisect -from typing import TYPE_CHECKING, List, Sequence, Tuple - -from ...extras.packages import is_pillow_available - - -if is_pillow_available(): - from PIL import Image - - -if TYPE_CHECKING: - from numpy.typing import NDArray - from PIL.Image import Image as ImageObject - from transformers import ProcessorMixin - from transformers.image_processing_utils import BaseImageProcessor +from typing import List, Sequence, Tuple def search_for_fit(numbers: Sequence[int], capacity: int) -> int: @@ -61,37 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]: return knapsacks -def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": - r""" - Processes visual inputs. (currently only supports a single image) - """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255)) - return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W) - - -def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]: - r""" - Gets paligemma token type ids for computing loss. - """ - image_seq_length = getattr(processor, "image_seq_length") - return [0] * image_seq_length + [1] * (input_len - image_seq_length) - - -def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": - r""" - Processes visual inputs. support multi images - """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - if len(images) != 0: - image_inputs = image_processor(images=images, return_tensors="pt") - else: - image = Image.new("RGB", (56, 56), (255, 255, 255)) - image_inputs = image_processor(images=[image], return_tensors="pt") - image_inputs["image_grid_thw"][0][0] = 0 - return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]} - - def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: r""" Computes the real sequence length after truncation by the cutoff_len. diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index f324036e..6f857d24 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -17,17 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import ( - get_paligemma_token_type_ids, - get_pixel_values, - get_qwen2vl_image_inputs, - greedy_knapsack, - infer_seqlen, -) +from .processor_utils import greedy_knapsack, infer_seqlen if TYPE_CHECKING: - from PIL.Image import Image as ImageObject from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments @@ -43,41 +36,15 @@ def _encode_supervised_example( system: Optional[str], tools: Optional[str], template: "Template", - images: Sequence["ImageObject"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, train_on_prompt: bool, mask_history: bool, ) -> Tuple[List[int], List[int]]: - if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models - image_processor = getattr(processor, "image_processor") - merge_length = image_processor.merge_size**2 - if len(images) > 0: - image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"] - index = 0 - for message in prompt: - content = message["content"] - while "<|image_pad|>" in content: - content = content.replace( - "<|image_pad|>", - template.vision_start_token - + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length) - + template.vision_end_token, - 1, - ) - index += 1 - message["content"] = content.replace("<|placeholder|>", "<|image_pad|>") - elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - prompt[0]["content"] = template.image_token + prompt[0]["content"] - messages = prompt + response input_ids, labels = [], [] - - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - input_ids += [image_token_id] * getattr(processor, "image_seq_length") - labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length") + input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) total_length = 1 if template.efficient_eos else 0 @@ -125,28 +92,21 @@ def preprocess_supervised_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[Any]]: # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if processor is not None: - model_inputs["pixel_values"] = [] - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"] = [] - if "image_grid_thw" in processor.model_input_names: # qwen2_vl models - model_inputs["image_grid_thw"] = [] - + model_inputs = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue + prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) input_ids, labels = _encode_supervised_example( - prompt=examples["prompt"][i], + prompt=prompt, response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], - images=examples["images"][i], template=template, tokenizer=tokenizer, processor=processor, @@ -157,15 +117,12 @@ def preprocess_supervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - if processor is not None: - if "image_grid_thw" in processor.model_input_names: # qwen2_vl models - image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor) - model_inputs["pixel_values"].append(image_inputs["pixel_values"]) - model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"]) - else: - model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + template.mm_plugin.process_model_inputs( + model_inputs=model_inputs, + images=examples["images"][i], + feature_seqlens={"token_type_ids": len(input_ids)}, + processor=processor, + ) return model_inputs @@ -175,7 +132,7 @@ def preprocess_packed_supervised_dataset( template: "Template", tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", -) -> Dict[str, List[List[int]]]: +) -> Dict[str, List[Any]]: # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` valid_num = 0 @@ -209,7 +166,7 @@ def preprocess_packed_supervised_dataset( batch_labels.append(labels) valid_num += 1 - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + model_inputs = defaultdict(list) knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token for knapsack in knapsacks: packed_input_ids, packed_attention_masks, packed_labels = [], [], [] diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 7bd1904b..cf9ff643 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.logging import get_logger from ..data_utils import Role -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen +from .processor_utils import infer_seqlen if TYPE_CHECKING: @@ -39,9 +40,6 @@ def _encode_unsupervised_example( processor: Optional["ProcessorMixin"], cutoff_len: int, ) -> Tuple[List[int], List[int]]: - if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models - prompt[0]["content"] = template.image_token + prompt[0]["content"] - if len(response) == 1: messages = prompt + response else: @@ -51,10 +49,7 @@ def _encode_unsupervised_example( if template.efficient_eos: labels += [tokenizer.eos_token_id] - if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models - image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) - input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids - + input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) input_ids = input_ids[:source_len] labels = labels[:target_len] @@ -69,19 +64,15 @@ def preprocess_unsupervised_dataset( data_args: "DataArguments", ) -> Dict[str, List[List[int]]]: # build inputs with format ` X` and labels with format `Y ` - model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - if processor is not None: - model_inputs["pixel_values"] = [] - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"] = [] - + model_inputs = defaultdict(list) for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1: logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) continue + prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor) input_ids, labels = _encode_unsupervised_example( - prompt=examples["prompt"][i], + prompt=prompt, response=examples["response"][i], system=examples["system"][i], tools=examples["tools"][i], @@ -93,10 +84,12 @@ def preprocess_unsupervised_dataset( model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) - if processor is not None: - model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + template.mm_plugin.process_model_inputs( + model_inputs=model_inputs, + images=examples["images"][i], + feature_seqlens={"token_type_ids": len(input_ids)}, + processor=processor, + ) return model_inputs diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index cce97357..fe0a104f 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -15,9 +15,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union +from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +from .mm_plugin import BasePlugin, get_mm_plugin if TYPE_CHECKING: @@ -41,11 +43,9 @@ class Template: format_prefix: "Formatter" default_system: str stop_words: List[str] - image_token: str - vision_start_token: str - vision_end_token: str efficient_eos: bool replace_eos: bool + mm_plugin: "BasePlugin" def encode_oneturn( self, @@ -207,11 +207,9 @@ def _register_template( format_prefix: Optional["Formatter"] = None, default_system: str = "", stop_words: Sequence[str] = [], - image_token: str = "", - vision_start_token: str = "<|vision_start|>", - vision_end_token: str = "<|vision_end|>", efficient_eos: bool = False, replace_eos: bool = False, + mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER), ) -> None: r""" Registers a chat template. @@ -258,11 +256,9 @@ def _register_template( format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, stop_words=stop_words, - image_token=image_token, - vision_start_token=vision_start_token, - vision_end_token=vision_end_token, efficient_eos=efficient_eos, replace_eos=replace_eos, + mm_plugin=mm_plugin, ) @@ -722,6 +718,17 @@ _register_template( ) +_register_template( + name="llava", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava", image_token=""), +) + + _register_template( name="mistral", format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), @@ -766,6 +773,19 @@ _register_template( ) +_register_template( + name="paligemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, + mm_plugin=get_mm_plugin(name="paligemma", image_token=""), +) + + _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), @@ -790,17 +810,15 @@ _register_template( _register_template( - name="qwen2vl", + name="qwen2_vl", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", - image_token="<|image_pad|>", - vision_start_token="<|vision_start|>", - vision_end_token="<|vision_end|>", stop_words=["<|im_end|>"], replace_eos=True, + mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"), ) @@ -915,6 +933,7 @@ _register_template( ), stop_words=["###"], efficient_eos=True, + mm_plugin=get_mm_plugin(name="llava", image_token=""), ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 80df135c..fc2d3460 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -47,6 +47,8 @@ FILEEXT2TYPE = { IGNORE_INDEX = -100 +IMAGE_PLACEHOLDER = "" + LAYERNORM_NAMES = {"norm", "ln"} LLAMABOARD_CONFIG = "llamaboard_config.yaml" @@ -785,7 +787,7 @@ register_model_group( DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", }, }, - template="vicuna", + template="llava", vision=True, ) @@ -930,27 +932,28 @@ register_model_group( register_model_group( models={ - "PaliGemma-3B-pt-224": { + "PaliGemma-3B-pt-224-Chat": { DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224", }, - "PaliGemma-3B-pt-448": { + "PaliGemma-3B-pt-448-Chat": { DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448", }, - "PaliGemma-3B-pt-896": { + "PaliGemma-3B-pt-896-Chat": { DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896", }, - "PaliGemma-3B-mix-224": { + "PaliGemma-3B-mix-224-Chat": { DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224", }, - "PaliGemma-3B-mix-448": { + "PaliGemma-3B-mix-448-Chat": { DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448", }, }, + template="paligemma", vision=True, ) @@ -1329,6 +1332,34 @@ register_model_group( ) +register_model_group( + models={ + "Qwen2VL-2B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct", + }, + "Qwen2VL-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct", + }, + "Qwen2VL-2B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", + }, + "Qwen2VL-2B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ", + }, + "Qwen2VL-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8", + }, + "Qwen2VL-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ", + }, + }, + template="qwen2_vl", + vision=True, +) + + register_model_group( models={ "SOLAR-10.7B": { diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 8908b807..d3105e65 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,7 +79,7 @@ def check_dependencies() -> None: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: - require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") + require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0") require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index f209e338..b25e31e0 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -117,7 +117,7 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, ) - use_liger_kernel: bool = field( + enable_liger_kernel: bool = field( default=False, metadata={"help": "Whether or not to enable liger kernel for faster training."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 00a4c72c..7e3e39cd 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -116,7 +116,7 @@ def _check_extra_dependencies( if model_args.use_unsloth: require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") - if model_args.use_liger_kernel: + if model_args.enable_liger_kernel: require_version("liger-kernel", "To fix: pip install liger-kernel") if model_args.mixture_of_depths is not None: diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index e40169ad..31edd97c 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -27,7 +27,7 @@ logger = get_logger(__name__) def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: - if not is_trainable or not model_args.use_liger_kernel: + if not is_trainable or not model_args.enable_liger_kernel: return if getattr(config, "model_type", None) == "gemma": diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index e518aefb..ef39bcd9 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") + require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 96233002..d49222a3 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -36,11 +36,14 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output") elif model.config.model_type in ["llava", "paligemma"]: forbidden_modules.add("multi_modal_projector") - elif model.config.model_type in ["qwen2_vl"]: + elif model.config.model_type == "qwen2_vl": forbidden_modules.add("merger") if freeze_vision_tower: - forbidden_modules.add("vision_tower") + if model.config.model_type == "qwen2_vl": + forbidden_modules.add("visual") + else: + forbidden_modules.add("vision_tower") module_names = set() for name, module in model.named_modules(): diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index ded7f295..3d7f2dad 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor def _patch_for_block_diag_attn(model_type: str) -> None: - require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") + require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3") if is_transformers_version_greater_than_4_43(): import transformers.modeling_flash_attention_utils diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index deb3fce2..5cb2518f 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer): if "pixel_values" in batch: model_inputs["pixel_values"] = batch["pixel_values"] + if "image_grid_thw" in batch: + model_inputs["image_grid_thw"] = batch["image_grid_thw"] + if "{}token_type_ids".format(prefix) in batch: model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 6cea52d9..a2685f33 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -17,9 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from transformers import DataCollatorWithPadding - -from ...data import get_dataset +from ...data import CustomDataCollatorForSeq2Seq, get_dataset from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..callbacks import fix_valuehead_checkpoint @@ -47,7 +45,7 @@ def run_ppo( model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training - data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer) # Create reference model and reward model ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 72176986..67d910fa 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -115,7 +115,7 @@ class Runner: rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), - use_liger_kernel=(get("top.booster") == "liger_kernel"), + enable_liger_kernel=(get("top.booster") == "liger_kernel"), visual_inputs=get("top.visual_inputs"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")),