diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index a622dd1e..99b55f55 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -86,6 +86,14 @@ def is_rouge_available(): return _is_package_available("rouge_chinese") +def is_safetensors_available(): + return _is_package_available("safetensors") + + +def is_sglang_available(): + return _is_package_available("sglang") + + def is_starlette_available(): return _is_package_available("sse_starlette") @@ -101,7 +109,3 @@ def is_uvicorn_available(): def is_vllm_available(): return _is_package_available("vllm") - - -def is_sglang_available(): - return _is_package_available("sglang") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 7762194a..45068abd 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields from typing import Any, Literal, Optional, Union import torch +from omegaconf import OmegaConf from transformers.training_args import _convert_str_dict from typing_extensions import Self -from omegaconf import OmegaConf from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling from ..extras.logging import get_logger + logger = get_logger(__name__) + @dataclass class BaseModelArguments: r"""Arguments pertaining to the model.""" @@ -168,7 +170,7 @@ class BaseModelArguments: default="offload", metadata={"help": "Path to offload model weights."}, ) - use_cache: bool = field( + use_kv_cache: bool = field( default=True, metadata={"help": "Whether or not to use KV cache in generation."}, ) diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 38cdf6af..84b657a9 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -81,6 +81,11 @@ class RayArguments: class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): r"""Arguments pertaining to the trainer.""" + overwrite_output_dir: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + def __post_init__(self): Seq2SeqTrainingArguments.__post_init__(self) RayArguments.__post_init__(self) diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py index ef53a789..b503f3b9 100644 --- a/src/llamafactory/model/model_utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -94,7 +94,7 @@ def _description_based_initialization( if len(valid_token_ids) == 0: # Fallback: use mean of all existing embeddings logger.warning_rank0( - f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. " + f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. " "Using mean of existing embeddings." ) base_embedding = embed_weight[:-num_new_tokens].mean(dim=0) diff --git a/src/llamafactory/model/model_utils/kv_cache.py b/src/llamafactory/model/model_utils/kv_cache.py index cd2c119f..4f622f73 100644 --- a/src/llamafactory/model/model_utils/kv_cache.py +++ b/src/llamafactory/model/model_utils/kv_cache.py @@ -28,11 +28,11 @@ if TYPE_CHECKING: def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable: - setattr(config, "use_cache", model_args.use_cache) + setattr(config, "use_cache", model_args.use_kv_cache) if hasattr(config, "text_config"): - setattr(config.text_config, "use_cache", model_args.use_cache) + setattr(config.text_config, "use_cache", model_args.use_kv_cache) - if model_args.use_cache: + if model_args.use_kv_cache: logger.info_rank0("KV cache is enabled for faster generation.") else: logger.info_rank0("KV cache is disabled.") diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 0a541415..baf84066 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t ]: setattr(text_config, "output_router_logits", True) - if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: + if model_type in [ + "ernie4_5_moe", + "granitemoe", + "jamba", + "llama4", + "mixtral", + "olmoe", + "phimoe", + "qwen2_moe", + "qwen3_moe", + ]: setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]: diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 82a37a76..5619568e 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -26,16 +26,13 @@ import transformers from peft import PeftModel from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length -from transformers.utils import ( - SAFE_WEIGHTS_NAME, - WEIGHTS_NAME, - is_safetensors_available, -) +from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from typing_extensions import override from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.misc import get_peak_memory, is_env_enabled, use_ray +from ..extras.packages import is_safetensors_available if is_safetensors_available():