mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support HQQ/EETQ #4113
Former-commit-id: b7cb51ddb394f04fe4646b2c297fc8d918c9979e
This commit is contained in:
		
							parent
							
								
									08fa707085
								
							
						
					
					
						commit
						8aaf1185a5
					
				@ -48,7 +48,7 @@ Choose your path:
 | 
			
		||||
 | 
			
		||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
 | 
			
		||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
 | 
			
		||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
 | 
			
		||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
 | 
			
		||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
 | 
			
		||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
 | 
			
		||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
 | 
			
		||||
@ -341,7 +341,7 @@ cd LLaMA-Factory
 | 
			
		||||
pip install -e ".[torch,metrics]"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
 | 
			
		||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
 | 
			
		||||
 | 
			
		||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
 | 
			
		||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
 | 
			
		||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
 | 
			
		||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
 | 
			
		||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
 | 
			
		||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
 | 
			
		||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
 | 
			
		||||
@ -341,7 +341,7 @@ cd LLaMA-Factory
 | 
			
		||||
pip install -e ".[torch,metrics]"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
 | 
			
		||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								setup.py
									
									
									
									
									
								
							@ -39,12 +39,14 @@ extra_require = {
 | 
			
		||||
    "metrics": ["nltk", "jieba", "rouge-chinese"],
 | 
			
		||||
    "deepspeed": ["deepspeed>=0.10.0"],
 | 
			
		||||
    "bitsandbytes": ["bitsandbytes>=0.39.0"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3"],
 | 
			
		||||
    "galore": ["galore-torch"],
 | 
			
		||||
    "badam": ["badam>=1.2.1"],
 | 
			
		||||
    "hqq": ["hqq"],
 | 
			
		||||
    "eetq": ["eetq"],
 | 
			
		||||
    "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
 | 
			
		||||
    "awq": ["autoawq"],
 | 
			
		||||
    "aqlm": ["aqlm[gpu]>=1.1.0"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3"],
 | 
			
		||||
    "galore": ["galore-torch"],
 | 
			
		||||
    "badam": ["badam>=1.2.1"],
 | 
			
		||||
    "qwen": ["transformers_stream_generator"],
 | 
			
		||||
    "modelscope": ["modelscope"],
 | 
			
		||||
    "dev": ["ruff", "pytest"],
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,7 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
 | 
			
		||||
@ -77,6 +77,10 @@ class ModelArguments:
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use memory-efficient model loading."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
 | 
			
		||||
        default="bitsandbytes",
 | 
			
		||||
        metadata={"help": "Quantization method to use for on-the-fly quantization."},
 | 
			
		||||
    )
 | 
			
		||||
    quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
 | 
			
		||||
@ -235,9 +239,6 @@ class ModelArguments:
 | 
			
		||||
        if self.new_special_tokens is not None:  # support multiple special tokens
 | 
			
		||||
            self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
 | 
			
		||||
 | 
			
		||||
        assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
 | 
			
		||||
        assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
 | 
			
		||||
 | 
			
		||||
        if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
 | 
			
		||||
            raise ValueError("Quantization dataset is necessary for exporting.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,10 +14,12 @@
 | 
			
		||||
 | 
			
		||||
from .loader import load_config, load_model, load_tokenizer
 | 
			
		||||
from .model_utils.misc import find_all_linear_modules
 | 
			
		||||
from .model_utils.quantization import QuantizationMethod
 | 
			
		||||
from .model_utils.valuehead import load_valuehead_params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "QuantizationMethod",
 | 
			
		||||
    "load_config",
 | 
			
		||||
    "load_model",
 | 
			
		||||
    "load_tokenizer",
 | 
			
		||||
 | 
			
		||||
@ -186,11 +186,11 @@ def load_model(
 | 
			
		||||
 | 
			
		||||
    trainable_params, all_param = count_parameters(model)
 | 
			
		||||
    if is_trainable:
 | 
			
		||||
        param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
 | 
			
		||||
        param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
 | 
			
		||||
            trainable_params, all_param, 100 * trainable_params / all_param
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        param_stats = "all params: {:d}".format(all_param)
 | 
			
		||||
        param_stats = "all params: {:,}".format(all_param)
 | 
			
		||||
 | 
			
		||||
    logger.info(param_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from transformers import BitsAndBytesConfig, GPTQConfig
 | 
			
		||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
from transformers.modeling_utils import is_fsdp_enabled
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
@ -59,7 +59,7 @@ class QuantizationMethod(str, Enum):
 | 
			
		||||
 | 
			
		||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Prepares the dataset to perform AutoGPTQ.
 | 
			
		||||
    Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
 | 
			
		||||
    """
 | 
			
		||||
    if os.path.isfile(model_args.export_quantization_dataset):
 | 
			
		||||
        data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
 | 
			
		||||
@ -93,7 +93,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
 | 
			
		||||
        word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
 | 
			
		||||
        input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
 | 
			
		||||
        attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
 | 
			
		||||
        samples.append({"input_ids": input_ids, "attention_mask": attention_mask})
 | 
			
		||||
        samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
 | 
			
		||||
 | 
			
		||||
    return samples
 | 
			
		||||
 | 
			
		||||
@ -105,7 +105,7 @@ def configure_quantization(
 | 
			
		||||
    init_kwargs: Dict[str, Any],
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
 | 
			
		||||
    Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
 | 
			
		||||
    """
 | 
			
		||||
    if getattr(config, "quantization_config", None):  # ptq
 | 
			
		||||
        if is_deepspeed_zero3_enabled():
 | 
			
		||||
@ -131,6 +131,9 @@ def configure_quantization(
 | 
			
		||||
        logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
 | 
			
		||||
 | 
			
		||||
    elif model_args.export_quantization_bit is not None:  # auto-gptq
 | 
			
		||||
        if model_args.export_quantization_bit not in [8, 4, 3, 2]:
 | 
			
		||||
            raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
 | 
			
		||||
 | 
			
		||||
        require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
 | 
			
		||||
        require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
 | 
			
		||||
        from accelerate.utils import get_max_memory
 | 
			
		||||
@ -146,30 +149,48 @@ def configure_quantization(
 | 
			
		||||
        init_kwargs["max_memory"] = get_max_memory()
 | 
			
		||||
        logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
 | 
			
		||||
 | 
			
		||||
    elif model_args.quantization_bit is not None:  # bnb
 | 
			
		||||
        if model_args.quantization_bit == 8:
 | 
			
		||||
            require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
 | 
			
		||||
            init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
 | 
			
		||||
    elif model_args.quantization_bit is not None:  # on-the-fly
 | 
			
		||||
        if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
 | 
			
		||||
            if model_args.quantization_bit == 8:
 | 
			
		||||
                require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
 | 
			
		||||
                init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
 | 
			
		||||
            elif model_args.quantization_bit == 4:
 | 
			
		||||
                require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
 | 
			
		||||
                init_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                    load_in_4bit=True,
 | 
			
		||||
                    bnb_4bit_compute_dtype=model_args.compute_dtype,
 | 
			
		||||
                    bnb_4bit_use_double_quant=model_args.double_quantization,
 | 
			
		||||
                    bnb_4bit_quant_type=model_args.quantization_type,
 | 
			
		||||
                    bnb_4bit_quant_storage=model_args.compute_dtype,  # crucial for fsdp+qlora
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
 | 
			
		||||
 | 
			
		||||
        elif model_args.quantization_bit == 4:
 | 
			
		||||
            require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
 | 
			
		||||
            init_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                load_in_4bit=True,
 | 
			
		||||
                bnb_4bit_compute_dtype=model_args.compute_dtype,
 | 
			
		||||
                bnb_4bit_use_double_quant=model_args.double_quantization,
 | 
			
		||||
                bnb_4bit_quant_type=model_args.quantization_type,
 | 
			
		||||
                bnb_4bit_quant_storage=model_args.compute_dtype,  # crucial for fsdp+qlora
 | 
			
		||||
            )
 | 
			
		||||
            # Do not assign device map if:
 | 
			
		||||
            # 1. deepspeed zero3 or fsdp (train)
 | 
			
		||||
            # 2. auto quantization device map (inference)
 | 
			
		||||
            if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
 | 
			
		||||
                if model_args.quantization_bit != 4:
 | 
			
		||||
                    raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
 | 
			
		||||
 | 
			
		||||
        # Do not assign device map if:
 | 
			
		||||
        # 1. deepspeed zero3 or fsdp (train)
 | 
			
		||||
        # 2. auto quantization device map (inference)
 | 
			
		||||
        if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
 | 
			
		||||
            if model_args.quantization_bit != 4:
 | 
			
		||||
                raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
 | 
			
		||||
                require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
 | 
			
		||||
            else:
 | 
			
		||||
                init_kwargs["device_map"] = {"": get_current_device()}  # change auto device map for inference
 | 
			
		||||
 | 
			
		||||
            require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
 | 
			
		||||
        else:
 | 
			
		||||
            init_kwargs["device_map"] = {"": get_current_device()}  # change auto device map for inference
 | 
			
		||||
            logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
 | 
			
		||||
        elif model_args.quantization_method == QuantizationMethod.HQQ.value:
 | 
			
		||||
            if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
 | 
			
		||||
                raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
 | 
			
		||||
 | 
			
		||||
        logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
 | 
			
		||||
            require_version("hqq", "To fix: pip install hqq")
 | 
			
		||||
            init_kwargs["quantization_config"] = HqqConfig(
 | 
			
		||||
                nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
 | 
			
		||||
            )  # use ATEN kernel (axis=0) for performance
 | 
			
		||||
            logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
 | 
			
		||||
        elif model_args.quantization_method == QuantizationMethod.EETQ.value:
 | 
			
		||||
            if model_args.quantization_bit != 8:
 | 
			
		||||
                raise ValueError("EETQ only accepts 8-bit quantization.")
 | 
			
		||||
 | 
			
		||||
            require_version("eetq", "To fix: pip install eetq")
 | 
			
		||||
            init_kwargs["quantization_config"] = EetqConfig()
 | 
			
		||||
            logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ from ..data import Role
 | 
			
		||||
from ..extras.constants import PEFT_METHODS
 | 
			
		||||
from ..extras.misc import torch_gc
 | 
			
		||||
from ..extras.packages import is_gradio_available
 | 
			
		||||
from .common import get_save_dir
 | 
			
		||||
from .common import QUANTIZATION_BITS, get_save_dir
 | 
			
		||||
from .locales import ALERTS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -76,11 +76,17 @@ class WebChatModel(ChatModel):
 | 
			
		||||
            yield error
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if get("top.quantization_bit") in QUANTIZATION_BITS:
 | 
			
		||||
            quantization_bit = int(get("top.quantization_bit"))
 | 
			
		||||
        else:
 | 
			
		||||
            quantization_bit = None
 | 
			
		||||
 | 
			
		||||
        yield ALERTS["info_loading"][lang]
 | 
			
		||||
        args = dict(
 | 
			
		||||
            model_name_or_path=model_path,
 | 
			
		||||
            finetuning_type=finetuning_type,
 | 
			
		||||
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
 | 
			
		||||
            quantization_bit=quantization_bit,
 | 
			
		||||
            quantization_method=get("top.quantization_method"),
 | 
			
		||||
            template=get("top.template"),
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
            use_unsloth=(get("top.booster") == "unsloth"),
 | 
			
		||||
 | 
			
		||||
@ -47,6 +47,8 @@ DEFAULT_CONFIG_DIR = "config"
 | 
			
		||||
DEFAULT_DATA_DIR = "data"
 | 
			
		||||
DEFAULT_SAVE_DIR = "saves"
 | 
			
		||||
USER_CONFIG = "user_config.yaml"
 | 
			
		||||
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
 | 
			
		||||
GPTQ_BITS = ["8", "4", "3", "2"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_save_dir(*paths: str) -> os.PathLike:
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS
 | 
			
		||||
from ...extras.misc import torch_gc
 | 
			
		||||
from ...extras.packages import is_gradio_available
 | 
			
		||||
from ...train.tuner import export_model
 | 
			
		||||
from ..common import get_save_dir
 | 
			
		||||
from ..common import GPTQ_BITS, get_save_dir
 | 
			
		||||
from ..locales import ALERTS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -32,9 +32,6 @@ if TYPE_CHECKING:
 | 
			
		||||
    from ..engine import Engine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
GPTQ_BITS = ["8", "4", "3", "2"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
 | 
			
		||||
    if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
 | 
			
		||||
        return gr.Dropdown(value="none", interactive=False)
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from ...data import TEMPLATES
 | 
			
		||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
 | 
			
		||||
from ...extras.packages import is_gradio_available
 | 
			
		||||
from ..common import get_model_info, list_checkpoints, save_config
 | 
			
		||||
from ..utils import can_quantize
 | 
			
		||||
from ..utils import can_quantize, can_quantize_to
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_gradio_available():
 | 
			
		||||
@ -43,10 +43,11 @@ def create_top() -> Dict[str, "Component"]:
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as advanced_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
 | 
			
		||||
            template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
 | 
			
		||||
            rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
 | 
			
		||||
            booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
 | 
			
		||||
            quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
 | 
			
		||||
            quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
 | 
			
		||||
            template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
 | 
			
		||||
            rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
 | 
			
		||||
            booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
 | 
			
		||||
            visual_inputs = gr.Checkbox(scale=1)
 | 
			
		||||
 | 
			
		||||
    model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
 | 
			
		||||
@ -58,6 +59,7 @@ def create_top() -> Dict[str, "Component"]:
 | 
			
		||||
        list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
 | 
			
		||||
    )
 | 
			
		||||
    checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
 | 
			
		||||
    quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
 | 
			
		||||
 | 
			
		||||
    return dict(
 | 
			
		||||
        lang=lang,
 | 
			
		||||
@ -67,6 +69,7 @@ def create_top() -> Dict[str, "Component"]:
 | 
			
		||||
        checkpoint_path=checkpoint_path,
 | 
			
		||||
        advanced_tab=advanced_tab,
 | 
			
		||||
        quantization_bit=quantization_bit,
 | 
			
		||||
        quantization_method=quantization_method,
 | 
			
		||||
        template=template,
 | 
			
		||||
        rope_scaling=rope_scaling,
 | 
			
		||||
        booster=booster,
 | 
			
		||||
 | 
			
		||||
@ -85,15 +85,29 @@ LOCALES = {
 | 
			
		||||
    "quantization_bit": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Quantization bit",
 | 
			
		||||
            "info": "Enable 4/8-bit model quantization (QLoRA).",
 | 
			
		||||
            "info": "Enable quantization (QLoRA).",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Уровень квантования",
 | 
			
		||||
            "info": "Включить 4/8-битное квантование модели (QLoRA).",
 | 
			
		||||
            "info": "Включить квантование (QLoRA).",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "量化等级",
 | 
			
		||||
            "info": "启用 4/8 比特模型量化(QLoRA)。",
 | 
			
		||||
            "info": "启用量化(QLoRA)。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "quantization_method": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Quantization method",
 | 
			
		||||
            "info": "Quantization algorithm to use.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Метод квантования",
 | 
			
		||||
            "info": "Алгоритм квантования, который следует использовать.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "量化方法",
 | 
			
		||||
            "info": "使用的量化算法。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "template": {
 | 
			
		||||
 | 
			
		||||
@ -71,6 +71,7 @@ class Manager:
 | 
			
		||||
            self._id_to_elem["top.finetuning_type"],
 | 
			
		||||
            self._id_to_elem["top.checkpoint_path"],
 | 
			
		||||
            self._id_to_elem["top.quantization_bit"],
 | 
			
		||||
            self._id_to_elem["top.quantization_method"],
 | 
			
		||||
            self._id_to_elem["top.template"],
 | 
			
		||||
            self._id_to_elem["top.rope_scaling"],
 | 
			
		||||
            self._id_to_elem["top.booster"],
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
 | 
			
		||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
 | 
			
		||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
 | 
			
		||||
from ..extras.packages import is_gradio_available
 | 
			
		||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
 | 
			
		||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
 | 
			
		||||
from .locales import ALERTS, LOCALES
 | 
			
		||||
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
 | 
			
		||||
 | 
			
		||||
@ -104,6 +104,11 @@ class Runner:
 | 
			
		||||
        model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
 | 
			
		||||
        user_config = load_config()
 | 
			
		||||
 | 
			
		||||
        if get("top.quantization_bit") in QUANTIZATION_BITS:
 | 
			
		||||
            quantization_bit = int(get("top.quantization_bit"))
 | 
			
		||||
        else:
 | 
			
		||||
            quantization_bit = None
 | 
			
		||||
 | 
			
		||||
        args = dict(
 | 
			
		||||
            stage=TRAINING_STAGES[get("train.training_stage")],
 | 
			
		||||
            do_train=True,
 | 
			
		||||
@ -111,7 +116,8 @@ class Runner:
 | 
			
		||||
            cache_dir=user_config.get("cache_dir", None),
 | 
			
		||||
            preprocessing_num_workers=16,
 | 
			
		||||
            finetuning_type=finetuning_type,
 | 
			
		||||
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
 | 
			
		||||
            quantization_bit=quantization_bit,
 | 
			
		||||
            quantization_method=get("top.quantization_method"),
 | 
			
		||||
            template=get("top.template"),
 | 
			
		||||
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
@ -234,13 +240,19 @@ class Runner:
 | 
			
		||||
        model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
 | 
			
		||||
        user_config = load_config()
 | 
			
		||||
 | 
			
		||||
        if get("top.quantization_bit") in QUANTIZATION_BITS:
 | 
			
		||||
            quantization_bit = int(get("top.quantization_bit"))
 | 
			
		||||
        else:
 | 
			
		||||
            quantization_bit = None
 | 
			
		||||
 | 
			
		||||
        args = dict(
 | 
			
		||||
            stage="sft",
 | 
			
		||||
            model_name_or_path=get("top.model_path"),
 | 
			
		||||
            cache_dir=user_config.get("cache_dir", None),
 | 
			
		||||
            preprocessing_num_workers=16,
 | 
			
		||||
            finetuning_type=finetuning_type,
 | 
			
		||||
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
 | 
			
		||||
            quantization_bit=quantization_bit,
 | 
			
		||||
            quantization_method=get("top.quantization_method"),
 | 
			
		||||
            template=get("top.template"),
 | 
			
		||||
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ from yaml import safe_dump, safe_load
 | 
			
		||||
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
 | 
			
		||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
 | 
			
		||||
from ..extras.ploting import gen_loss_plot
 | 
			
		||||
from ..model import QuantizationMethod
 | 
			
		||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
 | 
			
		||||
from .locales import ALERTS
 | 
			
		||||
 | 
			
		||||
@ -55,6 +56,18 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
 | 
			
		||||
        return gr.Dropdown(interactive=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
 | 
			
		||||
    r"""
 | 
			
		||||
    Returns the available quantization bits.
 | 
			
		||||
    """
 | 
			
		||||
    if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
 | 
			
		||||
        return gr.Dropdown(choices=["none", "8", "4"])
 | 
			
		||||
    elif quantization_method == QuantizationMethod.HQQ.value:
 | 
			
		||||
        return gr.Dropdown(choices=["none", "8", "6", "5", "4", "3", "2", "1"])
 | 
			
		||||
    elif quantization_method == QuantizationMethod.EETQ.value:
 | 
			
		||||
        return gr.Dropdown(choices=["none", "8"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Modifys states after changing the training stage.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user