diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index 42691194..046fe094 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -55,46 +55,32 @@ class LlamaShiftShortAttention(LlamaAttention): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if getattr(self, "shift_ratio", None) and self.training: # shift - group_size = int(q_len * getattr(self, "shift_ratio")) - if q_len % group_size > 0: - raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size)) - num_group = q_len // group_size - for state in (query_states, key_states, value_states): + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + def shift(state: torch.Tensor) -> torch.Tensor: state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) - state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1) - state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim).transpose(1, 2) + state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) attn_output = attn_output.transpose(1, 2).contiguous() - if getattr(self, "shift_ratio", None) and self.training: # shift back + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) - attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) + attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -160,19 +146,21 @@ class LlamaFlashAttention2(LlamaAttention): key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) - if getattr(self, "shift_ratio", None) and self.training: # shift - group_size = int(q_len * getattr(self, "shift_ratio")) - if q_len % group_size > 0: - raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size)) - num_group = q_len // group_size - for state in (query_states, key_states, value_states): - state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1) - state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim) + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + def shift(state: torch.Tensor) -> torch.Tensor: + state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask.reshape(bsz * num_groups, groupsz) if attention_mask is not None: logger.warning_once("Padded sequences are less efficient in FlashAttention.") - batch_size = query_states.shape[0] - # -q_len: assumes left padding + # -q_len: assumes left padding when q_len != kv_len unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:]) unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) unpadded_v, _, _, _ = unpad_input(value_states, attention_mask) @@ -188,15 +176,15 @@ class LlamaFlashAttention2(LlamaAttention): softmax_scale=None, causal=True, ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) else: attn_output = flash_attn_func( query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True ) - if getattr(self, "shift_ratio", None) and self.training: # shift back + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) - attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1) + attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 0570b33c..820307a7 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -103,7 +103,6 @@ def load_model_and_tokenizer( logger.info("Using dynamic NTK scaling.") elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models - require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0") if is_trainable: if model_args.rope_scaling == "dynamic": logger.warning( @@ -128,7 +127,7 @@ def load_model_and_tokenizer( else: logger.warning("Current model does not support RoPE scaling.") - # Set FlashAttention-2 and S^2-Attn + # Set FlashAttention-2 if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 @@ -136,12 +135,22 @@ def load_model_and_tokenizer( LlamaPatches._prepare_decoder_attention_mask ) logger.info("Using FlashAttention-2 for faster training and inference.") + elif getattr(config, "model_type", None) == "qwen": + logger.info("Qwen models automatically enable FlashAttention if installed.") else: logger.warning("Current model does not support FlashAttention-2.") elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention logger.warning("Using `--flash_attn` for faster training in large context length.") + # Set shift short attention (S^2-Attn) + if is_trainable and model_args.shift_attn: + if getattr(config, "model_type", None) == "llama": + setattr(config, "group_size_ratio", 0.25) + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") + # Quantization configurations (using bitsandbytes library). is_mergeable = True if model_args.quantization_bit is not None: @@ -176,14 +185,6 @@ def load_model_and_tokenizer( **config_kwargs ) - # Set shift short attention (S^2-Attn) - if is_trainable and model_args.shift_attn: - if getattr(config, "model_type", None) == "llama": - setattr(model, "shift_ratio", 0.25) - logger.info("Using shift short attention proposed by LongLoRA.") - else: - logger.warning("Current model does not support shift short attention.") - # Disable custom generate method (for Qwen and Baichuan2) if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index d2d0113a..ff3ddff5 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -149,6 +149,9 @@ def get_train_args( if general_args.stage == "ppo" and data_args.streaming: raise ValueError("Streaming mode does not suppport PPO training currently.") + if general_args.stage == "ppo" and model_args.shift_attn: + raise ValueError("PPO training is incompatible with S^2-Attn.") + if training_args.max_steps == -1 and data_args.streaming: raise ValueError("Please specify `max_steps` in streaming mode.") diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index 4abd3894..545485c6 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -29,6 +29,7 @@ def run_dpo( dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") data_collator = DPODataCollatorWithPadding( tokenizer=tokenizer, + pad_to_multiple_of=4, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id ) diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index edc8e7c5..6d2c4422 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -27,7 +27,7 @@ def run_rm( dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") - data_collator = PairwiseDataCollatorWithPadding(tokenizer) + data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4) training_args_dict = training_args.to_dict() training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index d45571d2..63070965 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -33,6 +33,7 @@ def run_sft( data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, + pad_to_multiple_of=4, # for shift short attention label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id )