mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[data] fix mm pluigin for qwen omni video training (#9388)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
		
							parent
							
								
									767b344fb4
								
							
						
					
					
						commit
						215580c77d
					
				@ -68,6 +68,8 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
    from transformers.video_processing_utils import BaseVideoProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    class EncodedImage(TypedDict):
 | 
			
		||||
        path: Optional[str]
 | 
			
		||||
@ -1482,6 +1484,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
@ -1499,7 +1502,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
 | 
			
		||||
            mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
 | 
			
		||||
            temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
 | 
			
		||||
            if "second_per_grid_ts" in processor.model_input_names:
 | 
			
		||||
                mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
 | 
			
		||||
@ -1818,6 +1821,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
 | 
			
		||||
        feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
@ -1836,7 +1840,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
 | 
			
		||||
            mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
 | 
			
		||||
            temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
 | 
			
		||||
            mm_inputs["video_second_per_grid"] = torch.tensor(
 | 
			
		||||
                [temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ def launch():
 | 
			
		||||
    if is_env_enabled("USE_MCA"):
 | 
			
		||||
    # force use torchrun
 | 
			
		||||
        os.environ["FORCE_TORCHRUN"] = "1"
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
 | 
			
		||||
        # launch distributed training
 | 
			
		||||
        nnodes = os.getenv("NNODES", "1")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user