mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
support quantization in export model
Former-commit-id: 3524aa1e58da94ab00e9a2024952ea1b4119b2af
This commit is contained in:
parent
2db4cfab40
commit
7dbc670902
@ -479,7 +479,10 @@ python src/export_model.py \
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> Merging LoRA weights into a GPTQ quantized model is not supported.
|
> Merging LoRA weights into a quantized model is not supported.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/wiki_demo.txt` to quantize the model.
|
||||||
|
|
||||||
### API Demo
|
### API Demo
|
||||||
|
|
||||||
|
@ -479,7 +479,10 @@ python src/export_model.py \
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> 尚不支持 GPTQ 量化模型的 LoRA 权重合并及导出。
|
> 尚不支持量化模型的 LoRA 权重合并及导出。
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/wiki_demo.txt` 量化导出模型。
|
||||||
|
|
||||||
### API 服务
|
### API 服务
|
||||||
|
|
||||||
|
@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
from llmtuner.data.utils import checksum, EXT2TYPE
|
from llmtuner.data.utils import checksum
|
||||||
|
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -39,12 +40,12 @@ def get_dataset(
|
|||||||
for file_name in os.listdir(local_path):
|
for file_name in os.listdir(local_path):
|
||||||
data_files.append(os.path.join(local_path, file_name))
|
data_files.append(os.path.join(local_path, file_name))
|
||||||
if data_path is None:
|
if data_path is None:
|
||||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||||
elif os.path.isfile(local_path): # is file
|
elif os.path.isfile(local_path): # is file
|
||||||
data_files.append(local_path)
|
data_files.append(local_path)
|
||||||
data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
|
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
raise ValueError("File not found.")
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
|
@ -12,16 +12,6 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
EXT2TYPE = {
|
|
||||||
"arrow": "arrow",
|
|
||||||
"csv": "csv",
|
|
||||||
"json": "json",
|
|
||||||
"jsonl": "json",
|
|
||||||
"parquet": "parquet",
|
|
||||||
"txt": "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
if file_sha1 is None:
|
if file_sha1 is None:
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
|
@ -9,6 +9,15 @@ DEFAULT_MODULE = defaultdict(str)
|
|||||||
|
|
||||||
DEFAULT_TEMPLATE = defaultdict(str)
|
DEFAULT_TEMPLATE = defaultdict(str)
|
||||||
|
|
||||||
|
FILEEXT2TYPE = {
|
||||||
|
"arrow": "arrow",
|
||||||
|
"csv": "csv",
|
||||||
|
"json": "json",
|
||||||
|
"jsonl": "json",
|
||||||
|
"parquet": "parquet",
|
||||||
|
"txt": "text"
|
||||||
|
}
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
@ -125,7 +125,38 @@ class RLHFArguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
class ExportArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to model exporting.
|
||||||
|
"""
|
||||||
|
export_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory to save the exported model."}
|
||||||
|
)
|
||||||
|
export_size: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||||
|
)
|
||||||
|
export_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the exported model."}
|
||||||
|
)
|
||||||
|
export_quantization_dataset: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
||||||
|
)
|
||||||
|
export_quantization_nsamples: Optional[int] = field(
|
||||||
|
default=128,
|
||||||
|
metadata={"help": "The number of samples used for quantization."}
|
||||||
|
)
|
||||||
|
export_quantization_maxlen: Optional[str] = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={"help": "The maximum length of the model inputs used for quantization."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments):
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
@ -141,14 +172,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
)
|
)
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
|
||||||
)
|
|
||||||
export_size: Optional[int] = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
@ -170,6 +193,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
@ -177,6 +201,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||||
|
|
||||||
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
@ -62,7 +62,7 @@ def load_model_and_tokenizer(
|
|||||||
patcher.configure_rope(config, model_args, is_trainable)
|
patcher.configure_rope(config, model_args, is_trainable)
|
||||||
patcher.configure_flashattn(config_kwargs, model_args)
|
patcher.configure_flashattn(config_kwargs, model_args)
|
||||||
patcher.configure_longlora(config, model_args, is_trainable)
|
patcher.configure_longlora(config, model_args, is_trainable)
|
||||||
patcher.configure_quantization(config, config_kwargs, model_args)
|
patcher.configure_quantization(config, config_kwargs, tokenizer, model_args, finetuning_args)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
|
import os
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import BitsAndBytesConfig, PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import FILEEXT2TYPE
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
|
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
@ -14,7 +18,7 @@ from llmtuner.extras.packages import is_flash_attn2_available
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@ -36,7 +40,13 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
logger.warning("Current model does not support shift short attention.")
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
|
|
||||||
def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str, Any], model_args: "ModelArguments"):
|
def configure_quantization(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
config_kwargs: Dict[str, Any],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments"
|
||||||
|
):
|
||||||
if getattr(config, "quantization_config", None): # gptq or awq
|
if getattr(config, "quantization_config", None): # gptq or awq
|
||||||
model_args.quantization_bit = None # remove bnb quantization
|
model_args.quantization_bit = None # remove bnb quantization
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
@ -63,6 +73,16 @@ def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str,
|
|||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
if finetuning_args.export_quantization_bit is not None: # gptq
|
||||||
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
|
config_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
bits=finetuning_args.export_quantization_bit,
|
||||||
|
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
|
||||||
|
)
|
||||||
|
config_kwargs["device_map"] = "auto"
|
||||||
|
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
||||||
|
|
||||||
|
|
||||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||||
if model_args.rope_scaling is not None:
|
if model_args.rope_scaling is not None:
|
||||||
@ -91,6 +111,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
|||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def get_quantization_dataset(
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments"
|
||||||
|
) -> List[str]:
|
||||||
|
r"""
|
||||||
|
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||||
|
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||||
|
"""
|
||||||
|
if os.path.isfile(finetuning_args.export_quantization_dataset):
|
||||||
|
data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None)
|
||||||
|
data_files = finetuning_args.export_quantization_dataset
|
||||||
|
else:
|
||||||
|
data_path = finetuning_args.export_quantization_dataset
|
||||||
|
data_files = None
|
||||||
|
|
||||||
|
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||||
|
maxlen = finetuning_args.export_quantization_maxlen
|
||||||
|
|
||||||
|
samples = []
|
||||||
|
for _ in range(finetuning_args.export_quantization_nsamples):
|
||||||
|
while True:
|
||||||
|
sample_idx = random.randint(0, len(dataset) - 1)
|
||||||
|
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||||
|
if sample["input_ids"].size(1) >= maxlen:
|
||||||
|
break # TODO: fix large maxlen
|
||||||
|
|
||||||
|
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||||
|
input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen]
|
||||||
|
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
|
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
@ -38,10 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
|||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
|
||||||
if getattr(model, "quantization_method", None) in ["gptq", "awq"]:
|
if getattr(model, "quantization_method", None):
|
||||||
raise ValueError("Cannot export a GPTQ or AWQ quantized model.")
|
raise ValueError("Cannot export a quantized model.")
|
||||||
|
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
|
model = model.to("cpu")
|
||||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
|
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user