diff --git a/README.md b/README.md index d166c65c..e8933e7a 100644 --- a/README.md +++ b/README.md @@ -479,7 +479,10 @@ python src/export_model.py \ ``` > [!WARNING] -> Merging LoRA weights into a GPTQ quantized model is not supported. +> Merging LoRA weights into a quantized model is not supported. + +> [!TIP] +> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/wiki_demo.txt` to quantize the model. ### API Demo diff --git a/README_zh.md b/README_zh.md index dbf4b104..5f593c8e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -479,7 +479,10 @@ python src/export_model.py \ ``` > [!WARNING] -> 尚不支持 GPTQ 量化模型的 LoRA 权重合并及导出。 +> 尚不支持量化模型的 LoRA 权重合并及导出。 + +> [!TIP] +> 使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/wiki_demo.txt` 量化导出模型。 ### API 服务 diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 65f098c1..bbdde411 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset -from llmtuner.data.utils import checksum, EXT2TYPE +from llmtuner.data.utils import checksum +from llmtuner.extras.constants import FILEEXT2TYPE from llmtuner.extras.logging import get_logger if TYPE_CHECKING: @@ -39,12 +40,12 @@ def get_dataset( for file_name in os.listdir(local_path): data_files.append(os.path.join(local_path, file_name)) if data_path is None: - data_path = EXT2TYPE.get(file_name.split(".")[-1], None) + data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) else: - assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." + assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." elif os.path.isfile(local_path): # is file data_files.append(local_path) - data_path = EXT2TYPE.get(local_path.split(".")[-1], None) + data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) else: raise ValueError("File not found.") diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index fecadce4..9dfe4dc3 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -12,16 +12,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -EXT2TYPE = { - "arrow": "arrow", - "csv": "csv", - "json": "json", - "jsonl": "json", - "parquet": "parquet", - "txt": "text" -} - - def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: if file_sha1 is None: logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index e6cb7c38..26084aa8 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -9,6 +9,15 @@ DEFAULT_MODULE = defaultdict(str) DEFAULT_TEMPLATE = defaultdict(str) +FILEEXT2TYPE = { + "arrow": "arrow", + "csv": "csv", + "json": "json", + "jsonl": "json", + "parquet": "parquet", + "txt": "text" +} + IGNORE_INDEX = -100 LAYERNORM_NAMES = {"norm", "ln"} diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index bc3a9ead..93d3be91 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -125,7 +125,38 @@ class RLHFArguments: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): +class ExportArguments: + r""" + Arguments pertaining to model exporting. + """ + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."} + ) + export_size: Optional[int] = field( + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."} + ) + export_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the exported model."} + ) + export_quantization_dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."} + ) + export_quantization_nsamples: Optional[int] = field( + default=128, + metadata={"help": "The number of samples used for quantization."} + ) + export_quantization_maxlen: Optional[str] = field( + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."} + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ @@ -141,14 +172,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): default=False, metadata={"help": "Whether to upcast the layernorm weights in fp32."} ) - export_dir: Optional[str] = field( - default=None, - metadata={"help": "Path to the directory to save the exported model."} - ) - export_size: Optional[int] = field( - default=1, - metadata={"help": "The file shard size (in GB) of the exported model."} - ) plot_loss: Optional[bool] = field( default=False, metadata={"help": "Whether to plot the training loss after fine-tuning or not."} @@ -170,6 +193,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." if self.stage == "ppo" and self.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") @@ -177,6 +201,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.") + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") + def save_to_json(self, json_path: str): r"""Saves the content of this instance in JSON format inside `json_path`.""" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index f4421be1..72ca9782 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -62,7 +62,7 @@ def load_model_and_tokenizer( patcher.configure_rope(config, model_args, is_trainable) patcher.configure_flashattn(config_kwargs, model_args) patcher.configure_longlora(config, model_args, is_trainable) - patcher.configure_quantization(config, config_kwargs, model_args) + patcher.configure_quantization(config, config_kwargs, tokenizer, model_args, finetuning_args) model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index e90976bd..c7b916cb 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -1,12 +1,16 @@ +import os import math import torch +import random from types import MethodType -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, List +from datasets import load_dataset -from transformers import BitsAndBytesConfig, PreTrainedModel, PreTrainedTokenizerBase +from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version +from llmtuner.extras.constants import FILEEXT2TYPE from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import get_current_device, infer_optim_dtype from llmtuner.extras.packages import is_flash_attn2_available @@ -14,7 +18,7 @@ from llmtuner.extras.packages import is_flash_attn2_available if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import ModelArguments + from llmtuner.hparams import ModelArguments, FinetuningArguments logger = get_logger(__name__) @@ -36,7 +40,13 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", logger.warning("Current model does not support shift short attention.") -def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str, Any], model_args: "ModelArguments"): +def configure_quantization( + config: "PretrainedConfig", + config_kwargs: Dict[str, Any], + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments" +): if getattr(config, "quantization_config", None): # gptq or awq model_args.quantization_bit = None # remove bnb quantization config_kwargs["device_map"] = {"": get_current_device()} @@ -63,6 +73,16 @@ def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str, config_kwargs["device_map"] = {"": get_current_device()} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + if finetuning_args.export_quantization_bit is not None: # gptq + require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + config_kwargs["quantization_config"] = GPTQConfig( + bits=finetuning_args.export_quantization_bit, + dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args) + ) + config_kwargs["device_map"] = "auto" + logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit)) + def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): if model_args.rope_scaling is not None: @@ -91,6 +111,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ )) +def get_quantization_dataset( + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments" +) -> List[str]: + r""" + Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 + TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 + """ + if os.path.isfile(finetuning_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None) + data_files = finetuning_args.export_quantization_dataset + else: + data_path = finetuning_args.export_quantization_dataset + data_files = None + + dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) + maxlen = finetuning_args.export_quantization_maxlen + + samples = [] + for _ in range(finetuning_args.export_quantization_nsamples): + while True: + sample_idx = random.randint(0, len(dataset) - 1) + sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + if sample["input_ids"].size(1) >= maxlen: + break # TODO: fix large maxlen + + word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) + input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen] + samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) + + return samples + + def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"): if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 094aa50f..2ee4c0b7 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -38,10 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None): model_args, _, finetuning_args, _ = get_infer_args(args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - if getattr(model, "quantization_method", None) in ["gptq", "awq"]: - raise ValueError("Cannot export a GPTQ or AWQ quantized model.") + if getattr(model, "quantization_method", None): + raise ValueError("Cannot export a quantized model.") model.config.use_cache = True + model = model.to("cpu") model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size)) try: