[misc] fix import error (#9299)

This commit is contained in:
Yaowei Zheng 2025-10-17 17:46:27 +08:00 committed by GitHub
parent a442fa90ad
commit d9d67ba62d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 34 additions and 16 deletions

View File

@ -86,6 +86,14 @@ def is_rouge_available():
return _is_package_available("rouge_chinese")
def is_safetensors_available():
return _is_package_available("safetensors")
def is_sglang_available():
return _is_package_available("sglang")
def is_starlette_available():
return _is_package_available("sse_starlette")
@ -101,7 +109,3 @@ def is_uvicorn_available():
def is_vllm_available():
return _is_package_available("vllm")
def is_sglang_available():
return _is_package_available("sglang")

View File

@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
from typing import Any, Literal, Optional, Union
import torch
from omegaconf import OmegaConf
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from omegaconf import OmegaConf
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger
logger = get_logger(__name__)
@dataclass
class BaseModelArguments:
r"""Arguments pertaining to the model."""
@ -168,7 +170,7 @@ class BaseModelArguments:
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
use_kv_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)

View File

@ -81,6 +81,11 @@ class RayArguments:
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "deprecated"},
)
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)

View File

@ -28,11 +28,11 @@ if TYPE_CHECKING:
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable:
setattr(config, "use_cache", model_args.use_cache)
setattr(config, "use_cache", model_args.use_kv_cache)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache)
setattr(config.text_config, "use_cache", model_args.use_kv_cache)
if model_args.use_cache:
if model_args.use_kv_cache:
logger.info_rank0("KV cache is enabled for faster generation.")
else:
logger.info_rank0("KV cache is disabled.")

View File

@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
]:
setattr(text_config, "output_router_logits", True)
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
if model_type in [
"ernie4_5_moe",
"granitemoe",
"jamba",
"llama4",
"mixtral",
"olmoe",
"phimoe",
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:

View File

@ -26,16 +26,13 @@ import transformers
from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
)
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
from ..extras.packages import is_safetensors_available
if is_safetensors_available():