update arg name

This commit is contained in:
hiyouga
2024-07-03 23:23:24 +08:00
parent 575a02a23d
commit 8a6a7b9c8a
3 changed files with 20 additions and 34 deletions

View File

@@ -283,28 +283,15 @@ def patch_for_multipack(model_type, model_name, attn_implementation):
)
def patch_remote(model_name, config_name, modeling_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
# check exist load_balancing_loss_func for moe model
if hasattr(modeling_arch, "load_balancing_loss_func"):
modeling_arch.load_balancing_loss_func = load_balancing_loss_func
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
model_type = getattr(config, "model_type", None)
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", "")
else:
attn_implementation = getattr(config, "_attn_implementation", "")
if getattr(config, "model_type", None) in SUPPORTED_CLASS_EFFECIENT_PACKING:
patch_for_multipack(getattr(config, "model_type", None), model_args.model_name_or_path, attn_implementation)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
patch_for_block_diag_attn(model_type)
logger.info("Using packing sequences without cross-contamination attention for efficient training.")
else:
raise ValueError("Current model does not support packing sequences for efficient training. Please set config `efficient_packing` is False")
raise ValueError("Current model does not support packing sequences for efficient training.")