mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
add rlhf-v dataset
Former-commit-id: 8e49940746c1a6ff910f07dbefbec14af9d0f3c6
This commit is contained in:
parent
236f97b35c
commit
bfdcc6bacf
@ -291,6 +291,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
|||||||
|
|
||||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||||
|
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
|
||||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
@ -292,6 +292,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
|
|
||||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
|
||||||
|
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
|
||||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
@ -433,6 +433,17 @@
|
|||||||
"rejected": "rejected"
|
"rejected": "rejected"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"rlhf_v": {
|
||||||
|
"hf_hub_url": "llamafactory/RLHF-V",
|
||||||
|
"ranking": true,
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "conversations",
|
||||||
|
"chosen": "chosen",
|
||||||
|
"rejected": "rejected",
|
||||||
|
"images": "images"
|
||||||
|
}
|
||||||
|
},
|
||||||
"orca_pairs": {
|
"orca_pairs": {
|
||||||
"hf_hub_url": "Intel/orca_dpo_pairs",
|
"hf_hub_url": "Intel/orca_dpo_pairs",
|
||||||
"ranking": true,
|
"ranking": true,
|
||||||
|
@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
|||||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### DPO/ORPO/SimPO Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Multimodal DPO/ORPO/SimPO Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### Reward Modeling
|
#### Reward Modeling
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
|||||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO Training
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### KTO Training
|
#### KTO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
|||||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### DPO/ORPO/SimPO 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 多模态 DPO/ORPO/SimPO 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### 奖励模型训练
|
#### 奖励模型训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
|
|||||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### DPO/ORPO/SimPO 训练
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### KTO 训练
|
#### KTO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
41
examples/train_lora/qwen2vl_lora_dpo.yaml
Normal file
41
examples/train_lora/qwen2vl_lora_dpo.yaml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: dpo
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_target: all
|
||||||
|
pref_beta: 0.1
|
||||||
|
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: rlhf_v
|
||||||
|
template: qwen2_vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2_vl-7b/lora/dpo
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 5.0e-6
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||||
|
|
||||||
from datasets import Features
|
from datasets import Features
|
||||||
|
|
||||||
@ -33,19 +33,17 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||||
r"""
|
r"""
|
||||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||||
"""
|
"""
|
||||||
outputs = []
|
images = images[:]
|
||||||
if dataset_attr.load_from in ["script", "file"]:
|
if dataset_attr.load_from in ["script", "file"]:
|
||||||
for image in images:
|
for i in range(len(images)):
|
||||||
if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
|
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
||||||
outputs.append(os.path.join(data_args.dataset_dir, image))
|
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
||||||
else:
|
|
||||||
outputs.append(image)
|
|
||||||
|
|
||||||
return outputs
|
return images
|
||||||
|
|
||||||
|
|
||||||
def convert_alpaca(
|
def convert_alpaca(
|
||||||
|
@ -142,15 +142,15 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
|||||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||||
"labels": feature["{}_labels".format(key)],
|
"labels": feature["{}_labels".format(key)],
|
||||||
}
|
}
|
||||||
|
if "{}_token_type_ids".format(key) in feature:
|
||||||
|
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||||
|
|
||||||
if "pixel_values" in feature: # image data are same for chosen and rejected
|
if "pixel_values" in feature: # image data are same for chosen and rejected
|
||||||
target_feature["pixel_values"] = feature["pixel_values"]
|
target_feature["pixel_values"] = feature["pixel_values"]
|
||||||
|
|
||||||
if "image_grid_thw" in feature:
|
if "image_grid_thw" in feature:
|
||||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||||
|
|
||||||
if "{}_token_type_ids".format(key) in feature:
|
|
||||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
|
||||||
|
|
||||||
concatenated_features.append(target_feature)
|
concatenated_features.append(target_feature)
|
||||||
|
|
||||||
return super().__call__(concatenated_features)
|
return super().__call__(concatenated_features)
|
||||||
@ -177,16 +177,16 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
|||||||
"attention_mask": feature["kl_attention_mask"],
|
"attention_mask": feature["kl_attention_mask"],
|
||||||
"labels": feature["kl_labels"],
|
"labels": feature["kl_labels"],
|
||||||
}
|
}
|
||||||
|
if "token_type_ids" in feature:
|
||||||
|
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||||
|
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||||
|
|
||||||
if "pixel_values" in feature:
|
if "pixel_values" in feature:
|
||||||
target_feature["pixel_values"] = feature["pixel_values"]
|
target_feature["pixel_values"] = feature["pixel_values"]
|
||||||
|
|
||||||
if "image_grid_thw" in feature:
|
if "image_grid_thw" in feature:
|
||||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||||
|
|
||||||
if "token_type_ids" in feature:
|
|
||||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
|
||||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
|
||||||
|
|
||||||
target_features.append(target_feature)
|
target_features.append(target_feature)
|
||||||
kl_features.append(kl_feature)
|
kl_features.append(kl_feature)
|
||||||
kto_tags.append(feature["kto_tags"])
|
kto_tags.append(feature["kto_tags"])
|
||||||
|
@ -19,6 +19,23 @@ if TYPE_CHECKING:
|
|||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
|
||||||
|
r"""
|
||||||
|
Regularizes images to avoid error. Including resizing and mode convert.
|
||||||
|
"""
|
||||||
|
images = images[:]
|
||||||
|
image_resolution = getattr(processor, "image_resolution", 512)
|
||||||
|
for i in range(len(images)):
|
||||||
|
if max(images[i].width, images[i].height) > image_resolution:
|
||||||
|
factor = image_resolution / max(images[i].width, images[i].height)
|
||||||
|
images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
|
||||||
|
|
||||||
|
if images[i].mode != "RGB":
|
||||||
|
images[i] = images[i].convert("RGB")
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Processes visual inputs.
|
Processes visual inputs.
|
||||||
@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
|
|||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
|
images = _regularize_images(images, processor)
|
||||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||||
else: # add NoneType for fake images
|
else: # add NoneType for fake images
|
||||||
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||||
|
@ -138,6 +138,10 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
||||||
)
|
)
|
||||||
|
image_resolution: int = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "Keeps the height or width of image below this resolution."},
|
||||||
|
)
|
||||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||||
default="huggingface",
|
default="huggingface",
|
||||||
metadata={"help": "Backend engine used at inference."},
|
metadata={"help": "Backend engine used at inference."},
|
||||||
|
@ -99,6 +99,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
setattr(processor, "tokenizer", tokenizer)
|
setattr(processor, "tokenizer", tokenizer)
|
||||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||||
|
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||||
except Exception:
|
except Exception:
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
|
@ -176,7 +176,6 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||||
|
|
||||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||||
|
|
||||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||||
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
if self.loss_type in ["ipo", "orpo", "simpo"]:
|
||||||
all_logps = all_logps / valid_length
|
all_logps = all_logps / valid_length
|
||||||
|
@ -127,17 +127,16 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
"input_ids": batch["{}input_ids".format(prefix)],
|
"input_ids": batch["{}input_ids".format(prefix)],
|
||||||
"attention_mask": batch["{}attention_mask".format(prefix)],
|
"attention_mask": batch["{}attention_mask".format(prefix)],
|
||||||
}
|
}
|
||||||
|
if "{}token_type_ids".format(prefix) in batch:
|
||||||
|
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||||
|
|
||||||
if "pixel_values" in batch:
|
if "pixel_values" in batch:
|
||||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||||
|
|
||||||
if "image_grid_thw" in batch:
|
if "image_grid_thw" in batch:
|
||||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||||
|
|
||||||
if "{}token_type_ids".format(prefix) in batch:
|
|
||||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
|
||||||
|
|
||||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||||
|
|
||||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||||
return logps, logps / valid_length
|
return logps, logps / valid_length
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user