diff --git a/README.md b/README.md index 33e3fe2c..a4805c8a 100644 --- a/README.md +++ b/README.md @@ -291,6 +291,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) +- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) diff --git a/README_zh.md b/README_zh.md index 9465f027..883c7945 100644 --- a/README_zh.md +++ b/README_zh.md @@ -292,6 +292,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) +- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) diff --git a/data/dataset_info.json b/data/dataset_info.json index b00456d2..02597150 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -433,6 +433,17 @@ "rejected": "rejected" } }, + "rlhf_v": { + "hf_hub_url": "llamafactory/RLHF-V", + "ranking": true, + "formatting": "sharegpt", + "columns": { + "messages": "conversations", + "chosen": "chosen", + "rejected": "rejected", + "images": "images" + } + }, "orca_pairs": { "hf_hub_url": "Intel/orca_dpo_pairs", "ranking": true, diff --git a/examples/README.md b/examples/README.md index d6dccb1c..5df1886f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml ``` +#### DPO/ORPO/SimPO Training + +```bash +llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml +``` + +#### Multimodal DPO/ORPO/SimPO Training + +```bash +llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml +``` + #### Reward Modeling ```bash @@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml ``` -#### DPO/ORPO/SimPO Training - -```bash -llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml -``` - #### KTO Training ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index 037136a1..46d43402 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml ``` +#### DPO/ORPO/SimPO 训练 + +```bash +llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml +``` + +#### 多模态 DPO/ORPO/SimPO 训练 + +```bash +llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml +``` + #### 奖励模型训练 ```bash @@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml ``` -#### DPO/ORPO/SimPO 训练 - -```bash -llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml -``` - #### KTO 训练 ```bash diff --git a/examples/train_lora/qwen2vl_lora_dpo.yaml b/examples/train_lora/qwen2vl_lora_dpo.yaml new file mode 100644 index 00000000..4ff72cea --- /dev/null +++ b/examples/train_lora/qwen2vl_lora_dpo.yaml @@ -0,0 +1,41 @@ +### model +model_name_or_path: Qwen/Qwen2-VL-7B-Instruct + +### method +stage: dpo +do_train: true +finetuning_type: lora +lora_target: all +pref_beta: 0.1 +pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo] + +### dataset +dataset: rlhf_v +template: qwen2_vl +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/qwen2_vl-7b/lora/dpo +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 5.0e-6 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index ef70d75b..a3440ff5 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -14,7 +14,7 @@ import os from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union from datasets import Features @@ -33,19 +33,17 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: +def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: r""" Optionally concatenates image path to dataset dir when loading from local disk. """ - outputs = [] + images = images[:] if dataset_attr.load_from in ["script", "file"]: - for image in images: - if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)): - outputs.append(os.path.join(data_args.dataset_dir, image)) - else: - outputs.append(image) + for i in range(len(images)): + if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])): + images[i] = os.path.join(data_args.dataset_dir, images[i]) - return outputs + return images def convert_alpaca( diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 968d7018..eecf9052 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -142,15 +142,15 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): "attention_mask": feature["{}_attention_mask".format(key)], "labels": feature["{}_labels".format(key)], } + if "{}_token_type_ids".format(key) in feature: + target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] + if "pixel_values" in feature: # image data are same for chosen and rejected target_feature["pixel_values"] = feature["pixel_values"] if "image_grid_thw" in feature: target_feature["image_grid_thw"] = feature["image_grid_thw"] - if "{}_token_type_ids".format(key) in feature: - target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] - concatenated_features.append(target_feature) return super().__call__(concatenated_features) @@ -177,16 +177,16 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): "attention_mask": feature["kl_attention_mask"], "labels": feature["kl_labels"], } + if "token_type_ids" in feature: + target_feature["token_type_ids"] = feature["token_type_ids"] + kl_feature["token_type_ids"] = feature["kl_token_type_ids"] + if "pixel_values" in feature: target_feature["pixel_values"] = feature["pixel_values"] if "image_grid_thw" in feature: target_feature["image_grid_thw"] = feature["image_grid_thw"] - if "token_type_ids" in feature: - target_feature["token_type_ids"] = feature["token_type_ids"] - kl_feature["token_type_ids"] = feature["kl_token_type_ids"] - target_features.append(target_feature) kl_features.append(kl_feature) kto_tags.append(feature["kto_tags"]) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index acd81ca0..a3636737 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -19,6 +19,23 @@ if TYPE_CHECKING: from transformers.image_processing_utils import BaseImageProcessor +def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]: + r""" + Regularizes images to avoid error. Including resizing and mode convert. + """ + images = images[:] + image_resolution = getattr(processor, "image_resolution", 512) + for i in range(len(images)): + if max(images[i].width, images[i].height) > image_resolution: + factor = image_resolution / max(images[i].width, images[i].height) + images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor))) + + if images[i].mode != "RGB": + images[i] = images[i].convert("RGB") + + return images + + def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: r""" Processes visual inputs. @@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") if len(images) != 0: + images = _regularize_images(images, processor) image_inputs = image_processor(images=images, return_tensors="pt") else: # add NoneType for fake images image = Image.new("RGB", (64, 64), (255, 255, 255)) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 65f0fa62..eddb2b1d 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -138,6 +138,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to randomly initialize the model weights."}, ) + image_resolution: int = field( + default=512, + metadata={"help": "Keeps the height or width of image below this resolution."}, + ) infer_backend: Literal["huggingface", "vllm"] = field( default="huggingface", metadata={"help": "Backend engine used at inference."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 8ca8efdf..ed1f5741 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -99,6 +99,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) setattr(processor, "tokenizer", tokenizer) setattr(processor, "image_seqlen", get_image_seqlen(config)) + setattr(processor, "image_resolution", model_args.image_resolution) except Exception: processor = None diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index e9ba896c..ae2dec6a 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -176,7 +176,6 @@ class CustomDPOTrainer(DPOTrainer): batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) - all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) if self.loss_type in ["ipo", "orpo", "simpo"]: all_logps = all_logps / valid_length diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 5cb2518f..16136d6a 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -127,17 +127,16 @@ class CustomKTOTrainer(KTOTrainer): "input_ids": batch["{}input_ids".format(prefix)], "attention_mask": batch["{}attention_mask".format(prefix)], } + if "{}token_type_ids".format(prefix) in batch: + model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] + if "pixel_values" in batch: model_inputs["pixel_values"] = batch["pixel_values"] if "image_grid_thw" in batch: model_inputs["image_grid_thw"] = batch["image_grid_thw"] - if "{}token_type_ids".format(prefix) in batch: - model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] - logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) - logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) return logps, logps / valid_length