mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix qwen2.5 omni plugin (#7573)
* align key with qwen2vl * nit && change scripts
This commit is contained in:
		
							parent
							
								
									7b9deb9410
								
							
						
					
					
						commit
						d32c6c014d
					
				@ -19,7 +19,7 @@ import shutil
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from peft import PeftModel
 | 
			
		||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
 | 
			
		||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer, Qwen2_5OmniThinkerForConditionalGeneration
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_lora(
 | 
			
		||||
@ -31,7 +31,7 @@ def merge_lora(
 | 
			
		||||
):
 | 
			
		||||
    """Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
 | 
			
		||||
 | 
			
		||||
    for a specified submodule, and save the final merged model along with its configurations.
 | 
			
		||||
    For a specified submodule, and save the final merged model along with its configurations.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        base_model_path (str): Path to the original model directory.
 | 
			
		||||
@ -86,5 +86,47 @@ def merge_lora(
 | 
			
		||||
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_full_model(
 | 
			
		||||
    saved_thinker_path: str,
 | 
			
		||||
    base_model_path: str,
 | 
			
		||||
    save_path: str,
 | 
			
		||||
    extra_file: str = "spk_dict.pt",
 | 
			
		||||
):
 | 
			
		||||
    """Load the saved thinker module and the original model, replace the thinker in the original model.
 | 
			
		||||
 | 
			
		||||
    Then save the complete model along with its tokenizer and processor configuration.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        saved_thinker_path (str): Path to the saved thinker weights.
 | 
			
		||||
        base_model_path (str): Directory path of the original model.
 | 
			
		||||
        save_path (str): Directory where the final complete model will be saved.
 | 
			
		||||
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
 | 
			
		||||
    """
 | 
			
		||||
    # Load the thinker module
 | 
			
		||||
    thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(saved_thinker_path, device_map="cpu")
 | 
			
		||||
    # Load the original model
 | 
			
		||||
    base_model = AutoModel.from_pretrained(base_model_path, device_map="cpu")
 | 
			
		||||
    # Replace the thinker module in the original model
 | 
			
		||||
    base_model.thinker = thinker
 | 
			
		||||
 | 
			
		||||
    # Load the processor and tokenizer
 | 
			
		||||
    processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    # Save the complete model along with its configurations
 | 
			
		||||
    base_model.save_pretrained(save_path)
 | 
			
		||||
    tokenizer.save_pretrained(save_path)
 | 
			
		||||
    processor.save_pretrained(save_path)
 | 
			
		||||
    print(f"Complete model, tokenizer, and processor configuration have been saved to {save_path}.")
 | 
			
		||||
 | 
			
		||||
    source_file = os.path.join(base_model_path, extra_file)
 | 
			
		||||
    target_file = os.path.join(save_path, extra_file)
 | 
			
		||||
    if os.path.exists(source_file):
 | 
			
		||||
        shutil.copy(source_file, target_file)
 | 
			
		||||
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
 | 
			
		||||
    else:
 | 
			
		||||
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(merge_lora)
 | 
			
		||||
    fire.Fire({"save_full": save_full_model, "merge_lora": merge_lora})
 | 
			
		||||
@ -203,7 +203,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
 | 
			
		||||
                delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
 | 
			
		||||
                # avoid conflict
 | 
			
		||||
                rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
 | 
			
		||||
                new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
 | 
			
		||||
                features["position_ids"], features["rope_deltas"] = (
 | 
			
		||||
                    new_position_ids.clone(),
 | 
			
		||||
 | 
			
		||||
@ -1405,7 +1405,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
                            video_grid_thw[num_video_tokens][2] // self.image_processor.merge_size,
 | 
			
		||||
                        )
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        * mm_inputs["video_second_per_grid"][num_video_tokens]
 | 
			
		||||
                        * mm_inputs["second_per_grid_ts"][num_video_tokens]
 | 
			
		||||
                        * 25  # FIXME hardcode of position_id_per_seconds=25
 | 
			
		||||
                    ).long()
 | 
			
		||||
                    t_ntoken_per_chunk = 50  # FIXME hardcode: [25 * 2]
 | 
			
		||||
 | 
			
		||||
@ -157,7 +157,7 @@ def load_model(
 | 
			
		||||
                model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
 | 
			
		||||
            else:
 | 
			
		||||
                model = load_class.from_pretrained(**init_kwargs)
 | 
			
		||||
                if load_class is AutoModelForTextToWaveform:
 | 
			
		||||
                if getattr(model.config, "model_type", None) == "qwen2_5_omni":
 | 
			
		||||
                    model = model.thinker  # use part of Omni model
 | 
			
		||||
 | 
			
		||||
        if model_args.mixture_of_depths == "convert":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user