mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 02:35:59 +08:00
[v1] support quantization (#10161)
This commit is contained in:
1
.github/workflows/tests_cuda.yml
vendored
1
.github/workflows/tests_cuda.yml
vendored
@@ -61,6 +61,7 @@ jobs:
|
|||||||
uv venv
|
uv venv
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
uv pip install -r requirements/dev.txt
|
uv pip install -r requirements/dev.txt
|
||||||
|
uv pip install -r requirements/bitsandbytes.txt
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
43
examples/v1/train_qlora/quantization.yaml
Normal file
43
examples/v1/train_qlora/quantization.yaml
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# PEFT Configuration
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
target_modules: all
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
# Quantization Config
|
||||||
|
quant_config:
|
||||||
|
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
|
||||||
|
quantization_bit: 4 # choice: 8/4(bnb)
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: outputs/test_quantization
|
||||||
|
micro_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
@@ -90,6 +90,26 @@ class ModelEngine:
|
|||||||
Transformers can choose the proper model init context.
|
Transformers can choose the proper model init context.
|
||||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||||
"""
|
"""
|
||||||
|
if self.args.init_config is not None:
|
||||||
|
from ..plugins.model_plugins.initialization import InitPlugin
|
||||||
|
|
||||||
|
init_device = InitPlugin(self.args.init_config.name)()
|
||||||
|
else:
|
||||||
|
init_device = DistributedInterface().current_device
|
||||||
|
|
||||||
|
init_kwargs = {"device_map": init_device}
|
||||||
|
|
||||||
|
if self.args.quant_config is not None:
|
||||||
|
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||||
|
|
||||||
|
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
|
config=self.model_config,
|
||||||
|
tokenizer=self.processor,
|
||||||
|
model_args=self.args,
|
||||||
|
is_trainable=self.is_train,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.model_class == ModelClass.LLM:
|
if self.args.model_class == ModelClass.LLM:
|
||||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||||
|
|
||||||
@@ -107,14 +127,8 @@ class ModelEngine:
|
|||||||
|
|
||||||
AutoClass = AutoModel
|
AutoClass = AutoModel
|
||||||
|
|
||||||
if self.args.init_config is not None:
|
|
||||||
from ..plugins.model_plugins.initialization import InitPlugin
|
|
||||||
|
|
||||||
init_device = InitPlugin(self.args.init_config.name)()
|
|
||||||
else:
|
|
||||||
init_device = DistributedInterface().current_device
|
|
||||||
|
|
||||||
if init_device.type == DeviceType.META:
|
if init_device.type == DeviceType.META:
|
||||||
|
assert self.args.quant_config is None, "Quantization is not supported with meta device."
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoClass.from_config(self.model_config)
|
model = AutoClass.from_config(self.model_config)
|
||||||
else:
|
else:
|
||||||
@@ -122,8 +136,8 @@ class ModelEngine:
|
|||||||
self.args.model,
|
self.args.model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
dtype="auto",
|
dtype="auto",
|
||||||
device_map=init_device,
|
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.peft_config is None:
|
if self.args.peft_config is None:
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
|
||||||
|
from ...accelerator.helper import get_current_device
|
||||||
|
from ...config.model_args import ModelArguments
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.packages import check_version
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationPlugin(BasePlugin):
|
||||||
|
r"""Plugin for model quantization."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
init_kwargs: dict[str, Any] = None,
|
||||||
|
config: "PretrainedConfig" = None,
|
||||||
|
tokenizer: "PreTrainedTokenizer" = None,
|
||||||
|
model_args: "ModelArguments" = None,
|
||||||
|
is_trainable: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return super().__call__(
|
||||||
|
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@QuantizationPlugin("auto").register()
|
||||||
|
def quantization_auto(
|
||||||
|
init_kwargs: dict[str, Any],
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Automatic quantization selection, only support bnb currently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_kwargs (dict[str, Any]): The kwargs for model initialization.
|
||||||
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: The updated kwargs for model initialization.
|
||||||
|
"""
|
||||||
|
model_args: ModelArguments = kwargs.get("model_args", None)
|
||||||
|
quant_config = model_args.quant_config
|
||||||
|
|
||||||
|
quantization_bit = quant_config.get("quantization_bit", None)
|
||||||
|
if quantization_bit is not None:
|
||||||
|
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
|
||||||
|
if quantization_bit in [8, 4]:
|
||||||
|
return quantization_with_bnb(init_kwargs, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
|
||||||
|
logger.warning_rank0("No quantization method applied.")
|
||||||
|
return init_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@QuantizationPlugin("bnb").register()
|
||||||
|
def quantization_with_bnb(
|
||||||
|
init_kwargs: dict[str, Any],
|
||||||
|
model_args: "ModelArguments" = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
r"""Quantization with BNB."""
|
||||||
|
logger.info_rank0("Using Bitsandbytes quantization.")
|
||||||
|
quantization_bit = model_args.quant_config.get("quantization_bit", None)
|
||||||
|
if quantization_bit is None:
|
||||||
|
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
|
||||||
|
quantization_bit = 4
|
||||||
|
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
|
||||||
|
if quantization_bit == 8:
|
||||||
|
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||||
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
elif quantization_bit == 4:
|
||||||
|
check_version("bitsandbytes>=0.39.0", mandatory=True)
|
||||||
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
|
||||||
|
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
|
||||||
|
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
|
||||||
|
bnb_4bit_quant_storage=model_args.quant_config.get(
|
||||||
|
"compute_dtype", torch.float16
|
||||||
|
), # crucial for fsdp+qlora
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||||
|
|
||||||
|
# TODO: improve deepspeed zero3 and fsdp detection.
|
||||||
|
if kwargs.get("is_trainable", False):
|
||||||
|
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
|
||||||
|
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||||
|
else:
|
||||||
|
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
|
||||||
|
if model_args.quant_config.get("quantization_bit") != 4:
|
||||||
|
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||||
|
|
||||||
|
check_version("bitsandbytes>=0.43.0", mandatory=True)
|
||||||
|
|
||||||
|
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
|
||||||
|
return init_kwargs
|
||||||
|
|||||||
@@ -21,6 +21,13 @@ from functools import lru_cache
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from . import logging
|
||||||
|
from .env import is_env_enabled
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
|
|||||||
@lru_cache
|
@lru_cache
|
||||||
def is_transformers_version_greater_than(content: str):
|
def is_transformers_version_greater_than(content: str):
|
||||||
return _get_package_version("transformers") >= version.parse(content)
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
|
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||||
|
r"""Optionally check the package version."""
|
||||||
|
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||||
|
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if "gptqmodel" in requirement or "autoawq" in requirement:
|
||||||
|
pip_command = f"pip install {requirement} --no-build-isolation"
|
||||||
|
else:
|
||||||
|
pip_command = f"pip install {requirement}"
|
||||||
|
|
||||||
|
if mandatory:
|
||||||
|
hint = f"To fix: run `{pip_command}`."
|
||||||
|
else:
|
||||||
|
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||||
|
|
||||||
|
require_version(requirement, hint)
|
||||||
|
|||||||
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llamafactory.v1.config.model_args import ModelArguments
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
|
bitsandbytes = pytest.importorskip("bitsandbytes")
|
||||||
|
|
||||||
|
|
||||||
|
def check_quantization_status(model):
|
||||||
|
quantized_info = {"bnb": []}
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
# check BitsAndBytes quantization
|
||||||
|
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
|
||||||
|
module, bitsandbytes.nn.modules.Linear4bit
|
||||||
|
):
|
||||||
|
quantized_info["bnb"].append(name)
|
||||||
|
|
||||||
|
return quantized_info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cuda"])
|
||||||
|
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
|
||||||
|
def test_quantization_plugin(name, quantization_bit):
|
||||||
|
model_args = ModelArguments(
|
||||||
|
model="llamafactory/tiny-random-qwen3",
|
||||||
|
quant_config={
|
||||||
|
"name": name,
|
||||||
|
"quantization_bit": quantization_bit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
quantized_info = check_quantization_status(model_engine.model)
|
||||||
|
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
|
||||||
|
assert any(v for v in quantized_info.values()), "model is not quantized properly."
|
||||||
Reference in New Issue
Block a user