mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] fix import error (#9299)
This commit is contained in:
		
							parent
							
								
									a442fa90ad
								
							
						
					
					
						commit
						d9d67ba62d
					
				@ -86,6 +86,14 @@ def is_rouge_available():
 | 
				
			|||||||
    return _is_package_available("rouge_chinese")
 | 
					    return _is_package_available("rouge_chinese")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_safetensors_available():
 | 
				
			||||||
 | 
					    return _is_package_available("safetensors")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_sglang_available():
 | 
				
			||||||
 | 
					    return _is_package_available("sglang")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_starlette_available():
 | 
					def is_starlette_available():
 | 
				
			||||||
    return _is_package_available("sse_starlette")
 | 
					    return _is_package_available("sse_starlette")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -101,7 +109,3 @@ def is_uvicorn_available():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def is_vllm_available():
 | 
					def is_vllm_available():
 | 
				
			||||||
    return _is_package_available("vllm")
 | 
					    return _is_package_available("vllm")
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_sglang_available():
 | 
					 | 
				
			||||||
    return _is_package_available("sglang")
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
 | 
				
			|||||||
from typing import Any, Literal, Optional, Union
 | 
					from typing import Any, Literal, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from omegaconf import OmegaConf
 | 
				
			||||||
from transformers.training_args import _convert_str_dict
 | 
					from transformers.training_args import _convert_str_dict
 | 
				
			||||||
from typing_extensions import Self
 | 
					from typing_extensions import Self
 | 
				
			||||||
from omegaconf import OmegaConf
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
 | 
					from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
 | 
				
			||||||
from ..extras.logging import get_logger
 | 
					from ..extras.logging import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = get_logger(__name__)
 | 
					logger = get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class BaseModelArguments:
 | 
					class BaseModelArguments:
 | 
				
			||||||
    r"""Arguments pertaining to the model."""
 | 
					    r"""Arguments pertaining to the model."""
 | 
				
			||||||
@ -168,7 +170,7 @@ class BaseModelArguments:
 | 
				
			|||||||
        default="offload",
 | 
					        default="offload",
 | 
				
			||||||
        metadata={"help": "Path to offload model weights."},
 | 
					        metadata={"help": "Path to offload model weights."},
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    use_cache: bool = field(
 | 
					    use_kv_cache: bool = field(
 | 
				
			||||||
        default=True,
 | 
					        default=True,
 | 
				
			||||||
        metadata={"help": "Whether or not to use KV cache in generation."},
 | 
					        metadata={"help": "Whether or not to use KV cache in generation."},
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -81,6 +81,11 @@ class RayArguments:
 | 
				
			|||||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
 | 
					class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
 | 
				
			||||||
    r"""Arguments pertaining to the trainer."""
 | 
					    r"""Arguments pertaining to the trainer."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    overwrite_output_dir: bool = field(
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        metadata={"help": "deprecated"},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init__(self):
 | 
					    def __post_init__(self):
 | 
				
			||||||
        Seq2SeqTrainingArguments.__post_init__(self)
 | 
					        Seq2SeqTrainingArguments.__post_init__(self)
 | 
				
			||||||
        RayArguments.__post_init__(self)
 | 
					        RayArguments.__post_init__(self)
 | 
				
			||||||
 | 
				
			|||||||
@ -94,7 +94,7 @@ def _description_based_initialization(
 | 
				
			|||||||
            if len(valid_token_ids) == 0:
 | 
					            if len(valid_token_ids) == 0:
 | 
				
			||||||
                # Fallback: use mean of all existing embeddings
 | 
					                # Fallback: use mean of all existing embeddings
 | 
				
			||||||
                logger.warning_rank0(
 | 
					                logger.warning_rank0(
 | 
				
			||||||
                    f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
 | 
					                    f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
 | 
				
			||||||
                    "Using mean of existing embeddings."
 | 
					                    "Using mean of existing embeddings."
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
 | 
					                base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
 | 
				
			||||||
 | 
				
			|||||||
@ -28,11 +28,11 @@ if TYPE_CHECKING:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
					def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
				
			||||||
    if not is_trainable:
 | 
					    if not is_trainable:
 | 
				
			||||||
        setattr(config, "use_cache", model_args.use_cache)
 | 
					        setattr(config, "use_cache", model_args.use_kv_cache)
 | 
				
			||||||
        if hasattr(config, "text_config"):
 | 
					        if hasattr(config, "text_config"):
 | 
				
			||||||
            setattr(config.text_config, "use_cache", model_args.use_cache)
 | 
					            setattr(config.text_config, "use_cache", model_args.use_kv_cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if model_args.use_cache:
 | 
					        if model_args.use_kv_cache:
 | 
				
			||||||
            logger.info_rank0("KV cache is enabled for faster generation.")
 | 
					            logger.info_rank0("KV cache is enabled for faster generation.")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            logger.info_rank0("KV cache is disabled.")
 | 
					            logger.info_rank0("KV cache is disabled.")
 | 
				
			||||||
 | 
				
			|||||||
@ -154,7 +154,17 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
 | 
				
			|||||||
    ]:
 | 
					    ]:
 | 
				
			||||||
        setattr(text_config, "output_router_logits", True)
 | 
					        setattr(text_config, "output_router_logits", True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if model_type in ["ernie4_5_moe", "granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
 | 
					    if model_type in [
 | 
				
			||||||
 | 
					        "ernie4_5_moe",
 | 
				
			||||||
 | 
					        "granitemoe",
 | 
				
			||||||
 | 
					        "jamba",
 | 
				
			||||||
 | 
					        "llama4",
 | 
				
			||||||
 | 
					        "mixtral",
 | 
				
			||||||
 | 
					        "olmoe",
 | 
				
			||||||
 | 
					        "phimoe",
 | 
				
			||||||
 | 
					        "qwen2_moe",
 | 
				
			||||||
 | 
					        "qwen3_moe",
 | 
				
			||||||
 | 
					    ]:
 | 
				
			||||||
        setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
					        setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
 | 
					    elif text_config and getattr(text_config, "model_type", None) in ["qwen3_moe"]:
 | 
				
			||||||
 | 
				
			|||||||
@ -26,16 +26,13 @@ import transformers
 | 
				
			|||||||
from peft import PeftModel
 | 
					from peft import PeftModel
 | 
				
			||||||
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
 | 
					from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
 | 
				
			||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
 | 
					from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
 | 
				
			||||||
from transformers.utils import (
 | 
					from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
 | 
				
			||||||
    SAFE_WEIGHTS_NAME,
 | 
					 | 
				
			||||||
    WEIGHTS_NAME,
 | 
					 | 
				
			||||||
    is_safetensors_available,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from typing_extensions import override
 | 
					from typing_extensions import override
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..extras import logging
 | 
					from ..extras import logging
 | 
				
			||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
					from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
				
			||||||
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
 | 
					from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
 | 
				
			||||||
 | 
					from ..extras.packages import is_safetensors_available
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if is_safetensors_available():
 | 
					if is_safetensors_available():
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user