mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-24 15:50:35 +08:00
[model] Add Ministral3 (#9582)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
@@ -83,6 +83,7 @@ def configure_quantization(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
is_trainable: bool,
|
||||
init_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
|
||||
@@ -109,6 +110,10 @@ def configure_quantization(
|
||||
check_version("aqlm>=1.1.0", mandatory=True)
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
if quant_method == QuantizationMethod.FP8 and is_trainable:
|
||||
quant_config = FineGrainedFP8Config(dequantize=True)
|
||||
init_kwargs["quantization_config"] = quant_config
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user