mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 21:52:50 +08:00
fix shift short attention
Former-commit-id: ab65c3063b31b9e6a1aeb62c57224c1296ccdadd
This commit is contained in:
parent
8de990b1e9
commit
b6e81a0307
@ -55,46 +55,32 @@ class LlamaShiftShortAttention(LlamaAttention):
|
|||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
if getattr(self, "shift_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
group_size = int(q_len * getattr(self, "shift_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
if q_len % group_size > 0:
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
|
num_groups = q_len // groupsz
|
||||||
num_group = q_len // group_size
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
for state in (query_states, key_states, value_states):
|
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
|
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||||
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim).transpose(1, 2)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
if getattr(self, "shift_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
@ -160,19 +146,21 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
if getattr(self, "shift_ratio", None) and self.training: # shift
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||||
group_size = int(q_len * getattr(self, "shift_ratio"))
|
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||||
if q_len % group_size > 0:
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
raise ValueError("q_len {} should be divisible by group size {}.".format(q_len, group_size))
|
num_groups = q_len // groupsz
|
||||||
num_group = q_len // group_size
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
for state in (query_states, key_states, value_states):
|
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-group_size//2, dims=1)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||||
state = state.reshape(bsz * num_group, group_size, self.num_heads, self.head_dim)
|
|
||||||
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
||||||
batch_size = query_states.shape[0]
|
# -q_len: assumes left padding when q_len != kv_len
|
||||||
# -q_len: assumes left padding
|
|
||||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
||||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
||||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
||||||
@ -188,15 +176,15 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
|
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
||||||
else:
|
else:
|
||||||
attn_output = flash_attn_func(
|
attn_output = flash_attn_func(
|
||||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self, "shift_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(group_size//2, dims=1)
|
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
@ -103,7 +103,6 @@ def load_model_and_tokenizer(
|
|||||||
logger.info("Using dynamic NTK scaling.")
|
logger.info("Using dynamic NTK scaling.")
|
||||||
|
|
||||||
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
||||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
if model_args.rope_scaling == "dynamic":
|
if model_args.rope_scaling == "dynamic":
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -128,7 +127,7 @@ def load_model_and_tokenizer(
|
|||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
|
||||||
# Set FlashAttention-2 and S^2-Attn
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
@ -136,12 +135,22 @@ def load_model_and_tokenizer(
|
|||||||
LlamaPatches._prepare_decoder_attention_mask
|
LlamaPatches._prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
elif getattr(config, "model_type", None) == "qwen":
|
||||||
|
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support FlashAttention-2.")
|
logger.warning("Current model does not support FlashAttention-2.")
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||||
|
|
||||||
|
# Set shift short attention (S^2-Attn)
|
||||||
|
if is_trainable and model_args.shift_attn:
|
||||||
|
if getattr(config, "model_type", None) == "llama":
|
||||||
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
|
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library).
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
@ -176,14 +185,6 @@ def load_model_and_tokenizer(
|
|||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set shift short attention (S^2-Attn)
|
|
||||||
if is_trainable and model_args.shift_attn:
|
|
||||||
if getattr(config, "model_type", None) == "llama":
|
|
||||||
setattr(model, "shift_ratio", 0.25)
|
|
||||||
logger.info("Using shift short attention proposed by LongLoRA.")
|
|
||||||
else:
|
|
||||||
logger.warning("Current model does not support shift short attention.")
|
|
||||||
|
|
||||||
# Disable custom generate method (for Qwen and Baichuan2)
|
# Disable custom generate method (for Qwen and Baichuan2)
|
||||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
@ -149,6 +149,9 @@ def get_train_args(
|
|||||||
if general_args.stage == "ppo" and data_args.streaming:
|
if general_args.stage == "ppo" and data_args.streaming:
|
||||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||||
|
|
||||||
|
if general_args.stage == "ppo" and model_args.shift_attn:
|
||||||
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ def run_dpo(
|
|||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
data_collator = DPODataCollatorWithPadding(
|
data_collator = DPODataCollatorWithPadding(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
pad_to_multiple_of=4,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ def run_rm(
|
|||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||||
|
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
|
@ -33,6 +33,7 @@ def run_sft(
|
|||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
pad_to_multiple_of=4, # for shift short attention
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user