diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 1004544d..f63b6434 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine): try: asyncio.get_event_loop() except RuntimeError: - logger.warning_once("There is no current event loop, creating a new one.") + logger.warning_rank0_once("There is no current event loop, creating a new one.") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index cbce026c..bd5d3587 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -56,12 +56,12 @@ def merge_dataset( return all_datasets[0] elif data_args.mix_strategy == "concat": if data_args.streaming: - logger.warning_once("The samples between different datasets will not be mixed in streaming mode.") + logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.") return concatenate_datasets(all_datasets) elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: - logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.") + logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.") return interleave_datasets( datasets=all_datasets, diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 863b6492..3c7e34a4 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np from datasets import DatasetDict, load_dataset, load_from_disk -from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import FILEEXT2TYPE -from ..extras.misc import has_tokenized_data +from ..extras.misc import check_version, has_tokenized_data from .aligner import align_dataset from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list @@ -84,7 +83,7 @@ def _load_single_dataset( raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") if dataset_attr.load_from == "ms_hub": - require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + check_version("modelscope>=1.11.0", mandatory=True) from modelscope import MsDataset # type: ignore from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore @@ -103,7 +102,7 @@ def _load_single_dataset( dataset = dataset.to_hf_dataset() elif dataset_attr.load_from == "om_hub": - require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") + check_version("openmind>=0.8.0", mandatory=True) from openmind import OmDataset # type: ignore from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4e1f418a..a8c46d11 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -73,10 +73,14 @@ class BasePlugin: Validates if this model accepts the input modalities. """ if len(images) != 0 and self.image_token is None: - raise ValueError("This model does not support image input.") + raise ValueError( + "This model does not support image input. Please check whether the correct `template` is used." + ) if len(videos) != 0 and self.video_token is None: - raise ValueError("This model does not support video input.") + raise ValueError( + "This model does not support video input. Please check whether the correct `template` is used." + ) def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": r""" diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 5768cf7b..ebe31553 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -15,10 +15,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union -from transformers.utils.versions import require_version from typing_extensions import override from ..extras import logging +from ..extras.misc import check_version from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .mm_plugin import get_mm_plugin @@ -365,7 +365,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: raise ValueError(f"Template {data_args.template} does not exist.") if template.mm_plugin.__class__.__name__ != "BasePlugin": - require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") + check_version("transformers>=4.45.0") if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 40889a88..8f98b055 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler): class _Logger(logging.Logger): r""" - A logger that supports info_rank0 and warning_once. + A logger that supports rank0 logging. """ def info_rank0(self, *args, **kwargs) -> None: @@ -77,7 +77,7 @@ class _Logger(logging.Logger): def warning_rank0(self, *args, **kwargs) -> None: self.warning(*args, **kwargs) - def warning_once(self, *args, **kwargs) -> None: + def warning_rank0_once(self, *args, **kwargs) -> None: self.warning(*args, **kwargs) @@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None: @lru_cache(None) -def warning_once(self: "logging.Logger", *args, **kwargs) -> None: +def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None: if int(os.getenv("LOCAL_RANK", "0")) == 0: self.warning(*args, **kwargs) logging.Logger.info_rank0 = info_rank0 logging.Logger.warning_rank0 = warning_rank0 -logging.Logger.warning_once = warning_once +logging.Logger.warning_rank0_once = warning_rank0_once diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 11797f9f..beaed725 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -73,19 +73,31 @@ class AverageMeter: self.avg = self.sum / self.count +def check_version(requirement: str, mandatory: bool = False) -> None: + r""" + Optionally checks the package version. + """ + if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory: + logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") + return + + if mandatory: + hint = f"To fix: run `pip install {requirement}`." + else: + hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check." + + require_version(requirement, hint) + + def check_dependencies() -> None: r""" Checks the version of the required packages. """ - if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: - logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") - return - - require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") - require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0") - require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") - require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") - require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") + check_version("transformers>=4.41.2,<=4.46.1") + check_version("datasets>=2.16.0,<=3.1.0") + check_version("accelerate>=0.34.0,<=1.0.1") + check_version("peft>=0.11.1,<=0.12.0") + check_version("trl>=0.8.6,<=0.9.6") def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float: @@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: return model_args.model_name_or_path if use_modelscope(): - require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + check_version("modelscope>=1.11.0", mandatory=True) from modelscope import snapshot_download # type: ignore revision = "master" if model_args.model_revision == "main" else model_args.model_revision @@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: ) if use_openmind(): - require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") + check_version("openmind>=0.8.0", mandatory=True) from openmind.utils.hub import snapshot_download # type: ignore return snapshot_download( diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 62edbf78..456e34a2 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -29,11 +29,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available -from transformers.utils.versions import require_version from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES -from ..extras.misc import check_dependencies, get_current_device +from ..extras.misc import check_dependencies, check_version, get_current_device from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments @@ -124,38 +123,35 @@ def _check_extra_dependencies( finetuning_args: "FinetuningArguments", training_args: Optional["TrainingArguments"] = None, ) -> None: - if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: - logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") - return - if model_args.use_unsloth: - require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") + check_version("unsloth", mandatory=True) if model_args.enable_liger_kernel: - require_version("liger-kernel", "To fix: pip install liger-kernel") + check_version("liger-kernel", mandatory=True) if model_args.mixture_of_depths is not None: - require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") + check_version("mixture-of-depth>=1.1.6", mandatory=True) if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7") + check_version("vllm>=0.4.3,<0.6.7") + check_version("vllm", mandatory=True) if finetuning_args.use_galore: - require_version("galore_torch", "To fix: pip install galore_torch") + check_version("galore_torch", mandatory=True) if finetuning_args.use_badam: - require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") + check_version("badam>=1.2.1", mandatory=True) if finetuning_args.use_adam_mini: - require_version("adam-mini", "To fix: pip install adam-mini") + check_version("adam-mini", mandatory=True) if finetuning_args.plot_loss: - require_version("matplotlib", "To fix: pip install matplotlib") + check_version("matplotlib", mandatory=True) if training_args is not None and training_args.predict_with_generate: - require_version("jieba", "To fix: pip install jieba") - require_version("nltk", "To fix: pip install nltk") - require_version("rouge_chinese", "To fix: pip install rouge-chinese") + check_version("jieba", mandatory=True) + check_version("nltk", mandatory=True) + check_version("rouge_chinese", mandatory=True) def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS: diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index bf243aaa..8ec74351 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -15,9 +15,9 @@ from typing import TYPE_CHECKING from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available -from transformers.utils.versions import require_version from ...extras import logging +from ...extras.misc import check_version if TYPE_CHECKING: @@ -35,8 +35,8 @@ def configure_attn_implementation( if getattr(config, "model_type", None) == "gemma2" and is_trainable: if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if is_flash_attn_2_available(): - require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") - require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") + check_version("transformers>=4.42.4") + check_version("flash_attn>=2.6.3") if model_args.flash_attn != "fa2": logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") model_args.flash_attn = "fa2" diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 0fad48cf..80a50f3e 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -122,7 +122,7 @@ def _gradient_checkpointing_enable( if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format self.apply(partial(self._set_gradient_checkpointing, value=True)) self.enable_input_require_grads() - logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.") else: # have already enabled input require gradients self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 96a7b40e..89457846 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -31,10 +31,10 @@ from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) -from transformers.utils.versions import require_version from ...extras import logging from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN +from ...extras.misc import check_version from ...extras.packages import is_transformers_version_greater_than @@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") + check_version("transformers>=4.41.2,<=4.46.1") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 642d164a..58039e2a 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence import torch from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.utils.versions import require_version + +from ...extras.misc import check_version if TYPE_CHECKING: @@ -26,7 +27,7 @@ if TYPE_CHECKING: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + check_version("deepspeed>=0.13.0") from deepspeed.utils import set_z3_leaf_modules # type: ignore set_z3_leaf_modules(model, leaf_modules) diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 014c8e87..34c3c55b 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -41,9 +41,9 @@ from typing import TYPE_CHECKING, Tuple import torch import torch.nn.functional as F -from transformers.utils.versions import require_version from ...extras import logging +from ...extras.misc import check_version from ...extras.packages import is_transformers_version_greater_than @@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable or not model_args.block_diag_attn: return - require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1") + check_version("transformers>=4.43.0,<=4.46.1") transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.") diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 0739c566..e000ee23 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -26,11 +26,10 @@ from datasets import load_dataset from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled -from transformers.utils.versions import require_version from ...extras import logging from ...extras.constants import FILEEXT2TYPE -from ...extras.misc import get_current_device +from ...extras.misc import check_version, get_current_device if TYPE_CHECKING: @@ -118,15 +117,15 @@ def configure_quantization( quant_method = quantization_config.get("quant_method", "") if quant_method == QuantizationMethod.GPTQ: - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + check_version("auto_gptq>=0.5.0", mandatory=True) quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config["use_exllama"] = False # disable exllama if quant_method == QuantizationMethod.AWQ: - require_version("autoawq", "To fix: pip install autoawq") + check_version("autoawq", mandatory=True) if quant_method == QuantizationMethod.AQLM: - require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + check_version("aqlm>=1.1.0", mandatory=True) quantization_config["bits"] = 2 quant_bits = quantization_config.get("bits", "?") @@ -136,8 +135,8 @@ def configure_quantization( if model_args.export_quantization_bit not in [8, 4, 3, 2]: raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") - require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + check_version("optimum>=1.17.0", mandatory=True) + check_version("auto_gptq>=0.5.0", mandatory=True) from accelerate.utils import get_max_memory if getattr(config, "model_type", None) == "chatglm": @@ -154,10 +153,10 @@ def configure_quantization( elif model_args.quantization_bit is not None: # on-the-fly if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + check_version("bitsandbytes>=0.37.0", mandatory=True) init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + check_version("bitsandbytes>=0.39.0", mandatory=True) init_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=model_args.compute_dtype, @@ -175,7 +174,7 @@ def configure_quantization( if model_args.quantization_bit != 4: raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") - require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + check_version("bitsandbytes>=0.43.0", mandatory=True) else: init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference @@ -187,7 +186,7 @@ def configure_quantization( if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") - require_version("hqq", "To fix: pip install hqq") + check_version("hqq", mandatory=True) init_kwargs["quantization_config"] = HqqConfig( nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 ) # use ATEN kernel (axis=0) for performance @@ -199,6 +198,6 @@ def configure_quantization( if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") - require_version("eetq", "To fix: pip install eetq") + check_version("eetq", mandatory=True) init_kwargs["quantization_config"] = EetqConfig() logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.") diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 4da4ec18..5906a4a6 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -239,7 +239,7 @@ class LogCallback(TrainerCallback): and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and args.overwrite_output_dir ): - logger.warning_once("Previous trainer log in this folder will be deleted.") + logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") os.remove(os.path.join(args.output_dir, TRAINER_LOG)) @override diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 1ccfa9ef..f41f24cb 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -122,7 +122,7 @@ def run_sft( # Predict if training_args.do_predict: - logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") + logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)