mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
This commit is contained in:
parent
468eea6f6d
commit
185c76f6ad
@ -261,6 +261,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||||
|
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 7B | qwen2_omni |
|
||||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||||
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
@ -263,6 +263,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||||
|
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 7B | qwen2_omni |
|
||||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl |
|
||||||
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
| [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
|
90
scripts/lora_part_merge.py
Normal file
90
scripts/lora_part_merge.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is based on the HuggingFace's PEFT library.
|
||||||
|
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import fire
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora(
|
||||||
|
base_model_path: str,
|
||||||
|
lora_checkpoint_path: str,
|
||||||
|
extra_file: str = "spk_dict.pt",
|
||||||
|
submodule_name: str = "thinker",
|
||||||
|
save_path: str = "./merged_model_checkpoint",
|
||||||
|
):
|
||||||
|
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
|
||||||
|
|
||||||
|
for a specified submodule, and save the final merged model along with its configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model_path (str): Path to the original model directory.
|
||||||
|
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
|
||||||
|
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
|
||||||
|
submodule_name (str): Name of the submodule to merge (default: "thinker").
|
||||||
|
save_path (str): Directory where the merged model and configurations will be saved.
|
||||||
|
"""
|
||||||
|
# 1. Load the original model, tokenizer, and processor
|
||||||
|
model = AutoModel.from_pretrained(base_model_path)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(base_model_path)
|
||||||
|
except Exception:
|
||||||
|
print("Processor configuration not found, skipping processor load.")
|
||||||
|
processor = None
|
||||||
|
|
||||||
|
print("Successfully loaded the original model, tokenizer, and processor (if available).")
|
||||||
|
|
||||||
|
# 2. Extract the submodule to be merged (e.g., model.thinker)
|
||||||
|
if not hasattr(model, submodule_name):
|
||||||
|
raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
|
||||||
|
base_submodule = getattr(model, submodule_name)
|
||||||
|
print(f"Successfully extracted submodule: {submodule_name}.")
|
||||||
|
|
||||||
|
# 3. Load the LoRA weights onto the extracted submodule
|
||||||
|
lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
|
||||||
|
print("LoRA weights loaded successfully.")
|
||||||
|
|
||||||
|
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
|
||||||
|
merged_submodule = lora_model.merge_and_unload()
|
||||||
|
print("LoRA weights merged successfully.")
|
||||||
|
|
||||||
|
# 5. Replace the original submodule with the merged submodule in the model
|
||||||
|
setattr(model, submodule_name, merged_submodule)
|
||||||
|
|
||||||
|
# 6. Save the final merged model along with the tokenizer and processor configuration
|
||||||
|
model.save_pretrained(save_path)
|
||||||
|
tokenizer.save_pretrained(save_path)
|
||||||
|
if processor is not None:
|
||||||
|
processor.save_pretrained(save_path)
|
||||||
|
|
||||||
|
print(f"Merged model and configuration saved to {save_path}.")
|
||||||
|
|
||||||
|
source_file = os.path.join(base_model_path, extra_file)
|
||||||
|
target_file = os.path.join(save_path, extra_file)
|
||||||
|
if os.path.exists(source_file):
|
||||||
|
shutil.copy(source_file, target_file)
|
||||||
|
print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
|
||||||
|
else:
|
||||||
|
print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(merge_lora)
|
@ -190,10 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||||
"attention_mask": features["attention_mask"],
|
"attention_mask": features["attention_mask"],
|
||||||
}
|
}
|
||||||
if "second_per_grid_ts" in mm_inputs:
|
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||||
|
|
||||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni": # for qwen2omni
|
||||||
|
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
audio_feature_lengths = torch.sum(
|
||||||
|
feature_attention_mask, dim=1
|
||||||
|
) # FIXME need to get video image lengths
|
||||||
|
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||||
|
|
||||||
|
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
|
||||||
|
# avoid conflict
|
||||||
|
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
|
||||||
|
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
|
||||||
|
features["position_ids"], features["rope_deltas"] = (
|
||||||
|
new_position_ids.clone(),
|
||||||
|
rope_deltas - delta0,
|
||||||
|
) # avoid inplace operation FIXME
|
||||||
|
else: # for qwen2vl
|
||||||
|
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||||
|
|
||||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||||
|
@ -146,6 +146,12 @@ class MMPluginMixin:
|
|||||||
video_processor: BaseImageProcessor = getattr(
|
video_processor: BaseImageProcessor = getattr(
|
||||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||||
)
|
)
|
||||||
|
if image_processor is None and video_processor is None: # hack for qwen2_5_omni
|
||||||
|
image_processor, video_processor = (
|
||||||
|
getattr(processor, "omni_processor", None),
|
||||||
|
getattr(processor, "omni_processor", None),
|
||||||
|
)
|
||||||
|
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||||
if len(images) != 0 and self.image_token is None:
|
if len(images) != 0 and self.image_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1104,6 +1110,186 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, audios, processor)
|
return self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2OmniPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
imglens: Optional[list[int]] = None,
|
||||||
|
) -> dict[str, "torch.Tensor"]:
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0:
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||||
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||||
|
)
|
||||||
|
if imglens is not None:
|
||||||
|
images = _make_batched_images(images, imglens)
|
||||||
|
|
||||||
|
image_processor_kwargs = {}
|
||||||
|
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
video_processor: BaseImageProcessor = getattr(
|
||||||
|
processor, "video_processor", getattr(processor, "omni_processor", None)
|
||||||
|
)
|
||||||
|
videos = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||||
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
|
)
|
||||||
|
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
||||||
|
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||||
|
fps = [2.0] * len(videos) # FIXME hardcode
|
||||||
|
video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
|
||||||
|
mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if len(audios) != 0:
|
||||||
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||||
|
audios = self._regularize_audios(
|
||||||
|
audios,
|
||||||
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||||
|
)
|
||||||
|
mm_inputs.update(
|
||||||
|
feature_extractor(
|
||||||
|
audios,
|
||||||
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||||
|
return_attention_mask=True,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
||||||
|
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
||||||
|
|
||||||
|
# get length or size from mm_inputs
|
||||||
|
if "feature_attention_mask" in mm_inputs:
|
||||||
|
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
||||||
|
audio_lengths = (input_lengths - 2) // 2 + 1
|
||||||
|
if mm_inputs.get("image_grid_thw", None) is not None:
|
||||||
|
image_grid_thw = mm_inputs["image_grid_thw"]
|
||||||
|
merge_length = processor.omni_processor.merge_size**2
|
||||||
|
if mm_inputs.get("video_grid_thw", None) is not None:
|
||||||
|
video_grid_thw = mm_inputs["video_grid_thw"]
|
||||||
|
merge_length = processor.omni_processor.merge_size**2
|
||||||
|
|
||||||
|
if use_audio_in_video:
|
||||||
|
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
|
||||||
|
assert mm_inputs.get("video_grid_thw", None) is not None, (
|
||||||
|
"video_grid_thw should be exist when use_audio_in_video is `True`"
|
||||||
|
)
|
||||||
|
positions_list = []
|
||||||
|
for i, message in enumerate(messages): # get multimodal index when use_audio
|
||||||
|
positions = []
|
||||||
|
for special_token in [self.audio_token, self.image_token, self.video_token]:
|
||||||
|
start = 0
|
||||||
|
while True:
|
||||||
|
pos = message[i].find(special_token, start)
|
||||||
|
if pos == -1:
|
||||||
|
break
|
||||||
|
positions.append((pos, special_token))
|
||||||
|
start = pos + len(special_token)
|
||||||
|
positions_list.append(positions.sort(key=lambda x: x[0]))
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
# separate with audio-video
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER,
|
||||||
|
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
if not use_audio_in_video:
|
||||||
|
while AUDIO_PLACEHOLDER in content:
|
||||||
|
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||||
|
content = content.replace(
|
||||||
|
AUDIO_PLACEHOLDER,
|
||||||
|
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
num_audio_tokens += 1
|
||||||
|
# TODO handle video_input and use_audio_in_video
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||||
|
content = content.replace(
|
||||||
|
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||||
|
)
|
||||||
|
num_video_tokens += 1
|
||||||
|
else: # if use the audio of video # deal video token and audio token togather
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||||
|
video_t_index = (
|
||||||
|
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||||
|
.view(-1, 1, 1)
|
||||||
|
.expand(
|
||||||
|
-1,
|
||||||
|
video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
|
||||||
|
video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
|
||||||
|
)
|
||||||
|
.flatten()
|
||||||
|
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||||
|
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||||
|
).long()
|
||||||
|
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||||
|
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||||
|
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||||
|
placeholder_string = ""
|
||||||
|
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||||
|
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||||
|
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||||
|
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
|
||||||
|
if video_chunk_index is not None:
|
||||||
|
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||||
|
if audio_chunk_index is not None:
|
||||||
|
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||||
|
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||||
|
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||||
|
num_audio_tokens += 1
|
||||||
|
num_video_tokens += 1
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(audios) != num_audio_tokens:
|
||||||
|
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen2VLPlugin(BasePlugin):
|
class Qwen2VLPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@ -1328,6 +1514,7 @@ PLUGINS = {
|
|||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
"pixtral": PixtralPlugin,
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_audio": Qwen2AudioPlugin,
|
"qwen2_audio": Qwen2AudioPlugin,
|
||||||
|
"qwen2_omni": Qwen2OmniPlugin,
|
||||||
"qwen2_vl": Qwen2VLPlugin,
|
"qwen2_vl": Qwen2VLPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
}
|
}
|
||||||
|
@ -1367,6 +1367,24 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from qwen template
|
||||||
|
register_template(
|
||||||
|
name="qwen2_omni",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
mm_plugin=get_mm_plugin(
|
||||||
|
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# copied from qwen template
|
# copied from qwen template
|
||||||
register_template(
|
register_template(
|
||||||
name="qwen2_vl",
|
name="qwen2_vl",
|
||||||
|
@ -2270,6 +2270,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Qwen2.5-Omni-7B": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="qwen2_omni",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen2-VL-2B": {
|
"Qwen2-VL-2B": {
|
||||||
|
@ -222,6 +222,10 @@ class ProcessorArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use pan and scan to process image for gemma3."},
|
metadata={"help": "Use pan and scan to process image for gemma3."},
|
||||||
)
|
)
|
||||||
|
use_audio_in_video: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use audio in video inputs."},
|
||||||
|
)
|
||||||
video_max_pixels: int = field(
|
video_max_pixels: int = field(
|
||||||
default=256 * 256,
|
default=256 * 256,
|
||||||
metadata={"help": "The maximum number of pixels of video inputs."},
|
metadata={"help": "The maximum number of pixels of video inputs."},
|
||||||
|
@ -21,6 +21,7 @@ from transformers import (
|
|||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForImageTextToText,
|
AutoModelForImageTextToText,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoModelForTextToWaveform,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
@ -147,6 +148,8 @@ def load_model(
|
|||||||
load_class = AutoModelForImageTextToText
|
load_class = AutoModelForImageTextToText
|
||||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||||
load_class = AutoModelForSeq2SeqLM
|
load_class = AutoModelForSeq2SeqLM
|
||||||
|
elif type(config) in AutoModelForTextToWaveform._model_mapping.keys(): # audio hack for qwen2_5_omni
|
||||||
|
load_class = AutoModelForTextToWaveform
|
||||||
else:
|
else:
|
||||||
load_class = AutoModelForCausalLM
|
load_class = AutoModelForCausalLM
|
||||||
|
|
||||||
@ -154,6 +157,8 @@ def load_model(
|
|||||||
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
|
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
|
||||||
else:
|
else:
|
||||||
model = load_class.from_pretrained(**init_kwargs)
|
model = load_class.from_pretrained(**init_kwargs)
|
||||||
|
if load_class is AutoModelForTextToWaveform:
|
||||||
|
model = model.thinker # use part of Omni model
|
||||||
|
|
||||||
if model_args.mixture_of_depths == "convert":
|
if model_args.mixture_of_depths == "convert":
|
||||||
model = convert_pretrained_model_to_mod(model, config, model_args)
|
model = convert_pretrained_model_to_mod(model, config, model_args)
|
||||||
|
@ -257,6 +257,17 @@ _register_composite_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="qwen2_5_omni_thinker",
|
||||||
|
projector_key="visual.merger",
|
||||||
|
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
||||||
|
language_model_keys=["model", "lm_head"],
|
||||||
|
lora_conflict_keys=[
|
||||||
|
"patch_embed",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen2_vl",
|
model_type="qwen2_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user