mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	add rlhf-v dataset
Former-commit-id: 3fd18fc34a0c994a738504746abfd5548e002437
This commit is contained in:
		
							parent
							
								
									7621526d22
								
							
						
					
					
						commit
						60cf12727b
					
				@ -291,6 +291,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
 | 
			
		||||
 | 
			
		||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
 | 
			
		||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
 | 
			
		||||
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
 | 
			
		||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
 | 
			
		||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
 | 
			
		||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
 | 
			
		||||
 | 
			
		||||
@ -292,6 +292,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
 | 
			
		||||
- [UltraFeedback (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)
 | 
			
		||||
- [RLHF-V (en)](https://huggingface.co/datasets/openbmb/RLHF-V-Dataset)
 | 
			
		||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
 | 
			
		||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
 | 
			
		||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
 | 
			
		||||
 | 
			
		||||
@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
 | 
			
		||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### DPO/ORPO/SimPO Training
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Multimodal DPO/ORPO/SimPO Training
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Reward Modeling
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### DPO/ORPO/SimPO Training
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### KTO Training
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
 | 
			
		||||
@ -36,6 +36,18 @@ llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
 | 
			
		||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### DPO/ORPO/SimPO 训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 多模态 DPO/ORPO/SimPO 训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_dpo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 奖励模型训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
@ -48,12 +60,6 @@ llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### DPO/ORPO/SimPO 训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### KTO 训练
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										41
									
								
								examples/train_lora/qwen2vl_lora_dpo.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								examples/train_lora/qwen2vl_lora_dpo.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: dpo
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
pref_beta: 0.1
 | 
			
		||||
pref_loss: sigmoid  # choices: [sigmoid (dpo), orpo, simpo]
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: rlhf_v
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/qwen2_vl-7b/lora/dpo
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 5.0e-6
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
val_size: 0.1
 | 
			
		||||
per_device_eval_batch_size: 1
 | 
			
		||||
eval_strategy: steps
 | 
			
		||||
eval_steps: 500
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
 | 
			
		||||
 | 
			
		||||
from datasets import Features
 | 
			
		||||
 | 
			
		||||
@ -33,19 +33,17 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
 | 
			
		||||
def _convert_images(images: Sequence[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Optionally concatenates image path to dataset dir when loading from local disk.
 | 
			
		||||
    """
 | 
			
		||||
    outputs = []
 | 
			
		||||
    images = images[:]
 | 
			
		||||
    if dataset_attr.load_from in ["script", "file"]:
 | 
			
		||||
        for image in images:
 | 
			
		||||
            if isinstance(image, str) and os.path.isfile(os.path.join(data_args.dataset_dir, image)):
 | 
			
		||||
                outputs.append(os.path.join(data_args.dataset_dir, image))
 | 
			
		||||
            else:
 | 
			
		||||
                outputs.append(image)
 | 
			
		||||
        for i in range(len(images)):
 | 
			
		||||
            if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
 | 
			
		||||
                images[i] = os.path.join(data_args.dataset_dir, images[i])
 | 
			
		||||
 | 
			
		||||
    return outputs
 | 
			
		||||
    return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_alpaca(
 | 
			
		||||
 | 
			
		||||
@ -142,15 +142,15 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                    "attention_mask": feature["{}_attention_mask".format(key)],
 | 
			
		||||
                    "labels": feature["{}_labels".format(key)],
 | 
			
		||||
                }
 | 
			
		||||
                if "{}_token_type_ids".format(key) in feature:
 | 
			
		||||
                    target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
 | 
			
		||||
 | 
			
		||||
                if "pixel_values" in feature:  # image data are same for chosen and rejected
 | 
			
		||||
                    target_feature["pixel_values"] = feature["pixel_values"]
 | 
			
		||||
 | 
			
		||||
                if "image_grid_thw" in feature:
 | 
			
		||||
                    target_feature["image_grid_thw"] = feature["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
                if "{}_token_type_ids".format(key) in feature:
 | 
			
		||||
                    target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
 | 
			
		||||
 | 
			
		||||
                concatenated_features.append(target_feature)
 | 
			
		||||
 | 
			
		||||
        return super().__call__(concatenated_features)
 | 
			
		||||
@ -177,16 +177,16 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                "attention_mask": feature["kl_attention_mask"],
 | 
			
		||||
                "labels": feature["kl_labels"],
 | 
			
		||||
            }
 | 
			
		||||
            if "token_type_ids" in feature:
 | 
			
		||||
                target_feature["token_type_ids"] = feature["token_type_ids"]
 | 
			
		||||
                kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
 | 
			
		||||
 | 
			
		||||
            if "pixel_values" in feature:
 | 
			
		||||
                target_feature["pixel_values"] = feature["pixel_values"]
 | 
			
		||||
 | 
			
		||||
            if "image_grid_thw" in feature:
 | 
			
		||||
                target_feature["image_grid_thw"] = feature["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
            if "token_type_ids" in feature:
 | 
			
		||||
                target_feature["token_type_ids"] = feature["token_type_ids"]
 | 
			
		||||
                kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
 | 
			
		||||
 | 
			
		||||
            target_features.append(target_feature)
 | 
			
		||||
            kl_features.append(kl_feature)
 | 
			
		||||
            kto_tags.append(feature["kto_tags"])
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,23 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _regularize_images(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> List["ImageObject"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Regularizes images to avoid error. Including resizing and mode convert.
 | 
			
		||||
    """
 | 
			
		||||
    images = images[:]
 | 
			
		||||
    image_resolution = getattr(processor, "image_resolution", 512)
 | 
			
		||||
    for i in range(len(images)):
 | 
			
		||||
        if max(images[i].width, images[i].height) > image_resolution:
 | 
			
		||||
            factor = image_resolution / max(images[i].width, images[i].height)
 | 
			
		||||
            images[i] = images[i].resize((int(images[i].width * factor), int(images[i].height * factor)))
 | 
			
		||||
 | 
			
		||||
        if images[i].mode != "RGB":
 | 
			
		||||
            images[i] = images[i].convert("RGB")
 | 
			
		||||
 | 
			
		||||
    return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Processes visual inputs.
 | 
			
		||||
@ -34,6 +51,7 @@ def _get_mm_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin")
 | 
			
		||||
    """
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    if len(images) != 0:
 | 
			
		||||
        images = _regularize_images(images, processor)
 | 
			
		||||
        image_inputs = image_processor(images=images, return_tensors="pt")
 | 
			
		||||
    else:  # add NoneType for fake images
 | 
			
		||||
        image = Image.new("RGB", (64, 64), (255, 255, 255))
 | 
			
		||||
 | 
			
		||||
@ -138,6 +138,10 @@ class ModelArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to randomly initialize the model weights."},
 | 
			
		||||
    )
 | 
			
		||||
    image_resolution: int = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "Keeps the height or width of image below this resolution."},
 | 
			
		||||
    )
 | 
			
		||||
    infer_backend: Literal["huggingface", "vllm"] = field(
 | 
			
		||||
        default="huggingface",
 | 
			
		||||
        metadata={"help": "Backend engine used at inference."},
 | 
			
		||||
 | 
			
		||||
@ -99,6 +99,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
 | 
			
		||||
        setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
        setattr(processor, "image_seqlen", get_image_seqlen(config))
 | 
			
		||||
        setattr(processor, "image_resolution", model_args.image_resolution)
 | 
			
		||||
    except Exception:
 | 
			
		||||
        processor = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -176,7 +176,6 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
            batch = {k: v.detach().clone() for k, v in batch.items()}  # avoid error
 | 
			
		||||
 | 
			
		||||
        all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
 | 
			
		||||
        if self.loss_type in ["ipo", "orpo", "simpo"]:
 | 
			
		||||
            all_logps = all_logps / valid_length
 | 
			
		||||
 | 
			
		||||
@ -127,17 +127,16 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
            "input_ids": batch["{}input_ids".format(prefix)],
 | 
			
		||||
            "attention_mask": batch["{}attention_mask".format(prefix)],
 | 
			
		||||
        }
 | 
			
		||||
        if "{}token_type_ids".format(prefix) in batch:
 | 
			
		||||
            model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
 | 
			
		||||
 | 
			
		||||
        if "pixel_values" in batch:
 | 
			
		||||
            model_inputs["pixel_values"] = batch["pixel_values"]
 | 
			
		||||
 | 
			
		||||
        if "image_grid_thw" in batch:
 | 
			
		||||
            model_inputs["image_grid_thw"] = batch["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
        if "{}token_type_ids".format(prefix) in batch:
 | 
			
		||||
            model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
 | 
			
		||||
 | 
			
		||||
        logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
 | 
			
		||||
        return logps, logps / valid_length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user