[misc] fix import error (#9299)

This commit is contained in:
Yaowei Zheng 2025-10-17 17:46:27 +08:00 committed by GitHub
parent a442fa90ad
commit d9d67ba62d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 34 additions and 16 deletions

View File

@ -86,6 +86,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese") return _is_package_available("rouge_chinese")
def is_safetensors_available():
return _is_package_available("safetensors")
def is_sglang_available():
return _is_package_available("sglang")
def is_starlette_available(): def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
@ -101,7 +109,3 @@ def is_uvicorn_available():
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")

View File

@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import torch import torch
from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from typing_extensions import Self from typing_extensions import Self
from omegaconf import OmegaConf
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger from ..extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass @dataclass
class BaseModelArguments: class BaseModelArguments:
r"""Arguments pertaining to the model.""" r"""Arguments pertaining to the model."""
@ -168,7 +170,7 @@ class BaseModelArguments:
default="offload", default="offload",
metadata={"help": "Path to offload model weights."}, metadata={"help": "Path to offload model weights."},
) )
use_cache: bool = field( use_kv_cache: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to use KV cache in generation."}, metadata={"help": "Whether or not to use KV cache in generation."},
) )

View File

@ -81,6 +81,11 @@ class RayArguments:
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "deprecated"},
)
def __post_init__(self): def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self) Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self) RayArguments.__post_init__(self)

View File

@ -28,11 +28,11 @@ if TYPE_CHECKING:
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable: if not is_trainable:
setattr(config, "use_cache", model_args.use_cache) setattr(config, "use_cache", model_args.use_kv_cache)
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache) setattr(config.text_config, "use_cache", model_args.use_kv_cache)
if model_args.use_cache: if model_args.use_kv_cache:
logger.info_rank0("KV cache is enabled for faster generation.") logger.info_rank0("KV cache is enabled for faster generation.")
else: else:
logger.info_rank0("KV cache is disabled.") logger.info_rank0("KV cache is disabled.")

View File

@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
]: ]:
setattr(text_config, "output_router_logits", True) setattr(text_config, "output_router_logits", True)
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: if model_type in [
"ernie4_5_moe",
"granitemoe",
"jamba",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]: elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:

View File

@ -26,16 +26,13 @@ import transformers
from peft import PeftModel from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import ( from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
)
from typing_extensions import override from typing_extensions import override
from ..extras import logging from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
from ..extras.packages import is_safetensors_available
if is_safetensors_available(): if is_safetensors_available():