diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 8d7a1161..a3ff8fee 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -25,16 +25,16 @@ class SavePeftModelCallback(TrainerCallback): r""" Event called after a checkpoint save. """ - output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) - getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir) - return control + if args.should_save: + output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) + getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir) - def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of training. """ - getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir) - return control + if args.should_save: + getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir) class LogCallback(TrainerCallback): diff --git a/src/llmtuner/extras/models/flash_llama.py b/src/llmtuner/extras/models/flash_llama.py index 8a4dae2a..670c3e8f 100644 --- a/src/llmtuner/extras/models/flash_llama.py +++ b/src/llmtuner/extras/models/flash_llama.py @@ -1,22 +1,14 @@ # coding=utf-8 # Modified from: # [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py -# [2] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py -# [3] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +# [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py +# [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py +# [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17 -from typing import List, Optional, Tuple, Union - import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from typing import Optional, Tuple +from transformers.utils import logging from transformers.models.llama.configuration_llama import LlamaConfig @@ -43,69 +35,34 @@ except ImportError: logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "LlamaConfig" +class LlamaRMSNorm(torch.nn.Module): -def rmsnorm_func(hidden_states, weight, variance_epsilon): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) - return (weight * hidden_states).to(input_dtype) - - -class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.register_buffer( - "variance_epsilon", - torch.tensor(eps), - persistent=False, - ) - + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + def forward(self, hidden_states): - return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon) + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) # for fp32 weight class FlashRotaryEmbedding(torch.nn.Module): - """ - The rotary position embeddings from RoFormer_ (Su et. al). - A crucial insight from the method is that the query and keys are - transformed by rotation matrices which depend on the relative positions. - Other implementations are available in the Rotary Transformer repo_ and in - GPT-NeoX_, GPT-NeoX was an inspiration - - .. _RoFormer: https://arxiv.org/abs/2104.09864 - .. _repo: https://github.com/ZhuiyiTechnology/roformer - .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox - - If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). - A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 - Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py - """ - - def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, - scaling_factor=1.0, pos_idx_in_fp32=True, device=None): - """ - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, - otherwise they might be in lower precision. - This option was added because previously (before 2023-07-02), when we construct - the position indices, we use the dtype of self.inv_freq. In most cases this would - be fp32, but if the model is trained in pure bf16 (not mixed precision), then - self.inv_freq would be bf16, and the position indices are also in bf16. - Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the - embeddings for some positions will coincide. - To maintain compatibility with models previously trained in pure bf16, - we add this option. - scaling_factor: RotaryEmbedding extended with linear scaling. - """ + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + scaling_factor=1.0, + pos_idx_in_fp32=True, + device=None + ): super().__init__() self.dim = dim self.base = float(base) @@ -116,8 +73,10 @@ class FlashRotaryEmbedding(torch.nn.Module): self.interleaved = interleaved self.scale_base = scale_base self.scaling_factor = scaling_factor - scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) - / (1.4 * dim) if scale_base is not None else None) + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None else None + ) self.register_buffer("scale", scale) self._seq_len_cached = 0 @@ -127,28 +86,18 @@ class FlashRotaryEmbedding(torch.nn.Module): self._sin_k_cached = None def _compute_inv_freq(self, device=None): - return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, - dtype=torch.float32) / self.dim)) - + return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): - # Reset the tables if the sequence length has changed, - # if we're on a new device (possibly due to tracing for instance), - # or if we're switching from inference mode to training - if (seqlen > self._seq_len_cached or self._cos_cached.device != device + if ( + seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference())): + or (self.training and self._cos_cached.is_inference()) + ): self._seq_len_cached = seqlen - # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 - # And the output of arange can be quite large, so bf16 would lose a lot of precision. - # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. if self.pos_idx_in_fp32: t = torch.arange(seqlen, device=device, dtype=torch.float32) t /= self.scaling_factor - # We want fp32 here as well since inv_freq will be multiplied with t, and the output - # will be large. Having it in bf16 will lose a lot of precision and cause the - # cos & sin output to change significantly. - # We want to recompute self.inv_freq if it was not loaded in fp32 if self.inv_freq.dtype != torch.float32: inv_freq = self.inv_freq.to(torch.float32) else: @@ -157,15 +106,14 @@ class FlashRotaryEmbedding(torch.nn.Module): t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t /= self.scaling_factor inv_freq = self.inv_freq - # Don't do einsum, it converts fp32 to fp16 under AMP - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, inv_freq) if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: - power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2) / self.scale_base) + power = ( + (torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base + ) scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) @@ -174,7 +122,7 @@ class FlashRotaryEmbedding(torch.nn.Module): self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: - """ + r""" q: (batch, seqlen, nheads, headdim) k: (batch, seqlen, nheads, headdim) seqlen_offset: can be used in generation where the qkv being passed in is only the last @@ -193,23 +141,8 @@ class FlashRotaryEmbedding(torch.nn.Module): assert False -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ + r""" This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ @@ -220,10 +153,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim) -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" +class LlamaAttention(torch.nn.Module): - def __init__(self, config: LlamaConfig): + def __init__(self, config: "LlamaConfig"): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -238,10 +170,11 @@ class LlamaAttention(nn.Module): f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.register_buffer( "norm_factor", @@ -254,10 +187,10 @@ class LlamaAttention(nn.Module): else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] - assert scaling_type == 'linear' - + assert scaling_type == "linear" + self.rotary_emb = FlashRotaryEmbedding( - self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor, + self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -270,8 +203,7 @@ class LlamaAttention(nn.Module): position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, - use_cache: bool = False, - is_padded_inputs: Optional[bool] = False, + use_cache: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, h_size = hidden_states.size() @@ -290,9 +222,9 @@ class LlamaAttention(nn.Module): q = q.view(bsz, q_len, self.num_heads, self.head_dim) k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim) v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - + q, k = self.rotary_emb(q, k, past_len) - + kv = torch.stack([k, v], 2) kv = repeat_kv(kv, self.num_key_value_groups) @@ -300,27 +232,26 @@ class LlamaAttention(nn.Module): if has_layer_past: new_len = past_len+q.size(1) if new_len > past_kv.size(1): - past_kv = torch.cat([past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1) + past_kv = torch.cat( + [past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1 + ) past_kv[:, past_len:new_len] = kv kv = past_kv[:, :new_len] else: past_kv = kv - past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None - - if is_padded_inputs: + past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None + if attention_mask is not None: # varlen, ignore padding tokens, efficient for large batch with many paddings - logger.warning_once("padded") + logger.warning_once("padded sequences is less efficient") - assert attention_mask is not None - unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask) unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):]) attn_outputs = flash_attn_varlen_kvpacked_func( - unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k, + unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=0.0, softmax_scale=1.0/self.norm_factor, + dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions ) @@ -329,14 +260,13 @@ class LlamaAttention(nn.Module): attn_output, indices_q, bsz, q_len ).reshape(bsz, q_len, h_size) attn_weights = attn_outputs[2] if output_attentions else None - + else: - # no padding tokens, more efficient - attn_outputs = flash_attn_kvpacked_func( - q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions) - + q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, + causal=(not has_layer_past), return_attn_probs=output_attentions + ) attn_output = attn_outputs[0] if output_attentions else attn_outputs attn_output = attn_output.reshape(bsz, q_len, h_size) attn_weights = attn_outputs[2] if output_attentions else None @@ -349,377 +279,27 @@ class LlamaAttention(nn.Module): return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - is_padded_inputs: Optional[bool] = False, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - is_padded_inputs=is_padded_inputs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING, LLAMA_INPUTS_DOCSTRING = "", "" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - is_padded_inputs: Optional[bool] = False, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - position_ids = None - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - is_padded_inputs - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - is_padded_inputs=is_padded_inputs, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, +# Disable the transformation of the attention mask in LlamaModel as flash attention +# takes a boolean key_padding_mask. Fills in the past kv length for use in forward. +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + ( + torch.full( + (input_shape[0], past_key_values_length), + True, + dtype=attention_mask.dtype, + device=attention_mask.device + ), + attention_mask + ), + dim=-1 ) + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - is_padded_inputs: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - is_padded_inputs = ((attention_mask is not None) and (not attention_mask.all().item())) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: "CausalLMOutputWithPast" = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_padded_inputs=is_padded_inputs, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "is_padded_inputs": ((attention_mask is not None) and (not attention_mask.all().item())) - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past + return attention_mask diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 820a5714..656d3918 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -132,8 +132,11 @@ def load_model_and_tokenizer( # Set flash attention if model_args.flash_attn and getattr(config, "model_type", None) == "llama": - from llmtuner.extras.models.flash_llama import LlamaForCausalLM - transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM + import transformers.models.llama.modeling_llama as LlamaModule + from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask + LlamaModule.LlamaRMSNorm = LlamaRMSNorm + LlamaModule.LlamaAttention = LlamaAttention + LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask if not hasattr(config, "num_key_value_heads"): setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads")) if getattr(config, "pretraining_tp", 1) != 1: diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 0036fe0f..c1d2f054 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -26,6 +26,7 @@ class CustomDPOTrainer(DPOTrainer): if ref_model is not None: disable_dropout_in_model(ref_model) + self.is_encoder_decoder = model.config.is_encoder_decoder self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning self.label_pad_token_id = IGNORE_INDEX diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 49d2d1a9..95db49c2 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -5,6 +5,7 @@ from tqdm import tqdm from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from trl import PPOTrainer from trl.core import LengthSampler, PPODecorators, logprobs_from_logits @@ -96,7 +97,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): # Cast to inference mode unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True - unwrapped_model, layer_norm_params = cast_layernorm_dtype(unwrapped_model, self.compute_dtype) self.model.eval() # Get inputs @@ -107,7 +107,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer): # Cast to training mode unwrapped_model.gradient_checkpointing_enable() unwrapped_model.config.use_cache = False - unwrapped_model, _ = cast_layernorm_dtype(unwrapped_model, self.compute_dtype, layer_norm_params) self.model.train() # Run PPO step @@ -134,7 +133,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer): reward_meter.reset() if (step+1) % self.args.save_steps == 0: # save checkpoint - self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) + self.save_model(os.path.join( + self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step) + )) + self.save_callback.on_save( + self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) + ) if self.control.should_epoch_stop or self.control.should_training_stop: break @@ -165,8 +169,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) input_ids = batch["input_ids"] + self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) response: torch.Tensor = unwrapped_model.generate(**gen_kwargs) + self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params) query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu() queries, responses = [], [] @@ -294,6 +300,3 @@ class CustomPPOTrainer(PPOTrainer, Trainer): """ if self.args.should_save: self._save(output_dir) - self.save_callback.on_save( - self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) - )