support longlora for main branch

Former-commit-id: 38af076a75c33da26d641780820694e4b7342d92
This commit is contained in:
hiyouga 2024-01-20 19:25:22 +08:00
parent b36dea7e7b
commit 69e8925249
7 changed files with 168 additions and 204 deletions

View File

@ -3,37 +3,24 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple
from transformers.utils import logging from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import (
Cache, LlamaAttention, LlamaFlashAttention2, apply_rotary_pos_emb, repeat_kv
try: )
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from ..packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class LlamaShiftShortAttention(LlamaAttention): def llama_torch_attn_forward(
self: "LlamaAttention",
def forward(
self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
@ -46,18 +33,15 @@ class LlamaShiftShortAttention(LlamaAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2) cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
value_states = torch.cat([past_key_value[1], value_states], dim=2) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
past_key_value = (key_states, value_states) if use_cache else None
if getattr(self, "num_key_value_groups"):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
@ -83,6 +67,7 @@ class LlamaShiftShortAttention(LlamaAttention):
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
@ -101,18 +86,16 @@ class LlamaShiftShortAttention(LlamaAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class LlamaFlashAttention2(LlamaAttention): # Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
def llama_flash_attn_forward(
def forward( self: "LlamaFlashAttention2",
self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False output_attentions = False
@ -129,26 +112,15 @@ class LlamaFlashAttention2(LlamaAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: # reuse k, v, self_attention if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2) cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
value_states = torch.cat([past_key_value[1], value_states], dim=2) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
past_key_value = (key_states, value_states) if use_cache else None
# cast to half precision
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(self.config.torch_dtype)
key_states = key_states.to(self.config.torch_dtype)
value_states = value_states.to(self.config.torch_dtype)
if getattr(self, "num_key_value_groups", None):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
@ -156,6 +128,8 @@ class LlamaFlashAttention2(LlamaAttention):
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
dropout_rate = self.attention_dropout if self.training else 0.0
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
@ -168,30 +142,10 @@ class LlamaFlashAttention2(LlamaAttention):
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz) attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
if attention_mask is not None: attn_output: torch.Tensor = self._flash_attention_forward(
logger.warning_once("Padded sequences are less efficient in FlashAttention.") query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
# -q_len: assumes left padding when q_len != kv_len
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
attn_output_unpad = flash_attn_varlen_func(
unpadded_q,
unpadded_k,
unpadded_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=True,
)
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
) )
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
@ -209,16 +163,6 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention def apply_llama_patch() -> None:
# takes a boolean padding_mask. Fills in the past kv length for use in forward. LlamaAttention.forward = llama_torch_attn_forward
def _prepare_decoder_attention_mask( LlamaFlashAttention2.forward = llama_flash_attn_forward
self,
attention_mask: torch.Tensor,
input_shape: torch.Tensor,
inputs_embeds: torch.Tensor,
past_key_values_length: int
) -> torch.Tensor:
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask

View File

@ -1,4 +1,5 @@
import torch import torch
import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from peft import PeftModel, TaskType, LoraConfig, get_peft_model from peft import PeftModel, TaskType, LoraConfig, get_peft_model
@ -108,6 +109,9 @@ def init_adapter(
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters:
unsloth_peft_kwargs["loftq_config"] = {}
if getattr(model.config, "model_type", None) == "llama": if getattr(model.config, "model_type", None) == "llama":
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
elif getattr(model.config, "model_type", None) == "mistral": elif getattr(model.config, "model_type", None) == "mistral":

View File

@ -15,6 +15,7 @@ from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
@ -23,7 +24,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
@ -39,26 +40,25 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
Resize token embeddings. Resize token embeddings.
""" """
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed import deepspeed # type: ignore
with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None): params = [model.get_input_embeddings().weight]
current_embedding_size = model.get_input_embeddings().weight.size(0) if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else: else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0) current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size: if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear): if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
logger.warning("Current model does not support resizing token embeddings.") logger.warning("Current model does not support resizing token embeddings.")
return return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
if is_deepspeed_zero3_enabled(): with context_maybe_zero3:
import deepspeed
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context = nullcontext()
with context:
new_embedding_size = model.get_input_embeddings().weight.size(0) new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
@ -136,6 +136,7 @@ def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
def _configure_longlora(config: "PretrainedConfig") -> None: def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.") logger.info("Using shift short attention with group_size_ratio=1/4.")
else: else:
logger.warning("Current model does not support shift short attention.") logger.warning("Current model does not support shift short attention.")

View File

@ -1,4 +1,5 @@
import torch import torch
from contextlib import nullcontext
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer from transformers import BatchEncoding, Trainer
@ -93,7 +94,8 @@ class CustomDPOTrainer(DPOTrainer):
all_logps = self.get_batch_logps( all_logps = self.get_batch_logps(
all_logits, all_logits,
batch["labels"], batch["labels"],
average_log_prob=False average_log_prob=False,
label_pad_token_id=self.label_pad_token_id,
) )
batch_size = batch["input_ids"].size(0) // 2 batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
@ -118,20 +120,19 @@ class CustomDPOTrainer(DPOTrainer):
) = self.concatenated_forward(model, batch) ) = self.concatenated_forward(model, batch)
with torch.no_grad(): with torch.no_grad():
if self.ref_model is None: if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter(): ref_model = self.model
( ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else: else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
( (
reference_chosen_logps, reference_chosen_logps,
reference_rejected_logps, reference_rejected_logps,
_, _,
_, _,
) = self.concatenated_forward(self.ref_model, batch) ) = self.concatenated_forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss( losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps, policy_chosen_logps,

View File

@ -95,7 +95,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=3, allow_custom_value=True) dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
reward_model = gr.Dropdown(scale=2, allow_custom_value=True)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
refresh_btn.click( refresh_btn.click(
@ -105,8 +106,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
queue=False queue=False
) )
input_elems.update({dpo_beta, reward_model}) input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn)) elem_dict.update(dict(
rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn
))
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()

View File

@ -421,6 +421,16 @@ LOCALES = {
"info": "DPO 损失函数中 beta 超参数大小。" "info": "DPO 损失函数中 beta 超参数大小。"
} }
}, },
"dpo_ftx": {
"en": {
"label": "DPO-ftx weight",
"info": "The weight of SFT loss in the DPO-ftx."
},
"zh": {
"label": "DPO-ftx 权重",
"info": "DPO-ftx 中 SFT 损失的权重大小。"
}
},
"reward_model": { "reward_model": {
"en": { "en": {
"label": "Reward model", "label": "Reward model",

View File

@ -146,6 +146,7 @@ class Runner:
if args["stage"] == "dpo": if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta") args["dpo_beta"] = get("train.dpo_beta")
args["dpo_ftx"] = get("train.dpo_ftx")
if get("train.val_size") > 1e-6 and args["stage"] != "ppo": if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
args["val_size"] = get("train.val_size") args["val_size"] = get("train.val_size")