From 96265ec154db57bcdb7fe0ccf29f32522e2930d6 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 15 Feb 2024 02:27:36 +0800 Subject: [PATCH] support llama pro #2338 , add rslora Former-commit-id: 7924ffc55d98e33bfbfbca303e46c8f476435673 --- README.md | 6 +- README_zh.md | 6 +- requirements.txt | 2 +- src/llmtuner/api/app.py | 32 ++++--- src/llmtuner/data/utils.py | 2 +- src/llmtuner/extras/misc.py | 3 + src/llmtuner/hparams/data_args.py | 44 ++++++--- src/llmtuner/hparams/evaluation_args.py | 34 +++++-- src/llmtuner/hparams/finetuning_args.py | 120 ++++++++++++++++-------- src/llmtuner/hparams/generating_args.py | 15 ++- src/llmtuner/hparams/model_args.py | 78 ++++++++++----- src/llmtuner/hparams/parser.py | 34 ++++--- src/llmtuner/model/__init__.py | 4 +- src/llmtuner/model/adapter.py | 51 +++++----- src/llmtuner/model/loader.py | 31 +++--- src/llmtuner/model/patcher.py | 5 + src/llmtuner/model/utils.py | 16 +--- src/llmtuner/train/utils.py | 22 +++-- tests/llama_pro.py | 108 +++++++++++++++++++++ tests/llamafy_baichuan2.py | 6 +- tests/llamafy_internlm2.py | 6 +- tests/llamafy_qwen.py | 6 +- tests/loftq_init.py | 2 +- tests/test_toolcall.py | 8 +- 24 files changed, 438 insertions(+), 203 deletions(-) create mode 100644 tests/llama_pro.py diff --git a/README.md b/README.md index ae489085..32150a7a 100644 --- a/README.md +++ b/README.md @@ -55,16 +55,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage. + [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details. [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`. +
Full Changelog + [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). -
Full Changelog - [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`. diff --git a/README_zh.md b/README_zh.md index dd6962f7..f99f91bc 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,16 +55,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 ## 更新日志 +[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。 + [24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。 [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。 +
展开日志 + [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 -
展开日志 - [23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 [23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。 diff --git a/requirements.txt b/requirements.txt index c74f4fa8..c754efce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch>=1.13.1 transformers>=4.37.2 datasets>=2.14.3 accelerate>=0.21.0 -peft>=0.7.0 +peft>=0.8.2 trl>=0.7.6 gradio>=3.38.0,<4.0.0 scipy diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 776e8c84..7b1560d3 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -74,6 +74,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": ) semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) + role_mapping = { + Role.USER: DataRole.USER, + Role.ASSISTANT: DataRole.ASSISTANT, + Role.SYSTEM: DataRole.SYSTEM, + Role.FUNCTION: DataRole.FUNCTION, + Role.TOOL: DataRole.OBSERVATION, + } @app.get("/v1/models", response_model=ModelList) async def list_models(): @@ -85,28 +92,27 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if not chat_model.can_generate: raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") - if len(request.messages) == 0 or request.messages[-1].role not in [Role.USER, Role.TOOL]: + if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") - messages = [dictify(message) for message in request.messages] - if len(messages) and messages[0]["role"] == Role.SYSTEM: - system = messages.pop(0)["content"] + if role_mapping[request.messages[0].role] == DataRole.SYSTEM: + system = request.messages.pop(0).content else: - system = None + system = "" - if len(messages) % 2 == 0: + if len(request.messages) % 2 == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") - for i in range(len(messages)): - if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]: + input_messages = [] + for i, message in enumerate(request.messages): + input_messages.append({"role": role_mapping[message.role], "content": message.content}) + if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") - elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]: + elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") - elif messages[i]["role"] == Role.TOOL: - messages[i]["role"] = DataRole.OBSERVATION tool_list = request.tools - if len(tool_list): + if isinstance(tool_list, list) and len(tool_list): try: tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False) except Exception: @@ -116,7 +122,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": async with semaphore: loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, chat_completion, messages, system, tools, request) + return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request) def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest): if request.stream: diff --git a/src/llmtuner/data/utils.py b/src/llmtuner/data/utils.py index 418186dd..90e3fa81 100644 --- a/src/llmtuner/data/utils.py +++ b/src/llmtuner/data/utils.py @@ -20,8 +20,8 @@ class Role(str, Enum): USER = "user" ASSISTANT = "assistant" SYSTEM = "system" - OBSERVATION = "observation" FUNCTION = "function" + OBSERVATION = "observation" def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index e892f39a..348f6c6c 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -10,6 +10,7 @@ from transformers.utils import ( WEIGHTS_NAME, is_torch_bf16_gpu_available, is_torch_cuda_available, + is_torch_mps_available, is_torch_npu_available, is_torch_xpu_available, ) @@ -133,6 +134,8 @@ def get_current_device() -> torch.device: device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) elif is_torch_npu_available(): device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_mps_available(): + device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0")) elif is_torch_cuda_available(): device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) else: diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index c387d37f..539e5489 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -9,30 +9,40 @@ class DataArguments: """ template: Optional[str] = field( - default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."} + default=None, + metadata={"help": "Which template to use for constructing prompts in training and inference."}, ) dataset: Optional[str] = field( default=None, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, ) dataset_dir: Optional[str] = field( - default="data", metadata={"help": "Path to the folder containing the datasets."} + default="data", + metadata={"help": "Path to the folder containing the datasets."}, ) split: Optional[str] = field( - default="train", metadata={"help": "Which dataset split to use for training and evaluation."} + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."}, ) cutoff_len: Optional[int] = field( - default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."} + default=1024, + metadata={"help": "The cutoff length of the model inputs after tokenization."}, ) reserved_label_len: Optional[int] = field( - default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."} + default=1, + metadata={"help": "The minimum cutoff length reserved for label after tokenization."}, ) train_on_prompt: Optional[bool] = field( - default=False, metadata={"help": "Whether to disable the mask on the prompt or not."} + default=False, + metadata={"help": "Whether to disable the mask on the prompt or not."}, + ) + streaming: Optional[bool] = field( + default=False, + metadata={"help": "Enable dataset streaming."}, ) - streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."}) buffer_size: Optional[int] = field( - default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, ) mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( default="concat", @@ -43,13 +53,16 @@ class DataArguments: metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, ) overwrite_cache: Optional[bool] = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets."} + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, ) preprocessing_num_workers: Optional[int] = field( - default=None, metadata={"help": "The number of processes to use for the preprocessing."} + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, ) max_samples: Optional[int] = field( - default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, ) eval_num_beams: Optional[int] = field( default=None, @@ -62,13 +75,16 @@ class DataArguments: }, ) val_size: Optional[float] = field( - default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} + default=0, + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, ) sft_packing: Optional[bool] = field( - default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} + default=False, + metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}, ) cache_path: Optional[str] = field( - default=None, metadata={"help": "Path to save or load the preprocessed datasets."} + default=None, + metadata={"help": "Path to save or load the preprocessed datasets."}, ) def __post_init__(self): diff --git a/src/llmtuner/hparams/evaluation_args.py b/src/llmtuner/hparams/evaluation_args.py index bd4263f9..4257f47b 100644 --- a/src/llmtuner/hparams/evaluation_args.py +++ b/src/llmtuner/hparams/evaluation_args.py @@ -11,15 +11,33 @@ class EvaluationArguments: Arguments pertaining to specify the evaluation parameters. """ - task: str = field(metadata={"help": "Name of the evaluation task."}) - task_dir: Optional[str] = field( - default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."} + task: str = field( + metadata={"help": "Name of the evaluation task."}, + ) + task_dir: Optional[str] = field( + default="evaluation", + metadata={"help": "Path to the folder containing the evaluation datasets."}, + ) + batch_size: Optional[int] = field( + default=4, + metadata={"help": "The batch size per GPU for evaluation."}, + ) + seed: Optional[int] = field( + default=42, + metadata={"help": "Random seed to be used with data loaders."}, + ) + lang: Optional[Literal["en", "zh"]] = field( + default="en", + metadata={"help": "Language used at evaluation."}, + ) + n_shot: Optional[int] = field( + default=5, + metadata={"help": "Number of examplars for few-shot learning."}, + ) + save_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to save the evaluation results."}, ) - batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."}) - seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."}) - lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."}) - n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."}) - save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."}) download_mode: Optional[DownloadMode] = field( default=DownloadMode.REUSE_DATASET_IF_EXISTS, metadata={"help": "Download mode used for the evaluation datasets."}, diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 7c143918..b76079e4 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -10,20 +10,25 @@ class FreezeArguments: """ name_module_trainable: Optional[str] = field( - default="mlp", + default=None, metadata={ - "help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \ - Use commas to separate multiple modules. \ - LLaMA choices: ["mlp", "self_attn"], \ - BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \ - Qwen choices: ["mlp", "attn"], \ - Phi choices: ["mlp", "mixer"], \ - InternLM2 choices: ["feed_forward", "attention"], \ - Others choices: the same as LLaMA.' + "help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + Use commas to separate multiple modules. \ + Use "all" to specify all the available modules. \ + LLaMA choices: ["mlp", "self_attn"], \ + BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \ + Qwen choices: ["mlp", "attn"], \ + InternLM2 choices: ["feed_forward", "attention"], \ + Others choices: the same as LLaMA.""" }, ) num_layer_trainable: Optional[int] = field( - default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."} + default=3, + metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}, + ) + use_llama_pro: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."}, ) @@ -40,27 +45,42 @@ class LoraArguments: }, ) lora_alpha: Optional[int] = field( - default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."} + default=None, + metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, + ) + lora_dropout: Optional[float] = field( + default=0.0, + metadata={"help": "Dropout rate for the LoRA fine-tuning."}, + ) + lora_rank: Optional[int] = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, ) - lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."}) - lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}) lora_target: Optional[str] = field( default=None, metadata={ - "help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ - LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ - BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \ - Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ - Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \ - Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \ - Others choices: the same as LLaMA.' + "help": """Name(s) of target modules to apply LoRA. \ + Use commas to separate multiple modules. \ + Use "all" to specify all the available modules. \ + LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ + BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \ + Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ + Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \ + InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \ + Others choices: the same as LLaMA.""" }, ) lora_bf16_mode: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."} + default=False, + metadata={"help": "Whether or not to train lora adapters in bf16 precision."}, + ) + use_rslora: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, ) create_new_adapter: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."} + default=False, + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, ) @@ -70,49 +90,65 @@ class RLHFArguments: Arguments pertaining to the PPO and DPO training. """ - dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."}) + dpo_beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."}, + ) dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( - default="sigmoid", metadata={"help": "The type of DPO loss to use."} + default="sigmoid", + metadata={"help": "The type of DPO loss to use."}, ) dpo_ftx: Optional[float] = field( - default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."} + default=0, + metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, ) ppo_buffer_size: Optional[int] = field( default=1, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, ) ppo_epochs: Optional[int] = field( - default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."} + default=4, + metadata={"help": "The number of epochs to perform in a PPO optimization step."}, ) ppo_logger: Optional[str] = field( - default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'} + default=None, + metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'}, ) ppo_score_norm: Optional[bool] = field( - default=False, metadata={"help": "Use score normalization in PPO training."} + default=False, + metadata={"help": "Use score normalization in PPO training."}, ) ppo_target: Optional[float] = field( - default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."} + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."}, ) ppo_whiten_rewards: Optional[bool] = field( - default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."} + default=False, + metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, ) ref_model: Optional[str] = field( - default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."} + default=None, + metadata={"help": "Path to the reference model used for the PPO or DPO training."}, ) ref_model_adapters: Optional[str] = field( - default=None, metadata={"help": "Path to the adapters of the reference model."} + default=None, + metadata={"help": "Path to the adapters of the reference model."}, ) ref_model_quantization_bit: Optional[int] = field( - default=None, metadata={"help": "The number of bits to quantize the reference model."} + default=None, + metadata={"help": "The number of bits to quantize the reference model."}, ) reward_model: Optional[str] = field( - default=None, metadata={"help": "Path to the reward model used for the PPO training."} + default=None, + metadata={"help": "Path to the reward model used for the PPO training."}, ) reward_model_adapters: Optional[str] = field( - default=None, metadata={"help": "Path to the adapters of the reward model."} + default=None, + metadata={"help": "Path to the adapters of the reward model."}, ) reward_model_quantization_bit: Optional[int] = field( - default=None, metadata={"help": "The number of bits to quantize the reward model."} + default=None, + metadata={"help": "The number of bits to quantize the reward model."}, ) reward_model_type: Optional[Literal["lora", "full", "api"]] = field( default="lora", @@ -127,16 +163,20 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): """ stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( - default="sft", metadata={"help": "Which stage will be performed in training."} + default="sft", + metadata={"help": "Which stage will be performed in training."}, ) finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( - default="lora", metadata={"help": "Which fine-tuning method to use."} + default="lora", + metadata={"help": "Which fine-tuning method to use."}, ) disable_version_checking: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to disable version checking."} + default=False, + metadata={"help": "Whether or not to disable version checking."}, ) plot_loss: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to save the training loss curves."} + default=False, + metadata={"help": "Whether or not to save the training loss curves."}, ) def __post_init__(self): diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index bfd1395f..06b5dfc3 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -9,10 +9,12 @@ class GeneratingArguments: """ do_sample: Optional[bool] = field( - default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, ) temperature: Optional[float] = field( - default=0.95, metadata={"help": "The value used to modulate the next token probabilities."} + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."}, ) top_p: Optional[float] = field( default=0.7, @@ -25,7 +27,8 @@ class GeneratingArguments: metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, ) num_beams: Optional[int] = field( - default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."} + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."}, ) max_length: Optional[int] = field( default=512, @@ -36,10 +39,12 @@ class GeneratingArguments: metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, ) repetition_penalty: Optional[float] = field( - default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, ) length_penalty: Optional[float] = field( - default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) def to_dict(self) -> Dict[str, Any]: diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 86006681..52cd973f 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -9,10 +9,13 @@ class ModelArguments: """ model_name_or_path: str = field( - metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."} + metadata={ + "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." + }, ) adapter_name_or_path: Optional[str] = field( - default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."} + default=None, + metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}, ) cache_dir: Optional[str] = field( default=None, @@ -23,7 +26,8 @@ class ModelArguments: metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, ) resize_vocab: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."} + default=False, + metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, ) split_special_tokens: Optional[bool] = field( default=False, @@ -34,60 +38,88 @@ class ModelArguments: metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) quantization_bit: Optional[int] = field( - default=None, metadata={"help": "The number of bits to quantize the model."} + default=None, + metadata={"help": "The number of bits to quantize the model."}, ) quantization_type: Optional[Literal["fp4", "nf4"]] = field( - default="nf4", metadata={"help": "Quantization data type to use in int4 training."} + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."}, ) double_quantization: Optional[bool] = field( - default=True, metadata={"help": "Whether or not to use double quantization in int4 training."} + default=True, + metadata={"help": "Whether or not to use double quantization in int4 training."}, ) rope_scaling: Optional[Literal["linear", "dynamic"]] = field( - default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."} + default=None, + metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) flash_attn: Optional[bool] = field( - default=False, metadata={"help": "Enable FlashAttention-2 for faster training."} + default=False, + metadata={"help": "Enable FlashAttention-2 for faster training."}, ) shift_attn: Optional[bool] = field( - default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) use_unsloth: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."} + default=False, + metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, ) disable_gradient_checkpointing: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to disable gradient checkpointing."} + default=False, + metadata={"help": "Whether or not to disable gradient checkpointing."}, ) upcast_layernorm: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."} + default=False, + metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, ) upcast_lmhead_output: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."} + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, + ) + hf_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."}, + ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."}, ) - hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}) - ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."}) export_dir: Optional[str] = field( - default=None, metadata={"help": "Path to the directory to save the exported model."} + default=None, + metadata={"help": "Path to the directory to save the exported model."}, ) export_size: Optional[int] = field( - default=1, metadata={"help": "The file shard size (in GB) of the exported model."} + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."}, ) export_quantization_bit: Optional[int] = field( - default=None, metadata={"help": "The number of bits to quantize the exported model."} + default=None, + metadata={"help": "The number of bits to quantize the exported model."}, ) export_quantization_dataset: Optional[str] = field( - default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."} + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, ) export_quantization_nsamples: Optional[int] = field( - default=128, metadata={"help": "The number of samples used for quantization."} + default=128, + metadata={"help": "The number of samples used for quantization."}, ) export_quantization_maxlen: Optional[int] = field( - default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."} + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."}, ) export_legacy_format: Optional[bool] = field( - default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."} + default=False, + metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, ) export_hub_model_id: Optional[str] = field( - default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."} + default=None, + metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, + ) + print_param_status: Optional[bool] = field( + default=False, + metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) def __post_init__(self): diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index a09f84bc..6685816c 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -30,12 +30,15 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] -def _check_dependencies(): - require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") - require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") - require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") - require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") - require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6") +def _check_dependencies(disabled: bool) -> None: + if disabled: + logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") + else: + require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") + require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") + require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") + require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2") + require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6") def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: @@ -130,6 +133,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.do_train and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True while training.") + if ( + training_args.do_train + and finetuning_args.finetuning_type == "freeze" + and finetuning_args.name_module_trainable is None + ): + raise ValueError("Please specify `name_module_trainable` in Freeze training.") + if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: raise ValueError("Please specify `lora_target` in LoRA training.") @@ -137,9 +147,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth") _verify_model_args(model_args, finetuning_args) - - if not finetuning_args.disable_version_checking: - _check_dependencies() + _check_dependencies(disabled=finetuning_args.disable_version_checking) if ( training_args.do_train @@ -240,13 +248,11 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: _set_transformers_logging() _verify_model_args(model_args, finetuning_args) + _check_dependencies(disabled=finetuning_args.disable_version_checking) if data_args.template is None: raise ValueError("Please specify which `template` to use.") - if not finetuning_args.disable_version_checking: - _check_dependencies() - return model_args, data_args, finetuning_args, generating_args @@ -255,13 +261,11 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: _set_transformers_logging() _verify_model_args(model_args, finetuning_args) + _check_dependencies(disabled=finetuning_args.disable_version_checking) if data_args.template is None: raise ValueError("Please specify which `template` to use.") - if not finetuning_args.disable_version_checking: - _check_dependencies() - transformers.set_seed(eval_args.seed) return model_args, data_args, eval_args, finetuning_args diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 6d598361..7d0d15d6 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,5 @@ from .loader import load_model_and_tokenizer -from .utils import dispatch_model, get_modelcard_args, load_valuehead_params +from .utils import dispatch_model, load_valuehead_params -__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"] +__all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"] diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 3d3f95c0..b0dc6489 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -1,8 +1,7 @@ -import inspect from typing import TYPE_CHECKING import torch -from peft import LoraConfig, PeftModel, TaskType, get_peft_model +from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger @@ -47,12 +46,22 @@ def init_adapter( if not num_layers: raise ValueError("Current model does not support freeze tuning.") - if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 - trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] - else: # fine-tuning the first n layers if num_layer_trainable < 0 - trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416 + if finetuning_args.use_llama_pro: + if num_layers % finetuning_args.num_layer_trainable != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( + num_layers, finetuning_args.num_layer_trainable + ) + ) - freeze_modules = set() + stride = num_layers // finetuning_args.num_layer_trainable + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers) + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = range(-finetuning_args.num_layer_trainable) + + freeze_modules = {"all"} for name, _ in model.named_modules(): if "0." in name: freeze_modules.add(name.split("0.")[-1].split(".")[0]) @@ -65,13 +74,13 @@ def init_adapter( ) for idx in trainable_layer_ids: - trainable_layers.append("{:d}.{}".format(idx, module_name)) + trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) for name, param in model.named_parameters(): - if not any(trainable_layer in name for trainable_layer in trainable_layers): - param.requires_grad_(False) - else: + if any(trainable_layer in name for trainable_layer in trainable_layers): param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") @@ -94,7 +103,7 @@ def init_adapter( adapter_to_merge = model_args.adapter_name_or_path for adapter in adapter_to_merge: - model = PeftModel.from_pretrained(model, adapter) + model: "LoraModel" = PeftModel.from_pretrained(model, adapter) model = model.merge_and_unload() if len(adapter_to_merge) > 0: @@ -114,22 +123,14 @@ def init_adapter( "target_modules": target_modules, "lora_alpha": finetuning_args.lora_alpha, "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, } if model_args.use_unsloth: - from unsloth import FastLlamaModel, FastMistralModel # type: ignore + from unsloth import FastLanguageModel # type: ignore unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} - if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters: - unsloth_peft_kwargs["loftq_config"] = {} - - if getattr(model.config, "model_type", None) == "llama": - model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) - elif getattr(model.config, "model_type", None) == "mistral": - model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) - else: - raise NotImplementedError - + model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) else: lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, @@ -142,7 +143,7 @@ def init_adapter( for param in filter(lambda p: p.requires_grad, model.parameters()): param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32) - if model_args.adapter_name_or_path is not None: - logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + if model_args.adapter_name_or_path is not None: + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) return model diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 6d75c15b..29d213a7 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -55,7 +55,7 @@ def load_model_and_tokenizer( model = None if is_trainable and model_args.use_unsloth: - from unsloth import FastLlamaModel, FastMistralModel # type: ignore + from unsloth import FastLanguageModel # type: ignore unsloth_kwargs = { "model_name": model_args.model_name_or_path, @@ -63,14 +63,12 @@ def load_model_and_tokenizer( "dtype": model_args.compute_dtype, "load_in_4bit": model_args.quantization_bit == 4, "token": model_args.hf_hub_token, - "device_map": get_current_device(), + "device_map": {"": get_current_device()}, "rope_scaling": getattr(config, "rope_scaling", None), } - if getattr(config, "model_type", None) == "llama": - model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs) - elif getattr(config, "model_type", None) == "mistral": - model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs) - else: + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) model_args.use_unsloth = False @@ -87,17 +85,6 @@ def load_model_and_tokenizer( **config_kwargs, ) - # Add llama-factory tag to push these tags on the Hub. - # the feature is available since 4.37.0 but adding the check - # just in case - if hasattr(model, "add_model_tags"): - model.add_model_tags(["llama-factory"]) - else: - logger.warning_once( - "Was not able to properly tag the model, if you want to use the model tagging feature, make sure to " - "have transformers>=4.37.0 installed on your environment." - ) - patch_model(model, tokenizer, model_args, is_trainable) register_autoclass(config, model, tokenizer) @@ -134,4 +121,12 @@ def load_model_and_tokenizer( if not is_trainable: logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") + if model_args.print_param_status: + for name, param in model.named_parameters(): + print( + "name: {}, dtype: {}, device: {}, trainable: {}".format( + name, param.dtype, param.device, param.requires_grad + ) + ) + return model, tokenizer diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index bb774e08..1dd9f2da 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -300,6 +300,11 @@ def patch_model( if is_trainable: patch_mixtral_replace_moe_impl() + try: + model.add_model_tags(["llama-factory"]) + except Exception: + logger.warning("Cannot properly tag the model.") + def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 9d45c290..02056330 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,5 +1,5 @@ import inspect -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Dict, List import torch from transformers import PreTrainedModel @@ -13,7 +13,7 @@ from ..extras.misc import get_current_device if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer - from ..hparams import DataArguments, FinetuningArguments, ModelArguments + from ..hparams import ModelArguments logger = get_logger(__name__) @@ -76,18 +76,6 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: return list(module_names) -def get_modelcard_args( - model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments" -) -> Dict[str, Any]: - return { - "tasks": "text-generation", - "license": "other", - "finetuned_from": model_args.model_name_or_path, - "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")], - "tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []), - } - - def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index e8d2603b..4af994ad 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -4,7 +4,7 @@ import torch from ..extras.logging import get_logger from ..hparams import FinetuningArguments, ModelArguments -from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params +from ..model import load_model_and_tokenizer, load_valuehead_params if TYPE_CHECKING: @@ -25,14 +25,18 @@ def create_modelcard_and_push( training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", ) -> None: - if training_args.do_train: - if training_args.push_to_hub: - trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args)) - return - try: - trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args)) - except Exception as err: - logger.warning("Failed to create model card: {}".format(str(err))) + kwargs = { + "tasks": "text-generation", + "finetuned_from": model_args.model_name_or_path, + "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")], + "tags": ["llama-factory", finetuning_args.finetuning_type], + } + if not training_args.do_train: + pass + elif training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub def create_ref_model( diff --git a/tests/llama_pro.py b/tests/llama_pro.py new file mode 100644 index 00000000..fe2ae71d --- /dev/null +++ b/tests/llama_pro.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Performs block expansion for LLaMA, Mistral or Qwen1.5 models. +# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 +# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py + +import json +import os +from collections import OrderedDict +from typing import TYPE_CHECKING, Optional + +import fire +import torch +from safetensors.torch import save_file +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + shard_checkpoint, +) + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + +def block_expansion( + model_name_or_path: str, + output_dir: str, + num_expand: int, + shard_size: Optional[str] = "2GB", + save_safetensors: Optional[bool] = False, +): + config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) + num_layers = getattr(config, "num_hidden_layers") + setattr(config, "num_hidden_layers", num_layers + num_expand) + config.save_pretrained(output_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(output_dir) + + model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + torch_dtype="auto", + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + state_dict = model.state_dict() + + if num_layers % num_expand != 0: + raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand)) + + split = num_layers // num_expand + layer_cnt = 0 + output_state_dict = OrderedDict() + for i in range(num_layers): + for key, value in state_dict.items(): + if ".{:d}.".format(i) in key: + output_state_dict[key.replace(".{:d}.".format(i), ".{:d}.".format(layer_cnt))] = value + + print("Add layer {} copied from layer {}".format(layer_cnt, i)) + layer_cnt += 1 + if (i + 1) % split == 0: + for key, value in state_dict.items(): + if ".{:d}.".format(i) in key: + if "down_proj" in key or "o_proj" in key: + output_state_dict[key.replace(".{:d}.".format(i), ".{:d}.".format(layer_cnt))] = ( + torch.zeros_like(value) + ) + else: + output_state_dict[key.replace(".{:d}.".format(i), ".{:d}.".format(layer_cnt))] = value + + print("Add layer {} expanded from layer {}".format(layer_cnt, i)) + layer_cnt += 1 + + for key, value in state_dict.items(): + if key not in output_state_dict: + output_state_dict[key] = value + + weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME + shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name) + + for shard_file, shard in tqdm(shards.items(), desc="Save weights"): + if save_safetensors: + save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"}) + else: + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is None: + print("Model weights saved in {}".format(os.path.join(output_dir, weights_name))) + else: + index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME + with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: + json.dump(index, f, indent=2, sort_keys=True) + print("Model weights saved in {}".format(output_dir)) + + print("Fine-tune this model with:") + print(" --model_name_or_path {} \\".format(output_dir)) + print(" --finetuning_type freeze \\") + print(" --name_module_trainable all \\") + print(" --num_layer_trainable {} \\".format(num_expand)) + print(" --use_llama_pro") + + +if __name__ == "__main__": + fire.Fire(block_expansion) diff --git a/tests/llamafy_baichuan2.py b/tests/llamafy_baichuan2.py index 91666bae..1ae58879 100644 --- a/tests/llamafy_baichuan2.py +++ b/tests/llamafy_baichuan2.py @@ -1,6 +1,6 @@ # coding=utf-8 # Converts the Baichuan2-7B model in the same format as LLaMA2-7B. -# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB +# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output # Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied @@ -76,7 +76,9 @@ def save_config(input_dir: str, output_dir: str): print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) -def llamafy_baichuan2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False): +def llamafy_baichuan2( + input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False +): try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/tests/llamafy_internlm2.py b/tests/llamafy_internlm2.py index f1c4c9af..b6b03e7d 100644 --- a/tests/llamafy_internlm2.py +++ b/tests/llamafy_internlm2.py @@ -1,6 +1,6 @@ # coding=utf-8 # Converts the InternLM2 model in the same format as LLaMA2. -# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB +# Usage: python llamafy_internlm2.py --input_dir input --output_dir output # Warning: We have found that the converted model cannot infer correctly. It will be fixed later. import json @@ -98,7 +98,9 @@ def save_config(input_dir: str, output_dir: str): print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) -def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False): +def llamafy_internlm2( + input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False +): try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/tests/llamafy_qwen.py b/tests/llamafy_qwen.py index 5b66b7ef..69cf3e8e 100644 --- a/tests/llamafy_qwen.py +++ b/tests/llamafy_qwen.py @@ -1,6 +1,6 @@ # coding=utf-8 # Converts the Qwen models in the same format as LLaMA2. -# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB +# Usage: python llamafy_qwen.py --input_dir input --output_dir output # Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied import json @@ -128,7 +128,9 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str): print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) -def llamafy_qwen(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False): +def llamafy_qwen( + input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False +): try: os.makedirs(output_dir, exist_ok=False) except Exception as e: diff --git a/tests/loftq_init.py b/tests/loftq_init.py index be7b07c5..7f244316 100644 --- a/tests/loftq_init.py +++ b/tests/loftq_init.py @@ -26,7 +26,7 @@ class Shell(nn.Module): def unwrap_model(model: nn.Module, pattern=".base_layer") -> None: - for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403 + for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}: parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] parent_module = model.get_submodule(parent_name) diff --git a/tests/test_toolcall.py b/tests/test_toolcall.py index 666a33e7..e3351e3e 100644 --- a/tests/test_toolcall.py +++ b/tests/test_toolcall.py @@ -1,13 +1,10 @@ import json -import os from typing import Sequence from openai import OpenAI from transformers.utils.versions import require_version -os.environ["OPENAI_BASE_URL"] = "http://192.168.0.1:8000/v1" -os.environ["OPENAI_API_KEY"] = "0" require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") @@ -24,7 +21,10 @@ tool_map = {"calculate_gpa": calculate_gpa} if __name__ == "__main__": - client = OpenAI() + client = OpenAI( + api_key="0", + base_url="http://localhost:8000/v1", + ) tools = [ { "type": "function",