From 108c31e1fcc437033804d5080ac803001786343b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 27 Sep 2023 21:55:50 +0800 Subject: [PATCH] support LongLoRA Former-commit-id: 90375f600d5601866836123597fa3ef52008eeef --- README.md | 4 +- README_zh.md | 4 +- src/llmtuner/dsets/preprocess.py | 35 ++- src/llmtuner/extras/patches/flash_llama.py | 301 --------------------- src/llmtuner/extras/patches/llama_patch.py | 232 ++++++++++++++++ src/llmtuner/hparams/data_args.py | 4 + src/llmtuner/hparams/model_args.py | 6 +- src/llmtuner/tuner/core/loader.py | 56 ++-- 8 files changed, 313 insertions(+), 329 deletions(-) delete mode 100644 src/llmtuner/extras/patches/flash_llama.py create mode 100644 src/llmtuner/extras/patches/llama_patch.py diff --git a/README.md b/README.md index bbcb3b6f..50157e50 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ ## Changelog +[23/09/27] We supported **S^2-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA). Try `--shift_attn` argument to enable shift short attention. + [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. [23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. @@ -50,7 +52,7 @@ | [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | -| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | +| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | diff --git a/README_zh.md b/README_zh.md index e9f0ccd7..2b9d8533 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,6 +14,8 @@ ## 更新日志 +[23/09/27] 我们支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **S^2-Attn**。请使用 `--shift_attn` 参数以启用该功能。 + [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 [23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2(实验性功能)。 @@ -50,7 +52,7 @@ | [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan | | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 | | [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern | -| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | +| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B/14B | c_attn | chatml | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 320a54ef..5031a817 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -22,6 +22,9 @@ def preprocess_dataset( column_names = list(next(iter(dataset)).keys()) template = get_template_and_fix_tokenizer(data_args.template, tokenizer) + if template.efficient_eos and data_args.sft_packing: + raise ValueError("Current template is incompatible with packing.") + def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: for i in range(len(examples["prompt"])): query, response = examples["prompt"][i], examples["response"][i] @@ -96,6 +99,28 @@ def preprocess_dataset( return model_inputs + def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: + # build inputs with format ` X Y ` and labels with format ` ... Y ` + # we do not mask the inputs in packed training. + model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} + input_ids, labels = [], [] + for query, response, history, system in construct_example(examples): + for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): + input_ids += source_ids + target_ids + labels += source_ids + target_ids # TODO: try masking source_ids here + + total_length = len(input_ids) + block_size = data_args.cutoff_len + # we drop the small remainder, and if the total_length < block_size, we exclude this batch + total_length = (total_length // block_size) * block_size + # split by chunks of cutoff_len + for i in range(0, total_length, block_size): + model_inputs["input_ids"].append(input_ids[i: i + block_size]) + model_inputs["attention_mask"].append([1] * len(block_size)) + model_inputs["labels"].append(labels[i: i + block_size]) + + return model_inputs + def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} @@ -166,19 +191,19 @@ def preprocess_dataset( if stage == "pt": dataset = dataset.filter(lambda example: example["prompt"]) - preprocess_function = preprocess_pretrain_dataset + preprocess_func = preprocess_pretrain_dataset print_function = print_unsupervised_dataset_example elif stage == "sft" and not training_args.predict_with_generate: dataset = dataset.filter(lambda example: example["prompt"] and example["response"]) - preprocess_function = preprocess_supervised_dataset + preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset print_function = print_supervised_dataset_example elif stage == "rm": dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1) - preprocess_function = preprocess_pairwise_dataset + preprocess_func = preprocess_pairwise_dataset print_function = print_pairwise_dataset_example else: dataset = dataset.filter(lambda example: example["prompt"]) - preprocess_function = preprocess_unsupervised_dataset + preprocess_func = preprocess_unsupervised_dataset print_function = print_unsupervised_dataset_example with training_args.main_process_first(desc="dataset map pre-processing"): @@ -191,7 +216,7 @@ def preprocess_dataset( ) dataset = dataset.map( - preprocess_function, + preprocess_func, batched=True, remove_columns=column_names, **kwargs diff --git a/src/llmtuner/extras/patches/flash_llama.py b/src/llmtuner/extras/patches/flash_llama.py deleted file mode 100644 index 1d6ee66d..00000000 --- a/src/llmtuner/extras/patches/flash_llama.py +++ /dev/null @@ -1,301 +0,0 @@ -# coding=utf-8 -# Modified from: -# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py -# [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py -# [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py -# [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py -# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17 - -import torch -from typing import TYPE_CHECKING, Optional, Tuple -from transformers.utils import logging - -if TYPE_CHECKING: - from transformers.models.llama.configuration_llama import LlamaConfig - -try: - from flash_attn.flash_attn_interface import ( - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func - ) - from flash_attn.bert_padding import pad_input, unpad_input - print(">>>> FlashAttention installed") -except ImportError: - raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention") - -try: - from flash_attn.layers.rotary import apply_rotary_emb_func - print(">>>> Flash RoPE installed") -except ImportError: - raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention") - - -logger = logging.get_logger(__name__) - - -class LlamaRMSNorm(torch.nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return (self.weight * hidden_states).to(input_dtype) # for fp32 weight - - -class FlashRotaryEmbedding(torch.nn.Module): - - def __init__( - self, - dim: int, - base=10000.0, - interleaved=False, - scale_base=None, - scaling_factor=1.0, - pos_idx_in_fp32=True, - device=None - ): - super().__init__() - self.dim = dim - self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 - # Generate and save the inverse frequency buffer (non trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.interleaved = interleaved - self.scale_base = scale_base - self.scaling_factor = scaling_factor - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None else None - ) - self.register_buffer("scale", scale) - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - - def _compute_inv_freq(self, device=None): - return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) - - def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - if ( - seqlen > self._seq_len_cached or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - t /= self.scaling_factor - if self.inv_freq.dtype != torch.float32: - inv_freq = self.inv_freq.to(torch.float32) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - t /= self.scaling_factor - inv_freq = self.inv_freq - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - (torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base - ) - scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) - # We want the multiplication by scale to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: - r""" - q: (batch, seqlen, nheads, headdim) - k: (batch, seqlen, nheads, headdim) - seqlen_offset: can be used in generation where the qkv being passed in is only the last - token in the batch. - """ - self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype) - if self.scale is None: - return apply_rotary_emb_func( - q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], - self.interleaved, True # inplace=True - ), apply_rotary_emb_func( - k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], - self.interleaved, True # inplace=True - ) - else: - assert False - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - r""" - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim) - return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim) - - -class LlamaAttention(torch.nn.Module): - - def __init__(self, config: "LlamaConfig"): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.register_buffer( - "norm_factor", - torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), - persistent=False, - ) - - if self.config.rope_scaling is None: - scaling_factor = 1 - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - assert scaling_type == "linear" - - self.rotary_emb = FlashRotaryEmbedding( - self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, h_size = hidden_states.size() - - has_layer_past = past_key_value is not None - - if has_layer_past: - past_kv = past_key_value[0] - past_len = past_key_value[1] - else: - past_len = 0 - - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - q = q.view(bsz, q_len, self.num_heads, self.head_dim) - k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - q, k = self.rotary_emb(q, k, past_len) - - kv = torch.stack([k, v], 2) - kv = repeat_kv(kv, self.num_key_value_groups) - - # Cache QKV values - if has_layer_past: - new_len = past_len+q.size(1) - if new_len > past_kv.size(1): - past_kv = torch.cat( - [past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], - dim=1 - ) - past_kv[:, past_len:new_len] = kv - kv = past_kv[:, :new_len] - else: - past_kv = kv - - past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None - - if attention_mask is not None: - # varlen, ignore padding tokens, efficient for large batch with many paddings - logger.warning_once("padded sequences is less efficient") - - unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask) - unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):]) - attn_outputs = flash_attn_varlen_kvpacked_func( - unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - dropout_p=0.0, softmax_scale=1.0 / self.norm_factor, - causal=(not has_layer_past), return_attn_probs=output_attentions - ) - - attn_output = attn_outputs[0] if output_attentions else attn_outputs - attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size) - attn_weights = attn_outputs[2] if output_attentions else None - - else: - # no padding tokens, more efficient - attn_outputs = flash_attn_kvpacked_func( - q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor, - causal=(not has_layer_past), return_attn_probs=output_attentions - ) - attn_output = attn_outputs[0] if output_attentions else attn_outputs - attn_output = attn_output.reshape(bsz, q_len, h_size) - attn_weights = attn_outputs[2] if output_attentions else None - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Disable the transformation of the attention mask in LlamaModel as flash attention -# takes a boolean key_padding_mask. Fills in the past kv length for use in forward. -def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length -): - # [bsz, seq_len] - if past_key_values_length > 0 and attention_mask is not None: - attention_mask = torch.cat( - ( - torch.full( - (input_shape[0], past_key_values_length), - True, - dtype=attention_mask.dtype, - device=attention_mask.device - ), - attention_mask - ), - dim=-1 - ) - - if attention_mask is not None and torch.all(attention_mask): - return None # This uses the faster call when training with full samples - - return attention_mask diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py new file mode 100644 index 00000000..ba4e603c --- /dev/null +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Modified from: +# [1] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple +from transformers.utils import logging +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore + from flash_attn.bert_padding import pad_input, unpad_input # type: ignore +except ImportError: + raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention") + + +logger = logging.get_logger(__name__) + + +class LlamaRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + + +class LlamaShiftShortAttention(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if getattr(self, "num_key_value_groups"): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self, "shift_ratio", None) and self.training: # shift + group_size = int(q_len * getattr(self, "shift_ratio")) + if q_len % group_size > 0: + raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size)) + num_group = q_len // group_size + for state in (query_states, key_states, value_states): + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1) + state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim).transpose(1, 2) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self, "shift_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if getattr(self, "num_key_value_groups"): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) + + if getattr(self, "shift_ratio", None) and self.training: # shift + group_size = int(q_len * getattr(self, "shift_ratio")) + if q_len % group_size > 0: + raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size)) + num_group = q_len // group_size + for state in (query_states, key_states, value_states): + state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1) + state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim) + + if attention_mask is not None: + logger.warning_once("Padded sequences are less efficient.") + batch_size = query_states.shape[0] + # -q_len: assumes left padding + unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:]) + unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) + unpadded_v, _, _, _ = unpad_input(value_states, attention_mask) + attn_output_unpad = flash_attn_varlen_func( + unpadded_q, + unpadded_k, + unpadded_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=True, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True + ) + + if getattr(self, "shift_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as flash attention +# takes a boolean padding_mask. Fills in the past kv length for use in forward. +def _prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: torch.Tensor, + inputs_embeds: torch.Tensor, + past_key_values_length: int +) -> torch.Tensor: + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples + + return attention_mask diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 02a603f0..44797ba1 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -90,6 +90,10 @@ class DataArguments: default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} ) + sft_packing: Optional[bool] = field( + default=False, + metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} + ) def init_for_training(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 0d7a3d52..e5bbc04c 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -45,7 +45,11 @@ class ModelArguments: ) flash_attn: Optional[bool] = field( default=False, - metadata={"help": "Enable flash attention for faster training."} + metadata={"help": "Enable FlashAttention-2 for faster training."} + ) + shift_attn: Optional[bool] = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} ) checkpoint_dir: Optional[str] = field( default=None, diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 911b0c5d..36525d33 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -13,17 +13,19 @@ from transformers import ( PreTrainedModel, PreTrainedTokenizerBase ) +from transformers.models.llama import modeling_llama as LlamaModule from transformers.utils import check_min_version from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead try: from transformers.integrations import is_deepspeed_zero3_enabled -except ImportError: +except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1 from transformers.deepspeed import is_deepspeed_zero3_enabled from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.misc import count_parameters +from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.hparams import FinetuningArguments from llmtuner.tuner.core.adapter import init_adapter @@ -73,10 +75,6 @@ def load_model_and_tokenizer( **config_kwargs ) - # Fix tokenizer (for ChatGLM2) - if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): - tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) - if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: model_to_load = model_args.checkpoint_dir[0] else: @@ -84,10 +82,15 @@ def load_model_and_tokenizer( config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) + # Fix tokenizer (for ChatGLM2) + if getattr(config, "model_type", None) == "chatglm": + tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + # Fix config (for Qwen) - if hasattr(config, "fp16") and hasattr(config, "bf16"): + if getattr(config, "model_type", None) == "qwen": setattr(config, "fp16", model_args.compute_dtype == torch.float16) setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16) + setattr(config, "fp32", model_args.compute_dtype == torch.float32) # Set RoPE scaling if model_args.rope_scaling is not None: @@ -103,7 +106,6 @@ def load_model_and_tokenizer( require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0") if is_trainable: if model_args.rope_scaling == "dynamic": - assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling." logger.warning( "Dynamic NTK may not work well with fine-tuning. " "See: https://github.com/huggingface/transformers/pull/24653" @@ -126,17 +128,23 @@ def load_model_and_tokenizer( else: logger.warning("Current model does not support RoPE scaling.") - # Set flash attention - if model_args.flash_attn and getattr(config, "model_type", None) == "llama": - import transformers.models.llama.modeling_llama as LlamaModule - import llmtuner.extras.patches.flash_llama as FlashLlama - LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm - LlamaModule.LlamaAttention = FlashLlama.LlamaAttention - LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask - if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models - setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads")) - if getattr(config, "pretraining_tp", 1) != 1: - setattr(config, "pretraining_tp", 1) + # Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535) + if getattr(config, "model_type", None) == "llama": + LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm + + # Set FlashAttention-2 + if model_args.flash_attn: + if getattr(config, "model_type", None) == "llama": + LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 + LlamaModule.LlamaModel._prepare_decoder_attention_mask = ( + LlamaPatches._prepare_decoder_attention_mask + ) + logger.info("Using FlashAttention-2 for faster training and inference.") + else: + logger.warning("Current model does not support FlashAttention-2.") + elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": + LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention + logger.warning("Using `--flash_attn` for faster training in large context length.") # Quantization configurations (using bitsandbytes library). is_mergeable = True @@ -172,12 +180,20 @@ def load_model_and_tokenizer( **config_kwargs ) - # Disable custom generate method (for Qwen) + # Set shift short attention (S^2-Attn) + if is_trainable and model_args.shift_attn: + if getattr(config, "model_type", None) == "llama": + setattr(model, "shift_ratio", 0.25) + logger.info("Using shift short attention proposed by LongLoRA.") + else: + logger.warning("Current model does not support shift short attention.") + + # Disable custom generate method (for Qwen and Baichuan2) if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) # Fix LM head (for ChatGLM2) - if not hasattr(model, "lm_head") and hasattr(model, "transformer"): + if getattr(config, "model_type", None) == "chatglm": setattr(model, "lm_head", model.transformer.output_layer) # Register auto class to save the custom code files.