From f4be51f35668fe58f2b20c38fa9f49bb3f828378 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 2 Apr 2024 14:26:31 +0800 Subject: [PATCH] add moe aux loss control #3085 Former-commit-id: b267aeb53fc49d2eeb0f3fc5ebe55e643f5db377 --- src/llmtuner/extras/misc.py | 8 +++----- src/llmtuner/hparams/model_args.py | 4 ++++ src/llmtuner/model/loader.py | 3 +-- src/llmtuner/model/patcher.py | 24 +++++++++++++++--------- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index c7b687e9..60cf153b 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -203,17 +203,15 @@ def torch_gc() -> None: torch.cuda.ipc_collect() -def try_download_model_from_ms(model_args: "ModelArguments") -> None: +def try_download_model_from_ms(model_args: "ModelArguments") -> str: if not use_modelscope() or os.path.exists(model_args.model_name_or_path): - return + return model_args.model_name_or_path try: from modelscope import snapshot_download revision = "master" if model_args.model_revision == "main" else model_args.model_revision - model_args.model_name_or_path = snapshot_download( - model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir - ) + return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) except ImportError: raise ImportError("Please install modelscope via `pip install modelscope -U`") diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index f96fb636..be71d32f 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -73,6 +73,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, ) + moe_aux_loss_coef: Optional[float] = field( + default=None, + metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + ) disable_gradient_checkpointing: bool = field( default=False, metadata={"help": "Whether or not to disable gradient checkpointing."}, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index b1816aa7..d05c0886 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -20,6 +20,7 @@ logger = get_logger(__name__) def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + model_args.model_name_or_path = try_download_model_from_ms(model_args) return { "trust_remote_code": True, "cache_dir": model_args.cache_dir, @@ -34,9 +35,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": Note: including inplace operation of model_args. """ - try_download_model_from_ms(model_args) init_kwargs = _get_init_kwargs(model_args) - tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, use_fast=model_args.use_fast_tokenizer, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 379b0c48..7132470a 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -290,11 +290,6 @@ def patch_config( if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - if getattr(config, "model_type", None) == "qwen": - setattr(config, "use_flash_attn", model_args.flash_attn) - for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: - setattr(config, dtype_name, model_args.compute_dtype == dtype) - _configure_attn_implementation(config, model_args, init_kwargs) _configure_rope(config, model_args, is_trainable) _configure_longlora(config, model_args, is_trainable) @@ -304,11 +299,25 @@ def patch_config( setattr(config, "use_cache", True) logger.info("Using KV cache for faster generation.") + if model_args.moe_aux_loss_coef is not None: + if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"]: + setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) + elif getattr(config, "model_type", None) == "deepseek": + setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) + + if getattr(config, "model_type", None) == "qwen": + setattr(config, "use_flash_attn", model_args.flash_attn) + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, model_args.compute_dtype == dtype) + + if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: + setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn + init_kwargs["torch_dtype"] = model_args.compute_dtype if not is_deepspeed_zero3_enabled(): init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage if init_kwargs["low_cpu_mem_usage"]: - if "device_map" not in init_kwargs: # quant models cannot use auto device map + if "device_map" not in init_kwargs: init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()} if init_kwargs["device_map"] == "auto": @@ -333,9 +342,6 @@ def patch_model( setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) - if is_trainable and getattr(model.config, "model_type", None) == "qwen2" and model_args.flash_attn: - setattr(model.config, "use_cache", False) # qwen2 does not support use_cache when using flashattn - if model_args.resize_vocab: _resize_embedding_layer(model, tokenizer)