mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-28 11:14:18 +08:00
[misc] fix import error (#9299)
This commit is contained in:
parent
a442fa90ad
commit
d9d67ba62d
@ -86,6 +86,14 @@ def is_rouge_available():
|
||||
return _is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_safetensors_available():
|
||||
return _is_package_available("safetensors")
|
||||
|
||||
|
||||
def is_sglang_available():
|
||||
return _is_package_available("sglang")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _is_package_available("sse_starlette")
|
||||
|
||||
@ -101,7 +109,3 @@ def is_uvicorn_available():
|
||||
|
||||
def is_vllm_available():
|
||||
return _is_package_available("vllm")
|
||||
|
||||
|
||||
def is_sglang_available():
|
||||
return _is_package_available("sglang")
|
||||
|
||||
@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
r"""Arguments pertaining to the model."""
|
||||
@ -168,7 +170,7 @@ class BaseModelArguments:
|
||||
default="offload",
|
||||
metadata={"help": "Path to offload model weights."},
|
||||
)
|
||||
use_cache: bool = field(
|
||||
use_kv_cache: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
|
||||
@ -81,6 +81,11 @@ class RayArguments:
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "deprecated"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
RayArguments.__post_init__(self)
|
||||
|
||||
@ -94,7 +94,7 @@ def _description_based_initialization(
|
||||
if len(valid_token_ids) == 0:
|
||||
# Fallback: use mean of all existing embeddings
|
||||
logger.warning_rank0(
|
||||
f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
|
||||
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
|
||||
"Using mean of existing embeddings."
|
||||
)
|
||||
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
|
||||
|
||||
@ -28,11 +28,11 @@ if TYPE_CHECKING:
|
||||
|
||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable:
|
||||
setattr(config, "use_cache", model_args.use_cache)
|
||||
setattr(config, "use_cache", model_args.use_kv_cache)
|
||||
if hasattr(config, "text_config"):
|
||||
setattr(config.text_config, "use_cache", model_args.use_cache)
|
||||
setattr(config.text_config, "use_cache", model_args.use_kv_cache)
|
||||
|
||||
if model_args.use_cache:
|
||||
if model_args.use_kv_cache:
|
||||
logger.info_rank0("KV cache is enabled for faster generation.")
|
||||
else:
|
||||
logger.info_rank0("KV cache is disabled.")
|
||||
|
||||
@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
||||
]:
|
||||
setattr(text_config, "output_router_logits", True)
|
||||
|
||||
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
||||
if model_type in [
|
||||
"ernie4_5_moe",
|
||||
"granitemoe",
|
||||
"jamba",
|
||||
"llama4",
|
||||
"mixtral",
|
||||
"olmoe",
|
||||
"phimoe",
|
||||
"qwen2_moe",
|
||||
"qwen3_moe",
|
||||
]:
|
||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
|
||||
|
||||
@ -26,16 +26,13 @@ import transformers
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
)
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
||||
from ..extras.packages import is_safetensors_available
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user