mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	imporve log
Former-commit-id: a6abf375975ffea3d51e1b944c9855b5f62ffac8
This commit is contained in:
		
							parent
							
								
									3b843ac9d4
								
							
						
					
					
						commit
						647c51a772
					
				@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        try:
 | 
			
		||||
            asyncio.get_event_loop()
 | 
			
		||||
        except RuntimeError:
 | 
			
		||||
            logger.warning_once("There is no current event loop, creating a new one.")
 | 
			
		||||
            logger.warning_rank0_once("There is no current event loop, creating a new one.")
 | 
			
		||||
            loop = asyncio.new_event_loop()
 | 
			
		||||
            asyncio.set_event_loop(loop)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,12 +56,12 @@ def merge_dataset(
 | 
			
		||||
        return all_datasets[0]
 | 
			
		||||
    elif data_args.mix_strategy == "concat":
 | 
			
		||||
        if data_args.streaming:
 | 
			
		||||
            logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
 | 
			
		||||
            logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")
 | 
			
		||||
 | 
			
		||||
        return concatenate_datasets(all_datasets)
 | 
			
		||||
    elif data_args.mix_strategy.startswith("interleave"):
 | 
			
		||||
        if not data_args.streaming:
 | 
			
		||||
            logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
 | 
			
		||||
            logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
 | 
			
		||||
 | 
			
		||||
        return interleave_datasets(
 | 
			
		||||
            datasets=all_datasets,
 | 
			
		||||
 | 
			
		||||
@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from datasets import DatasetDict, load_dataset, load_from_disk
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import FILEEXT2TYPE
 | 
			
		||||
from ..extras.misc import has_tokenized_data
 | 
			
		||||
from ..extras.misc import check_version, has_tokenized_data
 | 
			
		||||
from .aligner import align_dataset
 | 
			
		||||
from .data_utils import merge_dataset, split_dataset
 | 
			
		||||
from .parser import get_dataset_list
 | 
			
		||||
@ -84,7 +83,7 @@ def _load_single_dataset(
 | 
			
		||||
        raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.load_from == "ms_hub":
 | 
			
		||||
        require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
 | 
			
		||||
        check_version("modelscope>=1.11.0", mandatory=True)
 | 
			
		||||
        from modelscope import MsDataset  # type: ignore
 | 
			
		||||
        from modelscope.utils.config_ds import MS_DATASETS_CACHE  # type: ignore
 | 
			
		||||
 | 
			
		||||
@ -103,7 +102,7 @@ def _load_single_dataset(
 | 
			
		||||
            dataset = dataset.to_hf_dataset()
 | 
			
		||||
 | 
			
		||||
    elif dataset_attr.load_from == "om_hub":
 | 
			
		||||
        require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
 | 
			
		||||
        check_version("openmind>=0.8.0", mandatory=True)
 | 
			
		||||
        from openmind import OmDataset  # type: ignore
 | 
			
		||||
        from openmind.utils.hub import OM_DATASETS_CACHE  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -73,10 +73,14 @@ class BasePlugin:
 | 
			
		||||
        Validates if this model accepts the input modalities.
 | 
			
		||||
        """
 | 
			
		||||
        if len(images) != 0 and self.image_token is None:
 | 
			
		||||
            raise ValueError("This model does not support image input.")
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "This model does not support image input. Please check whether the correct `template` is used."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0 and self.video_token is None:
 | 
			
		||||
            raise ValueError("This model does not support video input.")
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "This model does not support video input. Please check whether the correct `template` is used."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
 | 
			
		||||
        r"""
 | 
			
		||||
 | 
			
		||||
@ -15,10 +15,10 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.misc import check_version
 | 
			
		||||
from .data_utils import Role
 | 
			
		||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
 | 
			
		||||
from .mm_plugin import get_mm_plugin
 | 
			
		||||
@ -365,7 +365,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
 | 
			
		||||
            raise ValueError(f"Template {data_args.template} does not exist.")
 | 
			
		||||
 | 
			
		||||
    if template.mm_plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
        require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
 | 
			
		||||
        check_version("transformers>=4.45.0")
 | 
			
		||||
 | 
			
		||||
    if data_args.train_on_prompt and template.efficient_eos:
 | 
			
		||||
        raise ValueError("Current template does not support `train_on_prompt`.")
 | 
			
		||||
 | 
			
		||||
@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
 | 
			
		||||
 | 
			
		||||
class _Logger(logging.Logger):
 | 
			
		||||
    r"""
 | 
			
		||||
    A logger that supports info_rank0 and warning_once.
 | 
			
		||||
    A logger that supports rank0 logging.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def info_rank0(self, *args, **kwargs) -> None:
 | 
			
		||||
@ -77,7 +77,7 @@ class _Logger(logging.Logger):
 | 
			
		||||
    def warning_rank0(self, *args, **kwargs) -> None:
 | 
			
		||||
        self.warning(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def warning_once(self, *args, **kwargs) -> None:
 | 
			
		||||
    def warning_rank0_once(self, *args, **kwargs) -> None:
 | 
			
		||||
        self.warning(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache(None)
 | 
			
		||||
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
 | 
			
		||||
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
 | 
			
		||||
    if int(os.getenv("LOCAL_RANK", "0")) == 0:
 | 
			
		||||
        self.warning(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.Logger.info_rank0 = info_rank0
 | 
			
		||||
logging.Logger.warning_rank0 = warning_rank0
 | 
			
		||||
logging.Logger.warning_once = warning_once
 | 
			
		||||
logging.Logger.warning_rank0_once = warning_rank0_once
 | 
			
		||||
 | 
			
		||||
@ -73,19 +73,31 @@ class AverageMeter:
 | 
			
		||||
        self.avg = self.sum / self.count
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_version(requirement: str, mandatory: bool = False) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Optionally checks the package version.
 | 
			
		||||
    """
 | 
			
		||||
    if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
 | 
			
		||||
        logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if mandatory:
 | 
			
		||||
        hint = f"To fix: run `pip install {requirement}`."
 | 
			
		||||
    else:
 | 
			
		||||
        hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
 | 
			
		||||
 | 
			
		||||
    require_version(requirement, hint)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_dependencies() -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Checks the version of the required packages.
 | 
			
		||||
    """
 | 
			
		||||
    if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
 | 
			
		||||
        logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
 | 
			
		||||
    require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
 | 
			
		||||
    require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
 | 
			
		||||
    require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
 | 
			
		||||
    require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.46.1")
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.1.0")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.0.1")
 | 
			
		||||
    check_version("peft>=0.11.1,<=0.12.0")
 | 
			
		||||
    check_version("trl>=0.8.6,<=0.9.6")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
 | 
			
		||||
@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
 | 
			
		||||
        return model_args.model_name_or_path
 | 
			
		||||
 | 
			
		||||
    if use_modelscope():
 | 
			
		||||
        require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
 | 
			
		||||
        check_version("modelscope>=1.11.0", mandatory=True)
 | 
			
		||||
        from modelscope import snapshot_download  # type: ignore
 | 
			
		||||
 | 
			
		||||
        revision = "master" if model_args.model_revision == "main" else model_args.model_revision
 | 
			
		||||
@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if use_openmind():
 | 
			
		||||
        require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
 | 
			
		||||
        check_version("openmind>=0.8.0", mandatory=True)
 | 
			
		||||
        from openmind.utils.hub import snapshot_download  # type: ignore
 | 
			
		||||
 | 
			
		||||
        return snapshot_download(
 | 
			
		||||
 | 
			
		||||
@ -29,11 +29,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.trainer_utils import get_last_checkpoint
 | 
			
		||||
from transformers.training_args import ParallelMode
 | 
			
		||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import CHECKPOINT_NAMES
 | 
			
		||||
from ..extras.misc import check_dependencies, get_current_device
 | 
			
		||||
from ..extras.misc import check_dependencies, check_version, get_current_device
 | 
			
		||||
from .data_args import DataArguments
 | 
			
		||||
from .evaluation_args import EvaluationArguments
 | 
			
		||||
from .finetuning_args import FinetuningArguments
 | 
			
		||||
@ -124,38 +123,35 @@ def _check_extra_dependencies(
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
    training_args: Optional["TrainingArguments"] = None,
 | 
			
		||||
) -> None:
 | 
			
		||||
    if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
 | 
			
		||||
        logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if model_args.use_unsloth:
 | 
			
		||||
        require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
 | 
			
		||||
        check_version("unsloth", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if model_args.enable_liger_kernel:
 | 
			
		||||
        require_version("liger-kernel", "To fix: pip install liger-kernel")
 | 
			
		||||
        check_version("liger-kernel", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if model_args.mixture_of_depths is not None:
 | 
			
		||||
        require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
 | 
			
		||||
        check_version("mixture-of-depth>=1.1.6", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
 | 
			
		||||
        check_version("vllm>=0.4.3,<0.6.7")
 | 
			
		||||
        check_version("vllm", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
        require_version("galore_torch", "To fix: pip install galore_torch")
 | 
			
		||||
        check_version("galore_torch", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_badam:
 | 
			
		||||
        require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
 | 
			
		||||
        check_version("badam>=1.2.1", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_adam_mini:
 | 
			
		||||
        require_version("adam-mini", "To fix: pip install adam-mini")
 | 
			
		||||
        check_version("adam-mini", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.plot_loss:
 | 
			
		||||
        require_version("matplotlib", "To fix: pip install matplotlib")
 | 
			
		||||
        check_version("matplotlib", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if training_args is not None and training_args.predict_with_generate:
 | 
			
		||||
        require_version("jieba", "To fix: pip install jieba")
 | 
			
		||||
        require_version("nltk", "To fix: pip install nltk")
 | 
			
		||||
        require_version("rouge_chinese", "To fix: pip install rouge-chinese")
 | 
			
		||||
        check_version("jieba", mandatory=True)
 | 
			
		||||
        check_version("nltk", mandatory=True)
 | 
			
		||||
        check_version("rouge_chinese", mandatory=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,9 @@
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.misc import check_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -35,8 +35,8 @@ def configure_attn_implementation(
 | 
			
		||||
    if getattr(config, "model_type", None) == "gemma2" and is_trainable:
 | 
			
		||||
        if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
 | 
			
		||||
            if is_flash_attn_2_available():
 | 
			
		||||
                require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
 | 
			
		||||
                require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
 | 
			
		||||
                check_version("transformers>=4.42.4")
 | 
			
		||||
                check_version("flash_attn>=2.6.3")
 | 
			
		||||
                if model_args.flash_attn != "fa2":
 | 
			
		||||
                    logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
 | 
			
		||||
                    model_args.flash_attn = "fa2"
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
 | 
			
		||||
    if "value" in inspect.signature(self._set_gradient_checkpointing).parameters:  # old GC format
 | 
			
		||||
        self.apply(partial(self._set_gradient_checkpointing, value=True))
 | 
			
		||||
        self.enable_input_require_grads()
 | 
			
		||||
        logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
 | 
			
		||||
        logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
 | 
			
		||||
    else:  # have already enabled input require gradients
 | 
			
		||||
        self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import (
 | 
			
		||||
    apply_rotary_pos_emb,
 | 
			
		||||
    repeat_kv,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
 | 
			
		||||
from ...extras.misc import check_version
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_llama_patch() -> None:
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.46.1")
 | 
			
		||||
    LlamaAttention.forward = llama_attention_forward
 | 
			
		||||
    LlamaFlashAttention2.forward = llama_flash_attention_2_forward
 | 
			
		||||
    LlamaSdpaAttention.forward = llama_sdpa_attention_forward
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ...extras.misc import check_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -26,7 +27,7 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
 | 
			
		||||
    require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
 | 
			
		||||
    check_version("deepspeed>=0.13.0")
 | 
			
		||||
    from deepspeed.utils import set_z3_leaf_modules  # type: ignore
 | 
			
		||||
 | 
			
		||||
    set_z3_leaf_modules(model, leaf_modules)
 | 
			
		||||
 | 
			
		||||
@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.misc import check_version
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    if not is_trainable or not model_args.block_diag_attn:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
 | 
			
		||||
    check_version("transformers>=4.43.0,<=4.46.1")
 | 
			
		||||
    transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
 | 
			
		||||
    logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
 | 
			
		||||
 | 
			
		||||
@ -26,11 +26,10 @@ from datasets import load_dataset
 | 
			
		||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.modeling_utils import is_fsdp_enabled
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.constants import FILEEXT2TYPE
 | 
			
		||||
from ...extras.misc import get_current_device
 | 
			
		||||
from ...extras.misc import check_version, get_current_device
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -118,15 +117,15 @@ def configure_quantization(
 | 
			
		||||
        quant_method = quantization_config.get("quant_method", "")
 | 
			
		||||
 | 
			
		||||
        if quant_method == QuantizationMethod.GPTQ:
 | 
			
		||||
            require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
 | 
			
		||||
            check_version("auto_gptq>=0.5.0", mandatory=True)
 | 
			
		||||
            quantization_config.pop("disable_exllama", None)  # remove deprecated args
 | 
			
		||||
            quantization_config["use_exllama"] = False  # disable exllama
 | 
			
		||||
 | 
			
		||||
        if quant_method == QuantizationMethod.AWQ:
 | 
			
		||||
            require_version("autoawq", "To fix: pip install autoawq")
 | 
			
		||||
            check_version("autoawq", mandatory=True)
 | 
			
		||||
 | 
			
		||||
        if quant_method == QuantizationMethod.AQLM:
 | 
			
		||||
            require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
 | 
			
		||||
            check_version("aqlm>=1.1.0", mandatory=True)
 | 
			
		||||
            quantization_config["bits"] = 2
 | 
			
		||||
 | 
			
		||||
        quant_bits = quantization_config.get("bits", "?")
 | 
			
		||||
@ -136,8 +135,8 @@ def configure_quantization(
 | 
			
		||||
        if model_args.export_quantization_bit not in [8, 4, 3, 2]:
 | 
			
		||||
            raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
 | 
			
		||||
 | 
			
		||||
        require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
 | 
			
		||||
        require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
 | 
			
		||||
        check_version("optimum>=1.17.0", mandatory=True)
 | 
			
		||||
        check_version("auto_gptq>=0.5.0", mandatory=True)
 | 
			
		||||
        from accelerate.utils import get_max_memory
 | 
			
		||||
 | 
			
		||||
        if getattr(config, "model_type", None) == "chatglm":
 | 
			
		||||
@ -154,10 +153,10 @@ def configure_quantization(
 | 
			
		||||
    elif model_args.quantization_bit is not None:  # on-the-fly
 | 
			
		||||
        if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
 | 
			
		||||
            if model_args.quantization_bit == 8:
 | 
			
		||||
                require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
 | 
			
		||||
                check_version("bitsandbytes>=0.37.0", mandatory=True)
 | 
			
		||||
                init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
 | 
			
		||||
            elif model_args.quantization_bit == 4:
 | 
			
		||||
                require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
 | 
			
		||||
                check_version("bitsandbytes>=0.39.0", mandatory=True)
 | 
			
		||||
                init_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                    load_in_4bit=True,
 | 
			
		||||
                    bnb_4bit_compute_dtype=model_args.compute_dtype,
 | 
			
		||||
@ -175,7 +174,7 @@ def configure_quantization(
 | 
			
		||||
                if model_args.quantization_bit != 4:
 | 
			
		||||
                    raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
 | 
			
		||||
 | 
			
		||||
                require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
 | 
			
		||||
                check_version("bitsandbytes>=0.43.0", mandatory=True)
 | 
			
		||||
            else:
 | 
			
		||||
                init_kwargs["device_map"] = {"": get_current_device()}  # change auto device map for inference
 | 
			
		||||
 | 
			
		||||
@ -187,7 +186,7 @@ def configure_quantization(
 | 
			
		||||
            if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
 | 
			
		||||
                raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
 | 
			
		||||
 | 
			
		||||
            require_version("hqq", "To fix: pip install hqq")
 | 
			
		||||
            check_version("hqq", mandatory=True)
 | 
			
		||||
            init_kwargs["quantization_config"] = HqqConfig(
 | 
			
		||||
                nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
 | 
			
		||||
            )  # use ATEN kernel (axis=0) for performance
 | 
			
		||||
@ -199,6 +198,6 @@ def configure_quantization(
 | 
			
		||||
            if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
 | 
			
		||||
                raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
 | 
			
		||||
 | 
			
		||||
            require_version("eetq", "To fix: pip install eetq")
 | 
			
		||||
            check_version("eetq", mandatory=True)
 | 
			
		||||
            init_kwargs["quantization_config"] = EetqConfig()
 | 
			
		||||
            logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
 | 
			
		||||
 | 
			
		||||
@ -239,7 +239,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
 | 
			
		||||
            and args.overwrite_output_dir
 | 
			
		||||
        ):
 | 
			
		||||
            logger.warning_once("Previous trainer log in this folder will be deleted.")
 | 
			
		||||
            logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
 | 
			
		||||
            os.remove(os.path.join(args.output_dir, TRAINER_LOG))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,7 @@ def run_sft(
 | 
			
		||||
 | 
			
		||||
    # Predict
 | 
			
		||||
    if training_args.do_predict:
 | 
			
		||||
        logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
 | 
			
		||||
        logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
 | 
			
		||||
        predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
 | 
			
		||||
        trainer.log_metrics("predict", predict_results.metrics)
 | 
			
		||||
        trainer.save_metrics("predict", predict_results.metrics)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user