mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] fix import error (#9299)
This commit is contained in:
		
							parent
							
								
									a442fa90ad
								
							
						
					
					
						commit
						d9d67ba62d
					
				@ -86,6 +86,14 @@ def is_rouge_available():
 | 
			
		||||
    return _is_package_available("rouge_chinese")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_safetensors_available():
 | 
			
		||||
    return _is_package_available("safetensors")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_sglang_available():
 | 
			
		||||
    return _is_package_available("sglang")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_starlette_available():
 | 
			
		||||
    return _is_package_available("sse_starlette")
 | 
			
		||||
 | 
			
		||||
@ -101,7 +109,3 @@ def is_uvicorn_available():
 | 
			
		||||
 | 
			
		||||
def is_vllm_available():
 | 
			
		||||
    return _is_package_available("vllm")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_sglang_available():
 | 
			
		||||
    return _is_package_available("sglang")
 | 
			
		||||
 | 
			
		||||
@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
 | 
			
		||||
from typing import Any, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from transformers.training_args import _convert_str_dict
 | 
			
		||||
from typing_extensions import Self
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class BaseModelArguments:
 | 
			
		||||
    r"""Arguments pertaining to the model."""
 | 
			
		||||
@ -168,7 +170,7 @@ class BaseModelArguments:
 | 
			
		||||
        default="offload",
 | 
			
		||||
        metadata={"help": "Path to offload model weights."},
 | 
			
		||||
    )
 | 
			
		||||
    use_cache: bool = field(
 | 
			
		||||
    use_kv_cache: bool = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use KV cache in generation."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -81,6 +81,11 @@ class RayArguments:
 | 
			
		||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
 | 
			
		||||
    r"""Arguments pertaining to the trainer."""
 | 
			
		||||
 | 
			
		||||
    overwrite_output_dir: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "deprecated"},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        Seq2SeqTrainingArguments.__post_init__(self)
 | 
			
		||||
        RayArguments.__post_init__(self)
 | 
			
		||||
 | 
			
		||||
@ -94,7 +94,7 @@ def _description_based_initialization(
 | 
			
		||||
            if len(valid_token_ids) == 0:
 | 
			
		||||
                # Fallback: use mean of all existing embeddings
 | 
			
		||||
                logger.warning_rank0(
 | 
			
		||||
                    f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
 | 
			
		||||
                    f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
 | 
			
		||||
                    "Using mean of existing embeddings."
 | 
			
		||||
                )
 | 
			
		||||
                base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
 | 
			
		||||
 | 
			
		||||
@ -28,11 +28,11 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    if not is_trainable:
 | 
			
		||||
        setattr(config, "use_cache", model_args.use_cache)
 | 
			
		||||
        setattr(config, "use_cache", model_args.use_kv_cache)
 | 
			
		||||
        if hasattr(config, "text_config"):
 | 
			
		||||
            setattr(config.text_config, "use_cache", model_args.use_cache)
 | 
			
		||||
            setattr(config.text_config, "use_cache", model_args.use_kv_cache)
 | 
			
		||||
 | 
			
		||||
        if model_args.use_cache:
 | 
			
		||||
        if model_args.use_kv_cache:
 | 
			
		||||
            logger.info_rank0("KV cache is enabled for faster generation.")
 | 
			
		||||
        else:
 | 
			
		||||
            logger.info_rank0("KV cache is disabled.")
 | 
			
		||||
 | 
			
		||||
@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
 | 
			
		||||
    ]:
 | 
			
		||||
        setattr(text_config, "output_router_logits", True)
 | 
			
		||||
 | 
			
		||||
    if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
 | 
			
		||||
    if model_type in [
 | 
			
		||||
        "ernie4_5_moe",
 | 
			
		||||
        "granitemoe",
 | 
			
		||||
        "jamba",
 | 
			
		||||
        "llama4",
 | 
			
		||||
        "mixtral",
 | 
			
		||||
        "olmoe",
 | 
			
		||||
        "phimoe",
 | 
			
		||||
        "qwen2_moe",
 | 
			
		||||
        "qwen3_moe",
 | 
			
		||||
    ]:
 | 
			
		||||
        setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
			
		||||
 | 
			
		||||
    elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
 | 
			
		||||
 | 
			
		||||
@ -26,16 +26,13 @@ import transformers
 | 
			
		||||
from peft import PeftModel
 | 
			
		||||
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
 | 
			
		||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
 | 
			
		||||
from transformers.utils import (
 | 
			
		||||
    SAFE_WEIGHTS_NAME,
 | 
			
		||||
    WEIGHTS_NAME,
 | 
			
		||||
    is_safetensors_available,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
 | 
			
		||||
from ..extras.packages import is_safetensors_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_safetensors_available():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user