mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 04:40:35 +08:00
[model] switch to gptqmodel (#8108)
This commit is contained in:
@@ -29,10 +29,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||
if is_flash_attn_2_available():
|
||||
if model_args.flash_attn != AttentionFunction.FA2:
|
||||
|
||||
Reference in New Issue
Block a user