mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
reimplement neftune
Former-commit-id: 7b4acf7265b04cc4a674b3dcafdb90e76f149e39
This commit is contained in:
parent
3927550fa3
commit
d6c77d9196
@ -22,7 +22,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/10/21] We supported [NEFTune](https://arxiv.org/abs/2310.05914) optimization . Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
||||||
|
|
||||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[23/10/21] 我们支持了 [NEFTune](https://arxiv.org/abs/2310.05914) 优化。试试`--neftune_noise_alpha` 参数来激活 NEFTune,例如,`--neftune_noise_alpha 5`。
|
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
||||||
|
|
||||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
||||||
|
|
||||||
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||||
|
|
||||||
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
|
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
|
||||||
|
|
||||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
|
@ -75,9 +75,13 @@ class FinetuningArguments:
|
|||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
neftune_noise_alpha: Optional[float] = field(
|
upcast_layernorm: Optional[bool] = field(
|
||||||
default=None,
|
default=False,
|
||||||
metadata={"help": "The alpha parameter for the NEFTune noise. By setting this the NEFTune optimization will be activated."}
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
|
)
|
||||||
|
neft_alpha: Optional[float] = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -62,10 +62,6 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
|
@ -206,8 +206,7 @@ def load_model_and_tokenizer(
|
|||||||
tokenizer.__class__.register_for_auto_class()
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
|
||||||
# Initialize adapters
|
# Initialize adapters
|
||||||
if is_trainable:
|
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||||
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
|
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
model = model.train() if is_trainable else model.eval()
|
model = model.train() if is_trainable else model.eval()
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ def get_train_args(
|
|||||||
if not finetuning_args.resume_lora_training:
|
if not finetuning_args.resume_lora_training:
|
||||||
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(
|
def find_all_linear_modules(
|
||||||
@ -31,8 +33,7 @@ def find_all_linear_modules(
|
|||||||
|
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
upcast_layernorm: bool,
|
finetuning_args: "FinetuningArguments",
|
||||||
finetuning_type: str,
|
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
@ -44,31 +45,42 @@ def prepare_model_for_training(
|
|||||||
(3) upcast the lm_head to fp32
|
(3) upcast the lm_head to fp32
|
||||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||||
"""
|
"""
|
||||||
if upcast_layernorm:
|
if finetuning_args.upcast_layernorm:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
if finetuning_args.neft_alpha > 1e-6:
|
||||||
|
input_embed: torch.nn.Embedding = model.get_input_embeddings()
|
||||||
|
|
||||||
|
def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
embeddings = input_embed.forward(x)
|
||||||
|
if self.training:
|
||||||
|
dims = self.num_embeddings * self.embedding_dim
|
||||||
|
mag_norm = finetuning_args.neft_alpha / (dims ** 0.5)
|
||||||
|
embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
input_embed.forward = MethodType(noisy_forward, input_embed)
|
||||||
|
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
else:
|
else:
|
||||||
def make_inputs_require_grad(module, input, output):
|
def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
|
||||||
output.requires_grad_(True)
|
output.requires_grad_(True)
|
||||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||||
input_dtype = output_layer.weight.dtype
|
input_dtype = output_layer.weight.dtype
|
||||||
|
|
||||||
class CastOutputToFloat(torch.nn.Sequential):
|
def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return output_layer.forward(x.to(input_dtype)).to(torch.float32)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
output_layer.forward = MethodType(forward_in_fp32, output_layer)
|
||||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
|
||||||
|
|
||||||
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -3,10 +3,8 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import wraps
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
from transformers import Seq2SeqTrainer, PreTrainedModel, Trainer
|
from transformers import Seq2SeqTrainer
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
@ -23,14 +21,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Union["PreTrainedModel", nn.Module] = None, neftune_noise_alpha: Optional[float] = 0, **kwargs):
|
|
||||||
super().__init__(model, **kwargs)
|
|
||||||
self.neftune_noise_alpha = neftune_noise_alpha
|
|
||||||
self._neftune_activated = False
|
|
||||||
|
|
||||||
if self.neftune_noise_alpha:
|
|
||||||
self._activate_neftune(model)
|
|
||||||
|
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -109,71 +99,3 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
for pred, label in zip(decoded_preds, decoded_labels):
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
||||||
writer.write("\n".join(res))
|
writer.write("\n".join(res))
|
||||||
|
|
||||||
|
|
||||||
@wraps(Trainer.train)
|
|
||||||
def train(self, *args, **kwargs):
|
|
||||||
output = super().train(*args, **kwargs)
|
|
||||||
|
|
||||||
# After training we make sure to retrieve back the original forward pass method
|
|
||||||
# for the embedding layer.
|
|
||||||
if self.neftune_noise_alpha is not None:
|
|
||||||
self._deactivate_neftune(self.model)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _toggle_neftune(self, model, activate=True):
|
|
||||||
"""Toggle NEFTune optimization for a model (i.e. activate or deactivate).
|
|
||||||
This optimization based on this paper: https://arxiv.org/abs/2310.05914
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
model : PreTrainedModel or PeftModel
|
|
||||||
The model to toggle the noise for.
|
|
||||||
activate : bool, optional (default=True)
|
|
||||||
Whether to activate the noise or not.
|
|
||||||
"""
|
|
||||||
if activate == self._neftune_activated:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._neftune_activated = activate
|
|
||||||
|
|
||||||
embeddings = (model.get_input_embeddings() if isinstance(model, PreTrainedModel)
|
|
||||||
else model.base_model.get_input_embeddings() if isinstance(model, PeftModel)
|
|
||||||
else None)
|
|
||||||
|
|
||||||
if embeddings:
|
|
||||||
if activate:
|
|
||||||
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
|
|
||||||
embeddings._trl_old_forward = embeddings.forward
|
|
||||||
neftune_method = _neftune_forward_function.__get__(embeddings, embeddings.__class__)
|
|
||||||
setattr(embeddings, "forward", neftune_method)
|
|
||||||
logger.info("NEFTune activated with alpha: ", self.neftune_noise_alpha)
|
|
||||||
elif hasattr(embeddings, "_trl_old_forward"):
|
|
||||||
embeddings.forward = embeddings._trl_old_forward
|
|
||||||
del embeddings._trl_old_forward
|
|
||||||
del embeddings.neftune_noise_alpha
|
|
||||||
logger.info("NEFTune deactivated")
|
|
||||||
|
|
||||||
_activate_neftune = lambda self, model: self._toggle_neftune(model, activate=True)
|
|
||||||
_deactivate_neftune = lambda self, model: self._toggle_neftune(model, activate=False)
|
|
||||||
|
|
||||||
|
|
||||||
def _neftune_forward_function(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This code is adapted from the original source code that can be found here: https://github.com/neelsjain/NEFTune
|
|
||||||
"""
|
|
||||||
embeddings = torch.nn.functional.embedding(
|
|
||||||
input,
|
|
||||||
self.weight,
|
|
||||||
self.padding_idx,
|
|
||||||
self.max_norm,
|
|
||||||
self.norm_type,
|
|
||||||
self.scale_grad_by_freq,
|
|
||||||
self.sparse)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
|
||||||
mag_norm = self.neftune_noise_alpha / torch.sqrt(dims)
|
|
||||||
embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
@ -53,7 +53,6 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
neftune_noise_alpha=finetuning_args.neftune_noise_alpha,
|
|
||||||
**split_dataset(dataset, data_args, training_args)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user