From 8a09b1e73235261c3be799ba92c3c4ce8d2c6f5f Mon Sep 17 00:00:00 2001 From: simonJJJ <821898965@qq.com> Date: Wed, 28 Aug 2024 16:51:35 +0800 Subject: [PATCH] initial-commit Former-commit-id: aeb85f200bd824748008dae6047c2607dfcdf174 --- data/dataset_info.json | 14 ++ data/qwen2vl_demo.json | 140 ++++++++++++++++++ examples/train_full/qwen2vl_full_sft.yaml | 40 +++++ examples/train_lora/qwen2vl_lora_sft.yaml | 40 +++++ src/llamafactory/data/collator.py | 28 ++++ .../data/processors/processor_utils.py | 14 ++ .../data/processors/supervised.py | 41 ++++- src/llamafactory/data/template.py | 21 +++ src/llamafactory/model/adapter.py | 2 +- src/llamafactory/model/model_utils/misc.py | 2 + 10 files changed, 337 insertions(+), 5 deletions(-) create mode 100644 data/qwen2vl_demo.json create mode 100644 examples/train_full/qwen2vl_full_sft.yaml create mode 100644 examples/train_lora/qwen2vl_lora_sft.yaml diff --git a/data/dataset_info.json b/data/dataset_info.json index b00456d2..b2f99fe1 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -38,6 +38,20 @@ "assistant_tag": "assistant" } }, + "qwen2vl_demo": { + "file_name": "qwen2vl_demo.json", + "formatting": "sharegpt", + "columns": { + "messages": "messages", + "images": "images" + }, + "tags": { + "role_tag": "role", + "content_tag": "content", + "user_tag": "user", + "assistant_tag": "assistant" + } + }, "alpaca_en": { "hf_hub_url": "llamafactory/alpaca_en", "ms_hub_url": "llamafactory/alpaca_en" diff --git a/data/qwen2vl_demo.json b/data/qwen2vl_demo.json new file mode 100644 index 00000000..f8ab29ff --- /dev/null +++ b/data/qwen2vl_demo.json @@ -0,0 +1,140 @@ +[ + { + "messages": [ + { + "content": "<|image_pad|>Who are they?", + "role": "user" + }, + { + "content": "They're Kane and Gretzka from Bayern Munich.", + "role": "assistant" + }, + { + "content": "What are they doing?", + "role": "user" + }, + { + "content": "They are celebrating on the soccer field.", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/1.jpg" + ] + }, + { + "messages": [ + { + "content": "<|image_pad|>Who is he?", + "role": "user" + }, + { + "content": "He's Thomas Muller from Bayern Munich.", + "role": "assistant" + }, + { + "content": "<|image_pad|>Why is he on the ground?", + "role": "user" + }, + { + "content": "Because he's sliding on his knees to celebrate.", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/2.jpg","mllm_demo_data/2.jpg" + ] + }, + { + "messages": [ + { + "content": "<|image_pad|>Please describe this image", + "role": "user" + }, + { + "content": "Chinese astronaut Gui Haichao is giving a speech.", + "role": "assistant" + }, + { + "content": "What has he accomplished?", + "role": "user" + }, + { + "content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/3.jpg" + ] + }, + { + "messages": [ + { + "content": "<|image_pad|>他们是谁?", + "role": "user" + }, + { + "content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。", + "role": "assistant" + }, + { + "content": "<|image_pad|>他们在做什么?", + "role": "user" + }, + { + "content": "他们在足球场上庆祝。", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/1.jpg","mllm_demo_data/1.jpg" + ] + }, + { + "messages": [ + { + "content": "<|image_pad|>他是谁?", + "role": "user" + }, + { + "content": "他是来自拜仁慕尼黑的托马斯·穆勒。", + "role": "assistant" + }, + { + "content": "他为什么在地上?", + "role": "user" + }, + { + "content": "因为他正在双膝跪地滑行庆祝。", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/2.jpg" + ] + }, + { + "messages": [ + { + "content": "<|image_pad|>请描述这张图片", + "role": "user" + }, + { + "content": "中国宇航员桂海潮正在讲话。", + "role": "assistant" + }, + { + "content": "他取得过哪些成就?", + "role": "user" + }, + { + "content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。", + "role": "assistant" + } + ], + "images": [ + "mllm_demo_data/3.jpg" + ] + } +] diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml new file mode 100644 index 00000000..ea8a3422 --- /dev/null +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf +visual_inputs: true + +### method +stage: sft +do_train: true +finetuning_type: full +deepspeed: examples/deepspeed/ds_z3_config.json + +### dataset +dataset: qwen2vl_demo +template: qwen2vl +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/qwen2-vl-7b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 \ No newline at end of file diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml new file mode 100644 index 00000000..e893ea01 --- /dev/null +++ b/examples/train_lora/qwen2vl_lora_sft.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf +visual_inputs: true + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all + +### dataset +dataset: qwen2vl_demo +template: qwen2vl +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/qwen2-vl-7b/lora/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 2 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index a603a7e8..a4197a20 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -72,9 +72,37 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): compute_dtype: "torch.dtype" = torch.float32 def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: + image_grid_thw = None + if "image_grid_thw" in features[0]: + image_grid_thw_list = [ + torch.Tensor(feature["image_grid_thw"]).long() + for feature in features + if feature["image_grid_thw"][0][0] > 0 + ] + pixel_values_list = [ + torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0 + ] + if image_grid_thw_list: + image_grid_thw = torch.cat(image_grid_thw_list, 0) + else: + # Handle the case where the list is empty, for example: + image_grid_thw = None + if pixel_values_list: + pixel_values = torch.cat(pixel_values_list, 0) + else: + # Handle the case where the list is empty, for example: + pixel_values = None + features = [ + {key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]} + for feature in features + ] + features = super().__call__(features) if self.block_diag_attn and self.attn_implementation != "flash_attention_2": features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) + if image_grid_thw is not None: + features["image_grid_thw"] = image_grid_thw + features["pixel_values"] = pixel_values return features diff --git a/src/llamafactory/data/processors/processor_utils.py b/src/llamafactory/data/processors/processor_utils.py index 435cf6ca..845ef8f8 100644 --- a/src/llamafactory/data/processors/processor_utils.py +++ b/src/llamafactory/data/processors/processor_utils.py @@ -78,6 +78,20 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> return [0] * image_seq_length + [1] * (input_len - image_seq_length) +def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray": + r""" + Processes visual inputs. support multi images + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + if len(images) != 0: + image_inputs = image_processor(images=images, return_tensors="pt") + else: + image = Image.new("RGB", (56, 56), (255, 255, 255)) + image_inputs = image_processor(images=[image], return_tensors="pt") + image_inputs["image_grid_thw"][0][0] = 0 + return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]} + + def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: r""" Computes the real sequence length after truncation by the cutoff_len. diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 950de12a..90da57a1 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -17,10 +17,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen +from .processor_utils import ( + get_paligemma_token_type_ids, + get_pixel_values, + get_qwen2vl_image_inputs, + greedy_knapsack, + infer_seqlen, +) if TYPE_CHECKING: + from PIL.Image import Image as ImageObject from transformers import PreTrainedTokenizer, ProcessorMixin from ...hparams import DataArguments @@ -36,12 +43,31 @@ def _encode_supervised_example( system: Optional[str], tools: Optional[str], template: "Template", + images: Sequence["ImageObject"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, train_on_prompt: bool, mask_history: bool, ) -> Tuple[List[int], List[int]]: + if "image_grid_thw" in processor.model_input_names: # qwen2_vl models + image_processor = getattr(processor, "image_processor") + merge_length = image_processor.merge_size**2 + if len(images) > 0: + image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"] + index = 0 + for message in prompt: + content = message["content"] + while "<|image_pad|>" in content: + content = content.replace( + "<|image_pad|>", + template.vision_start_token + + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length) + + template.vision_end_token, + 1, + ) + index += 1 + message["content"] = content.replace("<|placeholder|>", "<|image_pad|>") if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models prompt[0]["content"] = template.image_token + prompt[0]["content"] @@ -107,6 +133,8 @@ def preprocess_supervised_dataset( model_inputs["pixel_values"] = [] if hasattr(processor, "image_seq_length"): # paligemma models model_inputs["token_type_ids"] = [] + if "image_grid_thw" in processor.model_input_names: # qwen2_vl models + model_inputs["image_grid_thw"] = [] for i in range(len(examples["prompt"])): if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: @@ -129,9 +157,14 @@ def preprocess_supervised_dataset( model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["labels"].append(labels) if processor is not None: - model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) - if hasattr(processor, "image_seq_length"): # paligemma models - model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) + if "image_grid_thw" in processor.model_input_names: # qwen2_vl models + image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor) + model_inputs["pixel_values"].append(image_inputs["pixel_values"]) + model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"]) + else: + model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) + if hasattr(processor, "image_seq_length"): # paligemma models + model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) return model_inputs diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index f9eeb66a..cce97357 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -42,6 +42,8 @@ class Template: default_system: str stop_words: List[str] image_token: str + vision_start_token: str + vision_end_token: str efficient_eos: bool replace_eos: bool @@ -206,6 +208,8 @@ def _register_template( default_system: str = "", stop_words: Sequence[str] = [], image_token: str = "", + vision_start_token: str = "<|vision_start|>", + vision_end_token: str = "<|vision_end|>", efficient_eos: bool = False, replace_eos: bool = False, ) -> None: @@ -255,6 +259,8 @@ def _register_template( default_system=default_system, stop_words=stop_words, image_token=image_token, + vision_start_token=vision_start_token, + vision_end_token=vision_end_token, efficient_eos=efficient_eos, replace_eos=replace_eos, ) @@ -783,6 +789,21 @@ _register_template( ) +_register_template( + name="qwen2vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", + image_token="<|image_pad|>", + vision_start_token="<|vision_start|>", + vision_end_token="<|vision_end|>", + stop_words=["<|im_end|>"], + replace_eos=True, +) + + _register_template( name="sailor", format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 7caef9cc..f18bcbc9 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -212,7 +212,7 @@ def _setup_lora_tuning( target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) if model_args.visual_inputs and finetuning_args.freeze_vision_tower: - target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules)) if ( finetuning_args.use_dora diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index a2812228..96233002 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -36,6 +36,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output") elif model.config.model_type in ["llava", "paligemma"]: forbidden_modules.add("multi_modal_projector") + elif model.config.model_type in ["qwen2_vl"]: + forbidden_modules.add("merger") if freeze_vision_tower: forbidden_modules.add("vision_tower")