mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] fix import error (#9296)
This commit is contained in:
		
							parent
							
								
									8c341cbaae
								
							
						
					
					
						commit
						a442fa90ad
					
				@ -14,8 +14,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from typing import TYPE_CHECKING
 | 
					from typing import TYPE_CHECKING
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ...extras import logging
 | 
					from ...extras import logging
 | 
				
			||||||
from ...extras.constants import AttentionFunction
 | 
					from ...extras.constants import AttentionFunction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -30,6 +28,8 @@ logger = logging.get_logger(__name__)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
 | 
					def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
 | 
				
			||||||
 | 
					    from transformers.utils import is_flash_attn_2_available
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if getattr(config, "model_type", None) == "gemma2":
 | 
					    if getattr(config, "model_type", None) == "gemma2":
 | 
				
			||||||
        if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
 | 
					        if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
 | 
				
			||||||
            if is_flash_attn_2_available():
 | 
					            if is_flash_attn_2_available():
 | 
				
			||||||
@ -51,6 +51,8 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
 | 
				
			|||||||
        requested_attn_implementation = "eager"
 | 
					        requested_attn_implementation = "eager"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif model_args.flash_attn == AttentionFunction.SDPA:
 | 
					    elif model_args.flash_attn == AttentionFunction.SDPA:
 | 
				
			||||||
 | 
					        from transformers.utils import is_torch_sdpa_available
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not is_torch_sdpa_available():
 | 
					        if not is_torch_sdpa_available():
 | 
				
			||||||
            logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
 | 
					            logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user