mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
update arg name
This commit is contained in:
@@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
|
||||
)
|
||||
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_* to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
|
||||
# check exist load_balancing_loss_func for moe model
|
||||
if hasattr(modeling_arch, "load_balancing_loss_func"):
|
||||
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
|
||||
attn_implementation = getattr(config, "attn_implementation", "")
|
||||
else:
|
||||
attn_implementation = getattr(config, "_attn_implementation", "")
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
|
||||
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
|
||||
else:
|
||||
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
|
||||
raise ValueError("Current model does not support packing sequences for efficient training.")
|
||||
|
||||
Reference in New Issue
Block a user