From 991267fd3bc627a8f9694a4ff8c1e5a910b6c7e1 Mon Sep 17 00:00:00 2001 From: sunyi0505 <1659275352@qq.com> Date: Thu, 12 Feb 2026 20:37:41 +0800 Subject: [PATCH] [v1] support quantization (#10161) --- .github/workflows/tests_cuda.yml | 1 + examples/v1/train_qlora/quantization.yaml | 43 ++++++ src/llamafactory/v1/core/model_engine.py | 30 +++-- .../v1/plugins/model_plugins/quantization.py | 122 ++++++++++++++++++ src/llamafactory/v1/utils/packages.py | 26 ++++ .../model_plugins/test_quantization_plugin.py | 51 ++++++++ 6 files changed, 265 insertions(+), 8 deletions(-) create mode 100644 examples/v1/train_qlora/quantization.yaml create mode 100644 tests_v1/plugins/model_plugins/test_quantization_plugin.py diff --git a/.github/workflows/tests_cuda.yml b/.github/workflows/tests_cuda.yml index 33558a5d0..f533b01c5 100644 --- a/.github/workflows/tests_cuda.yml +++ b/.github/workflows/tests_cuda.yml @@ -61,6 +61,7 @@ jobs: uv venv uv pip install -e . uv pip install -r requirements/dev.txt + uv pip install -r requirements/bitsandbytes.txt - name: Check quality run: | diff --git a/examples/v1/train_qlora/quantization.yaml b/examples/v1/train_qlora/quantization.yaml new file mode 100644 index 000000000..a063b207c --- /dev/null +++ b/examples/v1/train_qlora/quantization.yaml @@ -0,0 +1,43 @@ +model: Qwen/Qwen3-0.6B +trust_remote_code: true +model_class: llm + +template: qwen3_nothink + +# PEFT Configuration +peft_config: + name: lora + r: 16 + lora_alpha: 32 + lora_dropout: 0.05 + target_modules: all + +# Kernel Config +kernel_config: + name: auto + include_kernels: auto + +# FSDP Config +dist_config: + name: fsdp2 + dcp_path: null + +# Quantization Config +quant_config: + name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment. + quantization_bit: 4 # choice: 8/4(bnb) + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_quantization +micro_batch_size: 1 +cutoff_len: 2048 +learning_rate: 1.0e-4 +bf16: false +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index 0a5a00204..c08522402 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -90,6 +90,26 @@ class ModelEngine: Transformers can choose the proper model init context. https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538 """ + if self.args.init_config is not None: + from ..plugins.model_plugins.initialization import InitPlugin + + init_device = InitPlugin(self.args.init_config.name)() + else: + init_device = DistributedInterface().current_device + + init_kwargs = {"device_map": init_device} + + if self.args.quant_config is not None: + from ..plugins.model_plugins.quantization import QuantizationPlugin + + init_kwargs = QuantizationPlugin(self.args.quant_config.name)( + init_kwargs=init_kwargs, + config=self.model_config, + tokenizer=self.processor, + model_args=self.args, + is_trainable=self.is_train, + ) + if self.args.model_class == ModelClass.LLM: from transformers import AutoModelForCausalLM, AutoModelForImageTextToText @@ -107,14 +127,8 @@ class ModelEngine: AutoClass = AutoModel - if self.args.init_config is not None: - from ..plugins.model_plugins.initialization import InitPlugin - - init_device = InitPlugin(self.args.init_config.name)() - else: - init_device = DistributedInterface().current_device - if init_device.type == DeviceType.META: + assert self.args.quant_config is None, "Quantization is not supported with meta device." with init_empty_weights(): model = AutoClass.from_config(self.model_config) else: @@ -122,8 +136,8 @@ class ModelEngine: self.args.model, config=self.model_config, dtype="auto", - device_map=init_device, trust_remote_code=self.args.trust_remote_code, + **init_kwargs, ) if self.args.peft_config is None: diff --git a/src/llamafactory/v1/plugins/model_plugins/quantization.py b/src/llamafactory/v1/plugins/model_plugins/quantization.py index e69de29bb..2ba74d47d 100644 --- a/src/llamafactory/v1/plugins/model_plugins/quantization.py +++ b/src/llamafactory/v1/plugins/model_plugins/quantization.py @@ -0,0 +1,122 @@ +# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any + +import torch +from transformers import BitsAndBytesConfig + +from ...accelerator.helper import get_current_device +from ...config.model_args import ModelArguments +from ...utils import logging +from ...utils.packages import check_version +from ...utils.plugin import BasePlugin + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + +logger = logging.get_logger(__name__) + + +class QuantizationPlugin(BasePlugin): + r"""Plugin for model quantization.""" + + def __call__( + self, + init_kwargs: dict[str, Any] = None, + config: "PretrainedConfig" = None, + tokenizer: "PreTrainedTokenizer" = None, + model_args: "ModelArguments" = None, + is_trainable: bool = False, + ) -> dict[str, Any]: + return super().__call__( + init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable + ) + + +@QuantizationPlugin("auto").register() +def quantization_auto( + init_kwargs: dict[str, Any], + **kwargs, +) -> dict[str, Any]: + """Automatic quantization selection, only support bnb currently. + + Args: + init_kwargs (dict[str, Any]): The kwargs for model initialization. + **kwargs: Keyword arguments containing the model. + + Returns: + dict[str, Any]: The updated kwargs for model initialization. + """ + model_args: ModelArguments = kwargs.get("model_args", None) + quant_config = model_args.quant_config + + quantization_bit = quant_config.get("quantization_bit", None) + if quantization_bit is not None: + logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.") + if quantization_bit in [8, 4]: + return quantization_with_bnb(init_kwargs, **kwargs) + else: + raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.") + logger.warning_rank0("No quantization method applied.") + return init_kwargs + + +@QuantizationPlugin("bnb").register() +def quantization_with_bnb( + init_kwargs: dict[str, Any], + model_args: "ModelArguments" = None, + **kwargs, +) -> dict[str, Any]: + r"""Quantization with BNB.""" + logger.info_rank0("Using Bitsandbytes quantization.") + quantization_bit = model_args.quant_config.get("quantization_bit", None) + if quantization_bit is None: + logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.") + quantization_bit = 4 + assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization." + if quantization_bit == 8: + check_version("bitsandbytes>=0.37.0", mandatory=True) + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif quantization_bit == 4: + check_version("bitsandbytes>=0.39.0", mandatory=True) + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16), + bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True), + bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"), + bnb_4bit_quant_storage=model_args.quant_config.get( + "compute_dtype", torch.float16 + ), # crucial for fsdp+qlora + ) + else: + raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.") + + # TODO: improve deepspeed zero3 and fsdp detection. + if kwargs.get("is_trainable", False): + logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.") + init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference + else: + logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.") + if model_args.quant_config.get("quantization_bit") != 4: + raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") + + check_version("bitsandbytes>=0.43.0", mandatory=True) + + logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.") + return init_kwargs diff --git a/src/llamafactory/v1/utils/packages.py b/src/llamafactory/v1/utils/packages.py index 8d86e01a8..b2b76aa65 100644 --- a/src/llamafactory/v1/utils/packages.py +++ b/src/llamafactory/v1/utils/packages.py @@ -21,6 +21,13 @@ from functools import lru_cache from typing import TYPE_CHECKING from packaging import version +from transformers.utils.versions import require_version + +from . import logging +from .env import is_env_enabled + + +logger = logging.get_logger(__name__) if TYPE_CHECKING: @@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version": @lru_cache def is_transformers_version_greater_than(content: str): return _get_package_version("transformers") >= version.parse(content) + + +def check_version(requirement: str, mandatory: bool = False) -> None: + r"""Optionally check the package version.""" + if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory: + logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.") + return + + if "gptqmodel" in requirement or "autoawq" in requirement: + pip_command = f"pip install {requirement} --no-build-isolation" + else: + pip_command = f"pip install {requirement}" + + if mandatory: + hint = f"To fix: run `{pip_command}`." + else: + hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check." + + require_version(requirement, hint) diff --git a/tests_v1/plugins/model_plugins/test_quantization_plugin.py b/tests_v1/plugins/model_plugins/test_quantization_plugin.py new file mode 100644 index 000000000..019580d50 --- /dev/null +++ b/tests_v1/plugins/model_plugins/test_quantization_plugin.py @@ -0,0 +1,51 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from llamafactory.v1.config.model_args import ModelArguments +from llamafactory.v1.core.model_engine import ModelEngine + + +bitsandbytes = pytest.importorskip("bitsandbytes") + + +def check_quantization_status(model): + quantized_info = {"bnb": []} + + for name, module in model.named_modules(): + # check BitsAndBytes quantization + if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance( + module, bitsandbytes.nn.modules.Linear4bit + ): + quantized_info["bnb"].append(name) + + return quantized_info + + +@pytest.mark.runs_on(["cuda"]) +@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)]) +def test_quantization_plugin(name, quantization_bit): + model_args = ModelArguments( + model="llamafactory/tiny-random-qwen3", + quant_config={ + "name": name, + "quantization_bit": quantization_bit, + }, + ) + + model_engine = ModelEngine(model_args=model_args) + quantized_info = check_quantization_status(model_engine.model) + print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}") + assert any(v for v in quantized_info.values()), "model is not quantized properly."