add rlhf-v dataset

Former-commit-id: 8e49940746c1a6ff910f07dbefbec14af9d0f3c6
This commit is contained in:
hiyouga 2024-09-01 22:57:41 +08:00
parent 236f97b35c
commit bfdcc6bacf
13 changed files with 118 additions and 33 deletions

View File

@ -291,6 +291,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)

View File

@ -292,6 +292,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k) - [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) - [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs) - [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)

View File

@ -433,6 +433,17 @@
"rejected": "rejected" "rejected": "rejected"
} }
}, },
"rlhf_v": {
"hf_hub_url": "llamafactory/RLHF-V",
"ranking": true,
"formatting": "sharegpt",
"columns": {
"messages": "conversations",
"chosen": "chosen",
"rejected": "rejected",
"images": "images"
}
},
"orca_pairs": { "orca_pairs": {
"hf_hub_url": "Intel/orca_dpo_pairs", "hf_hub_url": "Intel/orca_dpo_pairs",
"ranking": true, "ranking": true,

View File

@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### Multimodal DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
```
#### Reward Modeling #### Reward Modeling
```bash ```bash
@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO Training
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO Training #### KTO Training
```bash ```bash

View File

@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### 多模态 DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
```
#### 奖励模型训练 #### 奖励模型训练
```bash ```bash
@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO 训练
```bash
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
```
#### KTO 训练 #### KTO 训练
```bash ```bash

View File

@ -0,0 +1,41 @@
### model
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
pref_beta: 0.1
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
### dataset
dataset: rlhf_v
template: qwen2_vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2_vl-7b/lora/dpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@ -14,7 +14,7 @@
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from datasets import Features from datasets import Features
@ -33,19 +33,17 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]: def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r""" r"""
Optionally concatenates image path to dataset dir when loading from local disk. Optionally concatenates image path to dataset dir when loading from local disk.
""" """
outputs = [] images = images[:]
if dataset_attr.load_from in ["script", "file"]: if dataset_attr.load_from in ["script", "file"]:
for image in images: for i in range(len(images)):
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)): if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
outputs.append(os.path.join(data_args.dataset_dir, image)) images[i] = os.path.join(data_args.dataset_dir, images[i])
else:
outputs.append(image)
return outputs return images
def convert_alpaca( def convert_alpaca(

View File

@ -142,15 +142,15 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)], "labels": feature["{}_labels".format(key)],
} }
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
if "pixel_values" in feature: # image data are same for chosen and rejected if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"] target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature: if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"] target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
concatenated_features.append(target_feature) concatenated_features.append(target_feature)
return super().__call__(concatenated_features) return super().__call__(concatenated_features)
@ -177,16 +177,16 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"attention_mask": feature["kl_attention_mask"], "attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"], "labels": feature["kl_labels"],
} }
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
if "pixel_values" in feature: if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"] target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature: if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"] target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
target_features.append(target_feature) target_features.append(target_feature)
kl_features.append(kl_feature) kl_features.append(kl_feature)
kto_tags.append(feature["kto_tags"]) kto_tags.append(feature["kto_tags"])

View File

@ -19,6 +19,23 @@ if TYPE_CHECKING:
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
r"""
Regularizes images to avoid error. Including resizing and mode convert.
"""
images = images[:]
image_resolution = getattr(processor, "image_resolution", 512)
for i in range(len(images)):
if max(images[i].width, images[i].height) > image_resolution:
factor = image_resolution / max(images[i].width, images[i].height)
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
if images[i].mode != "RGB":
images[i] = images[i].convert("RGB")
return images
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
r""" r"""
Processes visual inputs. Processes visual inputs.
@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0: if len(images) != 0:
images = _regularize_images(images, processor)
image_inputs = image_processor(images=images, return_tensors="pt") image_inputs = image_processor(images=images, return_tensors="pt")
else: # add NoneType for fake images else: # add NoneType for fake images
image = Image.new("RGB", (64, 64), (255, 255, 255)) image = Image.new("RGB", (64, 64), (255, 255, 255))

View File

@ -138,6 +138,10 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."}, metadata={"help": "Whether or not to randomly initialize the model weights."},
) )
image_resolution: int = field(
default=512,
metadata={"help": "Keeps the height or width of image below this resolution."},
)
infer_backend: Literal["huggingface", "vllm"] = field( infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface", default="huggingface",
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},

View File

@ -99,6 +99,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
except Exception: except Exception:
processor = None processor = None

View File

@ -176,7 +176,6 @@ class CustomDPOTrainer(DPOTrainer):
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length

View File

@ -127,17 +127,16 @@ class CustomKTOTrainer(KTOTrainer):
"input_ids": batch["{}input_ids".format(prefix)], "input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)], "attention_mask": batch["{}attention_mask".format(prefix)],
} }
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
if "image_grid_thw" in batch: if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"] model_inputs["image_grid_thw"] = batch["image_grid_thw"]
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length return logps, logps / valid_length