diff --git a/README.md b/README.md index 5b34b915..ab500587 100644 --- a/README.md +++ b/README.md @@ -261,6 +261,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | +| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 7B | qwen2_omni | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | diff --git a/README_zh.md b/README_zh.md index 344ac488..4862060e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -263,6 +263,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | +| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 7B | qwen2_omni | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | diff --git a/scripts/lora_part_merge.py b/scripts/lora_part_merge.py new file mode 100644 index 00000000..5000cece --- /dev/null +++ b/scripts/lora_part_merge.py @@ -0,0 +1,90 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is based on the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil + +import fire +from peft import PeftModel +from transformers import AutoModel, AutoProcessor, AutoTokenizer + + +def merge_lora( + base_model_path: str, + lora_checkpoint_path: str, + extra_file: str = "spk_dict.pt", + submodule_name: str = "thinker", + save_path: str = "./merged_model_checkpoint", +): + """Load the original model, tokenizer, and processor configuration, merge the LoRA weights. + + for a specified submodule, and save the final merged model along with its configurations. + + Args: + base_model_path (str): Path to the original model directory. + lora_checkpoint_path (str): Path to the directory containing LoRA weights. + extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt"). + submodule_name (str): Name of the submodule to merge (default: "thinker"). + save_path (str): Directory where the merged model and configurations will be saved. + """ + # 1. Load the original model, tokenizer, and processor + model = AutoModel.from_pretrained(base_model_path) + tokenizer = AutoTokenizer.from_pretrained(base_model_path) + + try: + processor = AutoProcessor.from_pretrained(base_model_path) + except Exception: + print("Processor configuration not found, skipping processor load.") + processor = None + + print("Successfully loaded the original model, tokenizer, and processor (if available).") + + # 2. Extract the submodule to be merged (e.g., model.thinker) + if not hasattr(model, submodule_name): + raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.") + base_submodule = getattr(model, submodule_name) + print(f"Successfully extracted submodule: {submodule_name}.") + + # 3. Load the LoRA weights onto the extracted submodule + lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path) + print("LoRA weights loaded successfully.") + + # 4. Merge the LoRA weights into the submodule and unload the LoRA modules + merged_submodule = lora_model.merge_and_unload() + print("LoRA weights merged successfully.") + + # 5. Replace the original submodule with the merged submodule in the model + setattr(model, submodule_name, merged_submodule) + + # 6. Save the final merged model along with the tokenizer and processor configuration + model.save_pretrained(save_path) + tokenizer.save_pretrained(save_path) + if processor is not None: + processor.save_pretrained(save_path) + + print(f"Merged model and configuration saved to {save_path}.") + + source_file = os.path.join(base_model_path, extra_file) + target_file = os.path.join(save_path, extra_file) + if os.path.exists(source_file): + shutil.copy(source_file, target_file) + print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.") + else: + print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.") + + +if __name__ == "__main__": + fire.Fire(merge_lora) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 838c0719..93c7349c 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -190,10 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): "video_grid_thw": mm_inputs.get("video_grid_thw"), "attention_mask": features["attention_mask"], } - if "second_per_grid_ts" in mm_inputs: + if "second_per_grid_ts" in mm_inputs: # for qwen2vl rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") - features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) + if getattr(self.model.config, "model_type", None) == "qwen2_5_omni": # for qwen2omni + feature_attention_mask = mm_inputs.get("feature_attention_mask", None) + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum( + feature_attention_mask, dim=1 + ) # FIXME need to get video image lengths + rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input + + delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1) + # avoid conflict + rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None) + new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs) + features["position_ids"], features["rope_deltas"] = ( + new_position_ids.clone(), + rope_deltas - delta0, + ) # avoid inplace operation FIXME + else: # for qwen2vl + features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4555be8e..b6928636 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -146,6 +146,12 @@ class MMPluginMixin: video_processor: BaseImageProcessor = getattr( processor, "video_processor", getattr(processor, "image_processor", None) ) + if image_processor is None and video_processor is None: # hack for qwen2_5_omni + image_processor, video_processor = ( + getattr(processor, "omni_processor", None), + getattr(processor, "omni_processor", None), + ) + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) if len(images) != 0 and self.image_token is None: raise ValueError( @@ -1104,6 +1110,186 @@ class Qwen2AudioPlugin(BasePlugin): return self._get_mm_inputs(images, videos, audios, processor) +class Qwen2OmniPlugin(BasePlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + imglens: Optional[list[int]] = None, + ) -> dict[str, "torch.Tensor"]: + mm_inputs = {} + if len(images) != 0: + image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + ) + if imglens is not None: + images = _make_batched_images(images, imglens) + + image_processor_kwargs = {} + mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs)) + + if len(videos) != 0: + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "omni_processor", None) + ) + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava + mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt")) + fps = [2.0] * len(videos) # FIXME hardcode + video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))] + mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid) + + else: + raise NotImplementedError + + if len(audios) != 0: + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) + audios = self._regularize_audios( + audios, + sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), + ) + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + ) + mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0 + use_audio_in_video = getattr(processor, "use_audio_in_video", False) + + # get length or size from mm_inputs + if "feature_attention_mask" in mm_inputs: + input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1 + audio_lengths = (input_lengths - 2) // 2 + 1 + if mm_inputs.get("image_grid_thw", None) is not None: + image_grid_thw = mm_inputs["image_grid_thw"] + merge_length = processor.omni_processor.merge_size**2 + if mm_inputs.get("video_grid_thw", None) is not None: + video_grid_thw = mm_inputs["video_grid_thw"] + merge_length = processor.omni_processor.merge_size**2 + + if use_audio_in_video: + assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`" + assert mm_inputs.get("video_grid_thw", None) is not None, ( + "video_grid_thw should be exist when use_audio_in_video is `True`" + ) + positions_list = [] + for i, message in enumerate(messages): # get multimodal index when use_audio + positions = [] + for special_token in [self.audio_token, self.image_token, self.video_token]: + start = 0 + while True: + pos = message[i].find(special_token, start) + if pos == -1: + break + positions.append((pos, special_token)) + start = pos + len(special_token) + positions_list.append(positions.sort(key=lambda x: x[0])) + + for message in messages: + content = message["content"] + # separate with audio-video + while IMAGE_PLACEHOLDER in content: + image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length + content = content.replace( + IMAGE_PLACEHOLDER, + f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>", + 1, + ) + num_image_tokens += 1 + + if not use_audio_in_video: + while AUDIO_PLACEHOLDER in content: + audio_token_replace_length = audio_lengths[num_audio_tokens] + content = content.replace( + AUDIO_PLACEHOLDER, + f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>", + 1, + ) + num_audio_tokens += 1 + # TODO handle video_input and use_audio_in_video + while VIDEO_PLACEHOLDER in content: + video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length + content = content.replace( + VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1 + ) + num_video_tokens += 1 + else: # if use the audio of video # deal video token and audio token togather + while VIDEO_PLACEHOLDER in content: + audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) + video_t_index = ( + torch.arange(video_grid_thw[num_video_tokens][0]) + .view(-1, 1, 1) + .expand( + -1, + video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size, + video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size, + ) + .flatten() + * mm_inputs["video_second_per_grid"][num_video_tokens] + * 25 # FIXME hardcode of position_id_per_seconds=25 + ).long() + t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2] + video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk) + audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk) + placeholder_string = "" + for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))): + video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None + audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None + placeholder_string = "<|vision_bos|>" + "<|audio_bos|>" + if video_chunk_index is not None: + placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) + if audio_chunk_index is not None: + placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) + placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" + content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) + content = content.replace(AUDIO_PLACEHOLDER, "", 1) + num_audio_tokens += 1 + num_video_tokens += 1 + message["content"] = content + + if len(audios) != num_audio_tokens: + raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.") + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + return messages + + @dataclass class Qwen2VLPlugin(BasePlugin): @override @@ -1328,6 +1514,7 @@ PLUGINS = { "paligemma": PaliGemmaPlugin, "pixtral": PixtralPlugin, "qwen2_audio": Qwen2AudioPlugin, + "qwen2_omni": Qwen2OmniPlugin, "qwen2_vl": Qwen2VLPlugin, "video_llava": VideoLlavaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b4d5d51b..a29bf959 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1367,6 +1367,24 @@ register_template( ) +# copied from qwen template +register_template( + name="qwen2_omni", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + default_system="You are a helpful assistant.", + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin( + name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" + ), +) + # copied from qwen template register_template( name="qwen2_vl", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index d0cd4891..702e9bc6 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -2270,6 +2270,18 @@ register_model_group( ) +register_model_group( + models={ + "Qwen2.5-Omni-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B", + } + }, + template="qwen2_omni", + multimodal=True, +) + + register_model_group( models={ "Qwen2-VL-2B": { diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 1cf2271b..fec05374 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -222,6 +222,10 @@ class ProcessorArguments: default=False, metadata={"help": "Use pan and scan to process image for gemma3."}, ) + use_audio_in_video: bool = field( + default=False, + metadata={"help": "Whether or not to use audio in video inputs."}, + ) video_max_pixels: int = field( default=256 * 256, metadata={"help": "The maximum number of pixels of video inputs."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 7b116397..5360a69a 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -21,6 +21,7 @@ from transformers import ( AutoModelForCausalLM, AutoModelForImageTextToText, AutoModelForSeq2SeqLM, + AutoModelForTextToWaveform, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, @@ -147,6 +148,8 @@ def load_model( load_class = AutoModelForImageTextToText elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text load_class = AutoModelForSeq2SeqLM + elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni + load_class = AutoModelForTextToWaveform else: load_class = AutoModelForCausalLM @@ -154,6 +157,8 @@ def load_model( model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code) else: model = load_class.from_pretrained(**init_kwargs) + if load_class is AutoModelForTextToWaveform: + model = model.thinker # use part of Omni model if model_args.mixture_of_depths == "convert": model = convert_pretrained_model_to_mod(model, config, model_args) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 76162802..0158a50b 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -257,6 +257,17 @@ _register_composite_model( ) +_register_composite_model( + model_type="qwen2_5_omni_thinker", + projector_key="visual.merger", + vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"], + language_model_keys=["model", "lm_head"], + lora_conflict_keys=[ + "patch_embed", + ], +) + + _register_composite_model( model_type="qwen2_vl", projector_key="visual.merger",