mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 02:12:14 +08:00
[misc] fix import error (#9299)
This commit is contained in:
parent
a442fa90ad
commit
d9d67ba62d
@ -86,6 +86,14 @@ def is_rouge_available():
|
|||||||
return _is_package_available("rouge_chinese")
|
return _is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
|
def is_safetensors_available():
|
||||||
|
return _is_package_available("safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def is_sglang_available():
|
||||||
|
return _is_package_available("sglang")
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return _is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
@ -101,7 +109,3 @@ def is_uvicorn_available():
|
|||||||
|
|
||||||
def is_vllm_available():
|
def is_vllm_available():
|
||||||
return _is_package_available("vllm")
|
return _is_package_available("vllm")
|
||||||
|
|
||||||
|
|
||||||
def is_sglang_available():
|
|
||||||
return _is_package_available("sglang")
|
|
||||||
|
|||||||
@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
|
|||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelArguments:
|
class BaseModelArguments:
|
||||||
r"""Arguments pertaining to the model."""
|
r"""Arguments pertaining to the model."""
|
||||||
@ -168,7 +170,7 @@ class BaseModelArguments:
|
|||||||
default="offload",
|
default="offload",
|
||||||
metadata={"help": "Path to offload model weights."},
|
metadata={"help": "Path to offload model weights."},
|
||||||
)
|
)
|
||||||
use_cache: bool = field(
|
use_kv_cache: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -81,6 +81,11 @@ class RayArguments:
|
|||||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
|
overwrite_output_dir: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "deprecated"},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
Seq2SeqTrainingArguments.__post_init__(self)
|
Seq2SeqTrainingArguments.__post_init__(self)
|
||||||
RayArguments.__post_init__(self)
|
RayArguments.__post_init__(self)
|
||||||
|
|||||||
@ -28,11 +28,11 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
setattr(config, "use_cache", model_args.use_cache)
|
setattr(config, "use_cache", model_args.use_kv_cache)
|
||||||
if hasattr(config, "text_config"):
|
if hasattr(config, "text_config"):
|
||||||
setattr(config.text_config, "use_cache", model_args.use_cache)
|
setattr(config.text_config, "use_cache", model_args.use_kv_cache)
|
||||||
|
|
||||||
if model_args.use_cache:
|
if model_args.use_kv_cache:
|
||||||
logger.info_rank0("KV cache is enabled for faster generation.")
|
logger.info_rank0("KV cache is enabled for faster generation.")
|
||||||
else:
|
else:
|
||||||
logger.info_rank0("KV cache is disabled.")
|
logger.info_rank0("KV cache is disabled.")
|
||||||
|
|||||||
@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
|||||||
]:
|
]:
|
||||||
setattr(text_config, "output_router_logits", True)
|
setattr(text_config, "output_router_logits", True)
|
||||||
|
|
||||||
if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
|
if model_type in [
|
||||||
|
"ernie4_5_moe",
|
||||||
|
"granitemoe",
|
||||||
|
"jamba",
|
||||||
|
"llama4",
|
||||||
|
"mixtral",
|
||||||
|
"olmoe",
|
||||||
|
"phimoe",
|
||||||
|
"qwen2_moe",
|
||||||
|
"qwen3_moe",
|
||||||
|
]:
|
||||||
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||||
|
|
||||||
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
|
elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
|
||||||
|
|||||||
@ -26,16 +26,13 @@ import transformers
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
|
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||||
from transformers.utils import (
|
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
is_safetensors_available,
|
|
||||||
)
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
|
||||||
|
from ..extras.packages import is_safetensors_available
|
||||||
|
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user